Coverage for src/flag_gems/ops/argmin.py: 45%

161 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_max 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.jit 

19def argmin_kernel_1( 

20 inp, 

21 mid_value, 

22 mid_index, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 

31 max_value = get_dtype_max(inp.type.element_ty) 

32 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) 

33 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True) 

34 min_index = min_index + pid * BLOCK_SIZE 

35 mid_value_ptr = mid_value + pid 

36 min_index_ptr = mid_index + pid 

37 tl.store(mid_value_ptr, min_val) 

38 tl.store(min_index_ptr, min_index) 

39 

40 

41@libentry() 

42@triton.jit 

43def argmin_kernel_2( 

44 mid_value, 

45 mid_index, 

46 out, 

47 mid_size, 

48 BLOCK_MID: tl.constexpr, 

49): 

50 offset = tl.arange(0, BLOCK_MID) 

51 mid_ptrs = mid_value + offset 

52 mask = offset < mid_size 

53 max_value = get_dtype_max(mid_value.type.element_ty) 

54 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) 

55 index_val = tl.argmin(mid_val, axis=0) 

56 mid_index_ptrs = mid_index + index_val 

57 out_val = tl.load(mid_index_ptrs) 

58 tl.store(out, out_val) 

59 

60 

61def heur_block_n(args): 

62 return min(4096, triton.next_power_of_2(args["N"])) 

63 

64 

65@libentry() 

66@triton.heuristics(runtime.get_heuristic_config("argmin")) 

67@triton.jit 

68def argmin_kernel_opt_k1( 

69 inp, 

70 out_index, 

71 M, 

72 N, 

73 BLOCK_M: tl.constexpr, 

74 BLOCK_N: tl.constexpr, 

75): 

76 pid_m = tle.program_id(0) 

77 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

78 

79 dtype = inp.type.element_ty 

80 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

81 max_val = get_dtype_max(dtype) 

82 

83 min_vals = tl.full([BLOCK_M], dtype=acc_type, value=max_val) 

84 argmin_vals = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

85 for start_n in range(0, N, BLOCK_N): 

86 n_offset = start_n + tl.arange(0, BLOCK_N) 

87 offset = m_offset[:, None] * N + n_offset[None, :] 

88 inp_vals = tl.load(inp + offset, mask=True) 

89 

90 local_min, local_argmin = tl.min( 

91 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

92 ) 

93 update = local_min < min_vals 

94 min_vals = tl.where(update, local_min, min_vals) 

95 argmin_vals = tl.where(update, start_n + local_argmin, argmin_vals) 

96 

97 out_ptr = out_index + m_offset 

98 tl.store(out_ptr, argmin_vals, mask=True) 

99 

100 

101@libentry() 

102@triton.autotune( 

103 configs=runtime.get_tuned_config("argmin_split_k"), key=["M", "N", "K"] 

104) 

105@triton.jit 

106def argmin_split_K_kernel_merged( 

107 inp, 

108 out_index, 

109 M: tl.constexpr, 

110 N: tl.constexpr, 

111 K: tl.constexpr, 

112 dtype: tl.constexpr, 

113 BLOCK_M: tl.constexpr, 

114 BLOCK_N: tl.constexpr, 

115 BLOCK_K: tl.constexpr, 

116): 

117 pid_m = tle.program_id(0) 

118 pid_k = tle.program_id(1) 

119 

120 m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] # (BLOCK_M, 1) 

121 k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)[None, :] # (1, BLOCK_K) 

122 

123 m_mask = m < M 

124 k_mask = k < K 

125 mk_mask = m_mask & k_mask 

126 

127 compute_dtype = tl.float32 if dtype == tl.bfloat16 else dtype 

128 max_val = get_dtype_max(compute_dtype) 

129 

130 global_min = tl.full((BLOCK_M, BLOCK_K), max_val, dtype=compute_dtype) 

131 global_argmin = tl.full((BLOCK_M, BLOCK_K), 0, dtype=tl.int64) 

132 

133 for start_n in range(0, N, BLOCK_N): 

134 n = start_n + tl.arange(0, BLOCK_N) 

135 n_mask = n < N 

136 

137 offset = m * N * K + n[:, None, None] * K + k[None, :, :] 

138 

139 inp_vals = tl.load( 

140 inp + offset, 

141 mask=(m_mask & n_mask[:, None, None] & k_mask[None, :, :]), 

142 other=max_val, 

143 ) 

144 inp_vals = inp_vals.to(compute_dtype) 

145 

146 local_min, local_argmin = tl.min( 

147 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True 

148 ) 

149 local_argmin += start_n 

150 

151 mask = local_min < global_min 

152 global_min = tl.where(mask, local_min, global_min) 

153 global_argmin = tl.where(mask, local_argmin, global_argmin) 

154 

155 out_offset = m * K + k # (BLOCK_M, BLOCK_K) 

