Coverage for src/flag_gems/runtime/backend/_mthreads/ops/arange.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2import math 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.ops.arange import arange_start as default_arange_start 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry 

13from flag_gems.utils import triton_lang_extension as tle 

14 

15logger = logging.getLogger( 

16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

17) 

18 

19device_ = runtime.device 

20_SUPPORTED_DTYPES = { 

21 torch.float16, 

22 torch.bfloat16, 

23 torch.float32, 

24 torch.int32, 

25 torch.int64, 

26} 

27_AUTOTUNE_CONFIGS = [ 

28 triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=1), 

29 triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=1), 

30 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1), 

31 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=1), 

32] 

33 

34 

35@libentry() 

36@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["n_elements", "USE_INT64"]) 

37@triton.jit(do_not_specialize=["start", "step"]) 

38def arange_kernel( 

39 out_ptr, 

40 start, 

41 step, 

42 n_elements, 

43 IS_FLOAT: tl.constexpr, 

44 USE_INT64: tl.constexpr, 

45 BLOCK_SIZE: tl.constexpr, 

46): 

47 pid = tle.program_id(0) 

48 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

49 if USE_INT64: 

50 offsets = offsets.to(tl.int64) 

51 n_elements = tl.full((1,), n_elements, tl.int64) 

52 else: 

53 offsets = offsets.to(tl.int32) 

54 n_elements = tl.full((1,), n_elements, tl.int32) 

55 mask = offsets < n_elements 

56 

57 if IS_FLOAT: 

58 idx = offsets.to(tl.float32) 

59 step_val = tl.full((1,), step, tl.float32) 

60 start_val = tl.full((1,), start, tl.float32) 

61 values = tl.fma(idx, step_val, start_val) 

62 else: 

63 value_dtype = tl.int64 if USE_INT64 else tl.int32 

64 idx = offsets.to(value_dtype) 

65 step_val = tl.full((1,), step, value_dtype) 

66 start_val = tl.full((1,), start, value_dtype) 

67 values = start_val + idx * step_val 

68 

69 tl.store(out_ptr + offsets, values, mask=mask) 

70 

71 

72def _normalize_scalar(value): 

73 if isinstance(value, torch.Tensor): 

74 return value.item() 

75 return value 

76 

77 

78def _compute_size(start, end, step, is_float_dtype: bool) -> int: 

79 if step == 0: 

80 raise ValueError("arange(): step must be non-zero.") 

81 if is_float_dtype: 

82 size = math.ceil((end - start) / step) 

83 else: 

84 sgn = (step > 0) - (step < 0) 

85 size = (end - start + step - sgn) // step 

86 return int(size) if size > 0 else 0 

87 

88 

89def _use_triton(dtype: torch.dtype, device: torch.device, size: int) -> bool: 

90 if device.type != "musa": 

91 return False 

92 if dtype not in _SUPPORTED_DTYPES: 

93 return False 

94 return size > 0 

95 

96 

97def _launch_triton_kernel( 

98 out: torch.Tensor, 

99 start, 

100 step, 

101 size: int, 

102 *, 

103 is_float_dtype: bool, 

104 use_int64: bool, 

105): 

106 grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),) 

107 with torch_device_fn.device(out.device): 

108 arange_kernel[grid]( 

109 out, 

110 start, 

111 step, 

112 size, 

113 IS_FLOAT=is_float_dtype, 

114 USE_INT64=use_int64, 

115 ) 

116 return out 

117 

118 

119def arange_start( 

120 start, 

121 end, 

122 step=1, 

123 *, 

124 dtype: Optional[torch.dtype] = None, 

125 layout=None, 

126 device=None, 

127 pin_memory: Optional[bool] = None, 

128): 

129 logger.debug("GEMS_MTHREADS ARANGE") 

130 start = _normalize_scalar(start) 

131 end = _normalize_scalar(end) 

132 step = _normalize_scalar(step) 

133 

134 if dtype is None: 

135 dtype = torch.int64 

136 if pin_memory is None: 

137 pin_memory = False 

138 if device is None: 

139 device = torch.device(device_.name) 

140 else: 

141 device = torch.device(device) 

142 

143 is_float_dtype = torch.is_floating_point(torch.tensor(0, dtype=dtype)) 

144 use_int64 = dtype == torch.int64 

145 size = _compute_size(start, end, step, is_float_dtype) 

146 

147 if not _use_triton(dtype, device, size): 

148 return default_arange_start( 

149 start, 

150 end, 

151 step, 

152 dtype=dtype, 

153 layout=layout, 

154 device=device, 

155 pin_memory=pin_memory, 

156 ) 

157 

158 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory) 

159 return _launch_triton_kernel( 

160 result, 

161 start, 

162 step, 

163 size, 

164 is_float_dtype=is_float_dtype, 

165 use_int64=use_int64, 

166 ) 

167 

168 

169def arange( 

170 end, 

171 *, 

172 dtype: Optional[torch.dtype] = None, 

173 layout=None, 

174 device=None, 

175 pin_memory: Optional[bool] = None, 

176): 

177 return arange_start( 

178 0, 

179 end, 

180 1, 

181 dtype=dtype, 

182 layout=layout, 

183 device=device, 

184 pin_memory=pin_memory, 

185 )