Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/argmax.py: 0%

137 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_min 

13 

14torch_dtype_to_tl_dtype_and_min_value = { 

15 torch.int16: (tl.int16, torch.iinfo(torch.int16).min), 

16 torch.int32: (tl.int32, torch.iinfo(torch.int32).min), 

17 torch.float16: (tl.float16, torch.finfo(torch.float16).min), 

18 torch.float32: (tl.float32, torch.finfo(torch.float32).min), 

19 torch.bfloat16: (tl.float32, torch.finfo(torch.float32).min), 

20} 

21logger = logging.getLogger(__name__) 

22 

23 

24@libentry() 

25@triton.jit 

26def argmax_kernel_1( 

27 inp, 

28 mid_value, 

29 mid_index, 

30 M, 

31 BLOCK_SIZE: tl.constexpr, 

32): 

33 pid = tle.program_id(0) 

34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

35 inp_ptrs = inp + offset 

36 mask = offset < M 

37 min_value = get_dtype_min(inp.type.element_ty) 

38 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) 

39 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) 

40 max_index = max_index + pid * BLOCK_SIZE 

41 mid_value_ptr = mid_value + pid 

42 max_index_ptr = mid_index + pid 

43 tl.store(mid_value_ptr, max_val) 

44 tl.store(max_index_ptr, max_index) 

45 

46 

47@libentry() 

48@triton.jit 

49def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr): 

50 offset = tl.arange(0, BLOCK_MID) 

51 mid_ptrs = mid_value + offset 

52 mask = offset < mid_size 

53 min_value = get_dtype_min(mid_value.type.element_ty) 

54 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) 

55 index_val = tl.argmax(mid_val, axis=0) 

56 mid_index_ptrs = mid_index + index_val 

57 out_val = tl.load(mid_index_ptrs) 

58 tl.store(out, out_val) 

59 

60 

61def heur_m_block_size(args): 

62 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

63 

64 

65def heur_n_block_size(args): 

66 import builtins 

67 

68 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

69 

70 

71@libentry() 

72# @triton.heuristics(runtime.get_heuristic_config("argmax")) 

73@triton.heuristics( 

74 values={ 

75 "BLOCK_M": heur_m_block_size, 

76 "BLOCK_N": heur_n_block_size, 

77 }, 

78) 

79@triton.jit 

80def argmax_kernel( 

81 inp, 

82 out_index, 

83 M: tl.constexpr, 

84 N: tl.constexpr, 

85 K: tl.constexpr, 

86 BLOCK_M: tl.constexpr, 

87 BLOCK_N: tl.constexpr, 

88): 

89 # set offset 

90 pid_m = tle.program_id(0) 

91 pid_k = tle.program_id(1) 

92 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

93 

94 dtype = inp.type.element_ty 

95 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

96 min_value = get_dtype_min(dtype) 

97 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value) 

98 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

99 for start_n in range(0, N, BLOCK_N): 

100 n_offset = start_n + tl.arange(0, BLOCK_N) 

101 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

102 mask = m_offset[:, None] < M and n_offset[None, :] < N 

103 inp_ptrs = inp + offset 

104 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

105 local_max, local_argmax = tl.max( 

106 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

107 ) 

108 # if return indices is not supported, call a tl.argmax in addition 

109 # local_argmax = tl.argmax(inp_vals, 1) 

110 update = local_max > max_values 

111 max_values = tl.where(update, local_max, max_values) 

112 argmax_values = tl.where(update, start_n + local_argmax, argmax_values) 

113 

114 offset_index = m_offset * K + pid_k 

115 out_index_ptrs = out_index + offset_index 

116 mask1 = m_offset < M 

117 tl.store(out_index_ptrs, argmax_values, mask=mask1) 

118 

119 

120@libentry() 

121@triton.heuristics(runtime.get_heuristic_config("argmax")) 

122@triton.jit 

123def argmax_kernel_small_n( 

124 inp, 

125 out_index, 

126 M, 

127 N, 

128 K, 

129 tl_dtype: tl.constexpr, 

130 dtype_min_value: tl.constexpr, 

131 BLOCK_M: tl.constexpr, 

132 BLOCK_N: tl.constexpr, 

133): 

134 # set offset 

135 pid_m = tle.program_id(0) 

136 pid_k = tle.program_id(1) 

137 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

138 

139 if tl_dtype is tl.int16: 

140 tl_dtype = tl.int32 

141 n_offset = tl.arange(0, BLOCK_N) 

142 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

143 offset_index = m_offset * K + pid_k 

144 # set mask 

145 mask1 = m_offset < M 

146 mask = m_offset[:, None] < M and n_offset[None, :] < N 

147 inp_ptrs = inp + offset 

148 inp_vals = tl.load(inp_ptrs, mask=mask, other=dtype_min_value) 

149 _, result_index = tl.max(inp_vals, axis=1, return_indices=True) 

150 

151 out_index_ptrs = out_index + offset_index 

152 

153 tl.store(out_index_ptrs, result_index, mask=mask1) 

154 

155 

156def argmax(inp, dim=None, keepdim=False, *, dtype=None): 

157 logger.debug("GEMS ARGMAX") 

158 if dim is None: 

159 M = inp.numel() 

160 if dtype is None: 

161 dtype = inp.dtype 

162 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

163 mid_size = triton.cdiv(M, block_size) 

164 block_mid = triton.next_power_of_2(mid_size) 

165 

166 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

167 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

168 if keepdim: 

169 shape = list(inp.shape) 

170 for i in range(0, inp.dim()): 

171 shape[i] = 1 

172 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

173 else: 

174 out = torch.empty([], dtype=torch.int64, device=inp.device) 

175 

176 with torch_device_fn.device(inp.device): 

177 argmax_kernel_1[(mid_size, 1, 1)]( 

178 inp, 

179 mid_value, 

180 mid_index, 

181 M, 

182 block_size, 

183 ) 

184 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid) 

185 return out 

186 else: 

187 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

188 shape = inp.shape 

189 dim = dim % inp.ndim 

190 if inp.numel() == 0: 

191 out_shape = list(shape) 

192 if keepdim: 

193 out_shape[dim] = 1 

194 else: 

195 del out_shape[dim] 

196 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device) 

197 N = shape[dim] 

198 M = math.prod(shape[:dim]) 

199 K = inp.numel() // M // N 

200 

201 inp = inp.contiguous() 

202 

203 shape_list = list(shape) 

204 shape_list[dim] = 1 

205 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

206 if not keepdim: 

207 out_index = torch.squeeze(out_index, dim) 

208 grid = lambda meta: ( 

209 triton.cdiv(M, meta["BLOCK_M"]), 

210 K, 

211 ) 

212 

213 if N == 1: 

214 tl_dtype, dtype_min_value = torch_dtype_to_tl_dtype_and_min_value[inp.dtype] 

215 with torch_device_fn.device(inp.device): 

216 argmax_kernel_small_n[grid]( 

217 inp, 

218 out_index, 

219 M, 

220 N, 

221 K, 

222 tl_dtype, 

223 dtype_min_value, 

224 ) 

225 return out_index 

226 

227 with torch_device_fn.device(inp.device): 

228 argmax_kernel[grid]( 

229 inp, 

230 out_index, 

231 M, 

232 N, 

233 K, 

234 is_use_mask_zero=True, 

235 ) 

236 

237 return out_index