Coverage for src/flag_gems/ops/argmax.py: 44%

163 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_min 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.jit 

19def argmax_kernel_1( 

20 inp, 

21 mid_value, 

22 mid_index, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 min_value = get_dtype_min(inp.type.element_ty) 

31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) 

32 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) 

33 max_index = max_index + pid * BLOCK_SIZE 

34 mid_value_ptr = mid_value + pid 

35 max_index_ptr = mid_index + pid 

36 tl.store(mid_value_ptr, max_val) 

37 tl.store(max_index_ptr, max_index) 

38 

39 

40@libentry() 

41@triton.jit 

42def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr): 

43 offset = tl.arange(0, BLOCK_MID) 

44 mid_ptrs = mid_value + offset 

45 mask = offset < mid_size 

46 min_value = get_dtype_min(mid_value.type.element_ty) 

47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) 

48 index_val = tl.argmax(mid_val, axis=0) 

49 mid_index_ptrs = mid_index + index_val 

50 out_val = tl.load(mid_index_ptrs) 

51 tl.store(out, out_val) 

52 

53 

54@libentry() 

55@triton.heuristics(runtime.get_heuristic_config("argmax_non_inner")) 

56@triton.jit 

57def argmax_kernel_non_inner( 

58 inp, 

59 out_index, 

60 M, 

61 N, 

62 K, 

63 TILE_K: tl.constexpr, 

64 TILE_N: tl.constexpr, 

65 ONE_TILE_PER_CTA: tl.constexpr, 

66): 

67 pid_m = tle.program_id(0) 

68 pid_k = tle.program_id(1) 

69 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K) 

70 

71 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

72 inp.dtype.element_ty == tl.bfloat16 

73 ): 

74 cdtype = tl.float32 

75 else: 

76 cdtype = inp.dtype.element_ty 

77 

78 min_value = get_dtype_min(cdtype) 

79 

80 if ONE_TILE_PER_CTA: 

81 n_offset = tl.arange(0, TILE_N) 

82 offset = pid_m * N * K + n_offset[:, None] * K + k_offset 

83 mask = k_offset < K and n_offset[:, None] < N 

84 inp_ptrs = inp + offset 

85 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

86 local_max, local_argmax = tl.max( 

87 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True 

88 ) 

89 offset_index = pid_m * K + k_offset 

90 out_index_ptrs = out_index + offset_index 

91 mask1 = k_offset < K 

92 tl.store(out_index_ptrs, local_argmax, mask=mask1) 

93 else: 

94 max_values = tl.full([TILE_K], dtype=cdtype, value=min_value) 

95 argmax_values = tl.full([TILE_K], dtype=tl.int64, value=0) 

96 

97 for start_n in range(0, N, TILE_N): 

98 n_offset = start_n + tl.arange(0, TILE_N) 

99 offset = pid_m * N * K + n_offset[:, None] * K + k_offset 

100 mask = k_offset < K and n_offset[:, None] < N 

101 inp_ptrs = inp + offset 

102 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

103 local_max, local_argmax = tl.max( 

104 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True 

105 ) 

106 update = local_max > max_values 

107 max_values = tl.where(update, local_max, max_values) 

108 argmax_values = tl.where(update, start_n + local_argmax, argmax_values) 

109 offset_index = pid_m * K + k_offset 

110 out_index_ptrs = out_index + offset_index 

111 mask1 = k_offset < K 

112 tl.store(out_index_ptrs, argmax_values, mask=mask1) 

113 

114 

115@libentry() 

116@triton.heuristics(runtime.get_heuristic_config("argmax_inner")) 

117@triton.jit 

118def argmax_kernel_inner( 

119 inp, 

120 out_index, 

121 M, 

122 N, 

123 TILE_N: tl.constexpr, 

124 ONE_TILE_PER_CTA: tl.constexpr, 

125): 

126 pid_m = tle.program_id(0) 

127 

128 dtype = inp.type.element_ty 

129 min_value = get_dtype_min(dtype) 

