Coverage for src/flag_gems/utils/tensor_wrapper.py: 69%
65 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import math
3import torch
5import flag_gems
8class TypedPtr:
9 """This is a minimal requirement for a type to be treated as a tensor in triton jit
10 function. Basically it is a typed pointer, without knowning the device, size, shape,
11 strides, etc.
12 """
14 def __init__(self, ptr: int, dtype: torch.dtype):
15 self.ptr = ptr
16 self.dtype = dtype
18 def data_ptr(self) -> int:
19 return self.ptr
21 def untyped_storage(self):
22 return self
24 @classmethod
25 def from_tensor(cls, tensor: torch.Tensor, offset: int = 0):
26 return cls(tensor.data_ptr() + tensor.element_size() * offset, tensor.dtype)
28 @classmethod
29 def reinterpret_tensor(cls, tensor: torch.Tensor, dtype: torch.dtype, offset=0):
30 return cls(tensor.data_ptr() + dtype.itemsize * offset, dtype)
33class StridedBuffer:
34 """A drop-in replacement of torch.Tensor that can be used in wrapper generated by
35 PointwiseDynamicFunction. It allows us to use a different shape, stride, data
36 pointer that that of the base tensor.
38 It is a kind of reinterpretation of the base tensor. We make this class since we
39 cannot get a Tensor view with negative strides via torch APIs, while we need this
40 to implement flip op.
42 Although generated code can accept torch.Tensor & StridedBuffer, but StridedBuffer
43 may not have all the methods as torch.Tensors do. We add some attributes & methods
44 with the same name as torch.Tensor, which are used in the generated code. But we
45 may not cover all the methods, add one if what you need is missing here.
47 And can also be used in triton kernels since it also has dtype & data_ptr().
48 """
50 def __init__(
51 self, base: torch.Tensor, shape=None, strides=None, dtype=None, offset=0
52 ):
53 self._base = base
54 self.dtype = dtype or base.dtype
55 self.offset = offset
57 if offset == 0:
58 self._data_ptr = self._base.data_ptr()
59 else:
60 # TODO[kunlunxin]: we will upgrade torch version in 2025.04
61 if flag_gems.vendor_name == "kunlunxin":
63 def get_dtype_bytes(dtype):
64 if dtype.is_floating_point:
65 return int(torch.finfo(dtype).bits / 8)
66 else:
67 return int(torch.iinfo(dtype).bits / 8)
69 offset = get_dtype_bytes(self.dtype) * offset
70 else:
71 offset = self.dtype.itemsize * offset
73 self._data_ptr = self._base.data_ptr() + offset
74 self.shape = tuple(shape if shape is not None else self._base.shape)
75 self._strides = tuple(strides if strides is not None else self._base.stride())
76 self.device = self._base.device
77 self.ndim = len(self.shape)
79 def stride(self):
80 return self._strides
82 def size(self):
83 return self.shape
85 def element_size(self):
86 return self.dtype.itemsize
88 def numel(self):
89 return math.prod(self.shape)
91 def dim(self):
92 return self.ndim
94 def unwrap(self):
95 return self._base
97 def data_ptr(self):
98 return self._data_ptr
100 def untyped_storage(self):
101 return self._base.untyped_storage()
103 def clone(self):
104 return StridedBuffer(
105 self._base.clone(),
106 shape=self.shape,
107 strides=self._strides,
108 dtype=self.dtype,
109 offset=self.offset,
110 )
112 def copy_(self, src):
113 if isinstance(src, StridedBuffer):
114 self._base.copy_(src._base)
115 self._strides = src._strides
116 self.shape = src.shape
117 self.dtype = src.dtype
118 self.device = src.device
119 self.offset = src.offset
120 else:
121 src_buffer = StridedBuffer(src)
122 self.copy_(src_buffer)
123 return self