Coverage for src/flag_gems/utils/tensor_wrapper.py: 69%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import math 

2 

3import torch 

4 

5import flag_gems 

6 

7 

8class TypedPtr: 

9 """This is a minimal requirement for a type to be treated as a tensor in triton jit 

10 function. Basically it is a typed pointer, without knowning the device, size, shape, 

11 strides, etc. 

12 """ 

13 

14 def __init__(self, ptr: int, dtype: torch.dtype): 

15 self.ptr = ptr 

16 self.dtype = dtype 

17 

18 def data_ptr(self) -> int: 

19 return self.ptr 

20 

21 def untyped_storage(self): 

22 return self 

23 

24 @classmethod 

25 def from_tensor(cls, tensor: torch.Tensor, offset: int = 0): 

26 return cls(tensor.data_ptr() + tensor.element_size() * offset, tensor.dtype) 

27 

28 @classmethod 

29 def reinterpret_tensor(cls, tensor: torch.Tensor, dtype: torch.dtype, offset=0): 

30 return cls(tensor.data_ptr() + dtype.itemsize * offset, dtype) 

31 

32 

33class StridedBuffer: 

34 """A drop-in replacement of torch.Tensor that can be used in wrapper generated by 

35 PointwiseDynamicFunction. It allows us to use a different shape, stride, data 

36 pointer that that of the base tensor. 

37 

38 It is a kind of reinterpretation of the base tensor. We make this class since we 

39 cannot get a Tensor view with negative strides via torch APIs, while we need this 

40 to implement flip op. 

41 

42 Although generated code can accept torch.Tensor & StridedBuffer, but StridedBuffer 

43 may not have all the methods as torch.Tensors do. We add some attributes & methods 

44 with the same name as torch.Tensor, which are used in the generated code. But we 

45 may not cover all the methods, add one if what you need is missing here. 

46 

47 And can also be used in triton kernels since it also has dtype & data_ptr(). 

48 """ 

49 

50 def __init__( 

51 self, base: torch.Tensor, shape=None, strides=None, dtype=None, offset=0 

52 ): 

53 self._base = base 

54 self.dtype = dtype or base.dtype 

55 self.offset = offset 

56 

57 if offset == 0: 

58 self._data_ptr = self._base.data_ptr() 

59 else: 

60 # TODO[kunlunxin]: we will upgrade torch version in 2025.04 

61 if flag_gems.vendor_name == "kunlunxin": 

62 

63 def get_dtype_bytes(dtype): 

64 if dtype.is_floating_point: 

65 return int(torch.finfo(dtype).bits / 8) 

66 else: 

67 return int(torch.iinfo(dtype).bits / 8) 

68 

69 offset = get_dtype_bytes(self.dtype) * offset 

70 else: 

71 offset = self.dtype.itemsize * offset 

72 

73 self._data_ptr = self._base.data_ptr() + offset 

74 self.shape = tuple(shape if shape is not None else self._base.shape) 

75 self._strides = tuple(strides if strides is not None else self._base.stride()) 

76 self.device = self._base.device 

77 self.ndim = len(self.shape) 

78 

79 def stride(self): 

80 return self._strides 

81 

82 def size(self): 

83 return self.shape 

84 

85 def element_size(self): 

86 return self.dtype.itemsize 

87 

88 def numel(self): 

89 return math.prod(self.shape) 

90 

91 def dim(self): 

92 return self.ndim 

93 

94 def unwrap(self): 

95 return self._base 

96 

97 def data_ptr(self): 

98 return self._data_ptr 

99 

100 def untyped_storage(self): 

101 return self._base.untyped_storage() 

102 

103 def clone(self): 

104 return StridedBuffer( 

105 self._base.clone(), 

106 shape=self.shape, 

107 strides=self._strides, 

108 dtype=self.dtype, 

109 offset=self.offset, 

110 ) 

111 

112 def copy_(self, src): 

113 if isinstance(src, StridedBuffer): 

114 self._base.copy_(src._base) 

115 self._strides = src._strides 

116 self.shape = src.shape 

117 self.dtype = src.dtype 

118 self.device = src.device 

119 self.offset = src.offset 

120 else: 

121 src_buffer = StridedBuffer(src) 

122 self.copy_(src_buffer) 

123 return self