130 

131 if ONE_TILE_PER_CTA: 

132 n_offset = tl.arange(0, TILE_N) 

133 offset = pid_m * N + n_offset 

134 mask = n_offset < N 

135 inp_ptrs = inp + offset 

136 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

137 local_max, local_argmax = tl.max( 

138 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True 

139 ) 

140 out_index_ptrs = out_index + pid_m 

141 tl.store(out_index_ptrs, local_argmax) 

142 else: 

143 max_values = min_value 

144 argmax_values = 0 

145 

146 loop_time = N // TILE_N 

147 remainder = N % TILE_N 

148 for start_n in range(0, loop_time): 

149 n_offset = start_n * TILE_N + tl.arange(0, TILE_N) 

150 offset = pid_m * N + n_offset 

151 inp_ptrs = inp + offset 

152 inp_vals = tl.load(inp_ptrs) 

153 local_max, local_argmax = tl.max( 

154 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True 

155 ) 

156 update = local_max > max_values 

157 max_values = tl.where(update, local_max, max_values) 

158 argmax_values = tl.where( 

159 update, start_n * TILE_N + local_argmax, argmax_values 

160 ) 

161 

162 if remainder: 

163 n_offset = loop_time * TILE_N + tl.arange(0, TILE_N) 

164 offset = pid_m * N + n_offset 

165 mask = n_offset < N 

166 inp_ptrs = inp + offset 

167 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

168 local_max, local_argmax = tl.max( 

169 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True 

170 ) 

171 update = local_max > max_values 

172 max_values = tl.where(update, local_max, max_values) 

173 argmax_values = tl.where( 

174 update, loop_time * TILE_N + local_argmax, argmax_values 

175 ) 

176 

177 out_index_ptrs = out_index + pid_m 

178 tl.store(out_index_ptrs, argmax_values) 

179 

180 

181def argmax(inp, dim=None, keepdim=False, *, dtype=None): 

182 logger.debug("GEMS ARGMAX") 

183 if dim is None: 

184 M = inp.numel() 

185 if dtype is None: 

186 dtype = inp.dtype 

187 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

188 mid_size = triton.cdiv(M, block_size) 

189 block_mid = triton.next_power_of_2(mid_size) 

190 

191 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

192 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

193 if keepdim: 

194 shape = list(inp.shape) 

195 for i in range(0, inp.dim()): 

196 shape[i] = 1 

197 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

198 else: 

199 out = torch.empty([], dtype=torch.int64, device=inp.device) 

200 

201 with torch_device_fn.device(inp.device): 

202 argmax_kernel_1[(mid_size, 1, 1)]( 

203 inp, 

204 mid_value, 

205 mid_index, 

206 M, 

207 block_size, 

208 ) 

209 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid) 

210 return out 

211 else: 

212 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

213 shape = inp.shape 

214 dim = dim % inp.ndim 

215 if inp.numel() == 0: 

216 out_shape = list(shape) 

217 if keepdim: 

218 out_shape[dim] = 1 

219 else: 

220 del out_shape[dim] 

221 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device) 

222 N = shape[dim] 

223 M = math.prod(shape[:dim]) 

224 K = inp.numel() // M // N 

225 

226 inp = inp.contiguous() 

227 

228 shape_list = list(shape) 

229 shape_list[dim] = 1 

230 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

231 if not keepdim: 

232 out_index = torch.squeeze(out_index, dim) 

233 

234 with torch_device_fn.device(inp.device): 

235 if K > 1: 

236 grid = lambda meta: ( 

237 M, 

238 triton.cdiv(K, meta["TILE_K"]), 

239 ) 

240 argmax_kernel_non_inner[grid]( 

241 inp, 

242 out_index, 

243 M, 

244 N, 

245 K, 

246 ) 

247 else: 

248 grid = lambda meta: (M, 1, 1) 

249 argmax_kernel_inner[grid]( 

250 inp, 

251 out_index, 

252 M, 

253 N, 

254 ) 

255 return out_index