156 tl.store(out_index + out_offset, global_argmin, mask=mk_mask) 

157 

158 

159@libentry() 

160@triton.heuristics(runtime.get_heuristic_config("argmin")) 

161@triton.jit 

162def argmin_kernel( 

163 inp, 

164 out_index, 

165 M, 

166 N, 

167 K, 

168 BLOCK_M: tl.constexpr, 

169 BLOCK_N: tl.constexpr, 

170): 

171 # set offset 

172 pid_m = tle.program_id(0) 

173 pid_k = tle.program_id(1) 

174 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

175 

176 dtype = inp.type.element_ty 

177 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

178 max_value = get_dtype_max(dtype) 

179 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) 

180 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

181 for start_n in range(0, N, BLOCK_N): 

182 n_offset = start_n + tl.arange(0, BLOCK_N) 

183 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

184 mask = m_offset[:, None] < M and n_offset[None, :] < N 

185 inp_ptrs = inp + offset 

186 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

187 # tl.bfloat is promoted to tl.float32 by tl.min 

188 local_min, local_argmin = tl.min( 

189 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

190 ) 

191 # if return indices is not supported, call a tl.argmin in addition 

192 # local_argmin = tl.argmin(inp_vals, 1) 

193 update = local_min < min_values 

194 min_values = tl.where(update, local_min, min_values) 

195 argmin_values = tl.where(update, start_n + local_argmin, argmin_values) 

196 

197 offset_index = m_offset * K + pid_k 

198 out_index_ptrs = out_index + offset_index 

199 mask1 = m_offset < M 

200 tl.store(out_index_ptrs, argmin_values, mask=mask1) 

201 

202 

203def argmin(inp, dim=None, keepdim=False, *, dtype=None): 

204 logger.debug("GEMS ARGMIN") 

205 if dim is None: 

206 M = inp.numel() 

207 if dtype is None: 

208 dtype = inp.dtype 

209 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

210 mid_size = triton.cdiv(M, block_size) 

211 block_mid = triton.next_power_of_2(mid_size) 

212 

213 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

214 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

215 if keepdim: 

216 shape = list(inp.shape) 

217 for i in range(0, inp.dim()): 

218 shape[i] = 1 

219 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

220 else: 

221 out = torch.empty([], dtype=torch.int64, device=inp.device) 

222 

223 with torch_device_fn.device(inp.device): 

224 argmin_kernel_1[(mid_size, 1, 1)]( 

225 inp, 

226 mid_value, 

227 mid_index, 

228 M, 

229 block_size, 

230 ) 

231 argmin_kernel_2[(1, 1, 1)]( 

232 mid_value, 

233 mid_index, 

234 out, 

235 mid_size, 

236 block_mid, 

237 ) 

238 return out 

239 else: 

240 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

241 shape = inp.shape 

242 dim = dim % inp.ndim 

243 N = shape[dim] 

244 M = math.prod(shape[:dim]) 

245 K = inp.numel() // M // N 

246 inp = inp.contiguous() 

247 

248 shape_list = list(shape) 

249 shape_list[dim] = 1 

250 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

251 if not keepdim: 

252 out_index = torch.squeeze(out_index, dim) 

253 

254 grid = lambda meta: ( 

255 triton.cdiv(M, meta["BLOCK_M"]), 

256 K, 

257 ) 

258 if K == 1 and inp.dtype != torch.int32 and inp.dtype != torch.int16: 

259 with torch_device_fn.device(inp.device): 

260 argmin_kernel_opt_k1[grid]( 

261 inp, 

262 out_index, 

263 M, 

264 N, 

265 ) 

266 

267 else: 

268 torch2triton_dtype = { 

269 torch.float16: tl.float16, 

270 torch.bfloat16: tl.bfloat16, 

271 torch.float32: tl.float32, 

272 } 

273 # general support for other (N, K) 

274 if ( 

275 (N % 64 == 0 or N == 512) 

276 and (K % 32 == 0) 

277 and M % 8 == 0 

278 and inp.dtype != torch.int32 

279 and inp.dtype != torch.int16 

280 ): 

281 triton_dtype = torch2triton_dtype[inp.dtype] 

282 # use default paramerter to calcualte grid 

283 grid_for_split_K = (triton.cdiv(M, 8), triton.cdiv(K, 32)) 

284 with torch_device_fn.device(inp.device): 

285 argmin_split_K_kernel_merged[grid_for_split_K]( 

286 inp, 

287 out_index, 

288 M, 

289 N, 

290 K, 

291 dtype=triton_dtype, 

292 ) 

293 else: 

294 with torch_device_fn.device(inp.device): 

295 argmin_kernel[grid]( 

296 inp, 

297 out_index, 

298 M, 

299 N, 

300 K, 

301 ) 

302 

303 return out_index