Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/argmax.py: 0%

132 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils.shape_utils import can_use_int32_index 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16def cfggen_reduce_op(): 

17 return runtime.get_tuned_config("argmax_kernel_1") 

18 

19 

20@libentry() 

21@triton.jit 

22def argmax_kernel_once( 

23 inp, 

24 out, 

25 M: tl.constexpr, 

26): 

27 offset = tl.arange(0, M) 

28 inp_val = tl.load(inp + offset) 

29 index_val = tl.argmax(inp_val, axis=0) 

30 tl.store(out, index_val.to(tl.int64)) 

31 

32 

33@libentry() 

34@libtuner( 

35 configs=cfggen_reduce_op(), 

36 key=["M"], 

37 strategy=["log"], 

38) 

39@triton.jit 

40def argmax_kernel_1( 

41 inp, 

42 mid_value, 

43 mid_index, 

44 real_size, 

45 M, 

46 BLOCK_SIZE: tl.constexpr, 

47 INT64_INDEX: tl.constexpr = False, 

48): 

49 pid = tl.program_id(0) 

50 if INT64_INDEX: 

51 pid = pid.to(tl.int64) 

52 num_jobs = tl.num_programs(axis=0) 

53 

54 size_per_job = (M + num_jobs - 1) // num_jobs 

55 start_idx = pid * size_per_job 

56 end_idx = min(start_idx + size_per_job, M) 

57 

58 max_tmp = -float("inf") 

59 index_tmp = 0 

60 if INT64_INDEX: 

61 index_tmp = index_tmp.to(tl.int64) 

62 for off in range(start_idx, end_idx, BLOCK_SIZE): 

63 offset = off + tl.arange(0, BLOCK_SIZE) 

64 mask = offset < end_idx 

65 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")) 

66 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) 

67 if max_val > max_tmp: 

68 max_tmp = max_val.to(tl.float32) 

69 index_tmp = max_index + off 

70 mid_value_ptr = mid_value + pid 

71 max_index_ptr = mid_index + pid 

72 tl.store(mid_value_ptr, max_tmp) 

73 tl.store(max_index_ptr, index_tmp) 

74 tl.store(real_size, num_jobs) 

75 

76 

77@libentry() 

78@triton.jit 

79def argmax_kernel_2(mid_value, mid_index, out, real_size, mid_size: tl.constexpr): 

80 size = tl.load(real_size) 

81 offset = tl.arange(0, mid_size) 

82 mid_ptrs = mid_value + offset 

83 mid_val = tl.load(mid_ptrs, mask=offset < size, other=-float("inf")) 

84 index_val = tl.argmax(mid_val, axis=0) 

85 mid_index_ptrs = mid_index + index_val 

86 out_val = tl.load(mid_index_ptrs) 

87 tl.store(out, out_val) 

88 

89 

90@libentry() 

91@libtuner( 

92 configs=runtime.get_tuned_config("argmax"), 

93 key=[ 

94 "M", 

95 "N", 

96 ], 

97 strategy=["log", "log"], 

98) 

99@triton.jit 

100def argmax_kernel( 

101 inp, 

102 out_index, 

103 M, 

104 N, 

105 K, 

106 BLOCK_M: tl.constexpr, 

107 BLOCK_N: tl.constexpr, 

108 INT64_INDEX: tl.constexpr = False, 

109): 

110 # set offset 

111 pid_m = tl.program_id(0) 

112 pid_k = tl.program_id(1) 

113 if INT64_INDEX: 

114 pid_m = pid_m.to(tl.int64) 

115 pid_k = pid_k.to(tl.int64) 

116 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

117 

118 max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf")) 

119 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

120 for start_n in range(0, N, BLOCK_N): 

121 n_offset = start_n + tl.arange(0, BLOCK_N) 

122 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

123 mask = m_offset[:, None] < M and n_offset[None, :] < N 

124 inp_ptrs = inp + offset 

125 inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")) 

126 local_max, local_argmax = tl.max( 

127 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

128 ) 

129 # if return indices is not supported, call a tl.argmax in addition 

130 # local_argmax = tl.argmax(inp_vals, 1) 

131 update = local_max > max_values 

132 max_values = tl.where(update, local_max, max_values) 

133 argmax_values = tl.where(update, start_n + local_argmax, argmax_values) 

134 

135 offset_index = m_offset * K + pid_k 

136 out_index_ptrs = out_index + offset_index 

137 mask1 = m_offset < M 

138 tl.store(out_index_ptrs, argmax_values, mask=mask1) 

139 

140 

141def argmax(inp, dim=None, keepdim=False, *, dtype=None): 

142 logger.debug("GEMS_TSINGMICRO ARGMAX") 

143 if dim is None: 

144 M = inp.numel() 

145 if dtype is None: 

146 dtype = inp.dtype 

147 

148 use_int64_index = not can_use_int32_index(inp) 

149 

150 if keepdim: 

151 shape = list(inp.shape) 

152 for i in range(0, inp.dim()): 

153 shape[i] = 1 

154 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

155 else: 

156 out = torch.empty([], dtype=torch.int64, device=inp.device) 

157 

158 if M <= 65530: 

159 with torch_device_fn.device(inp.device): 

160 argmax_kernel_once[(1, 1, 1)](inp, out, M) 

161 else: 

162 grid = lambda meta: ( 

163 min( 

164 triton.cdiv(M, meta["BLOCK_SIZE"]), 

165 torch_device_fn.get_device_properties().multi_processor_count, 

166 ), 

167 ) 

168 mid_size = torch_device_fn.get_device_properties().multi_processor_count 

169 real_size = torch.empty([], dtype=torch.int32, device=inp.device) 

170 mid_value = torch.empty((mid_size,), dtype=torch.float32, device=inp.device) 

171 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

172 with torch_device_fn.device(inp.device): 

173 argmax_kernel_1[grid]( 

174 inp, 

175 mid_value, 

176 mid_index, 

177 real_size, 

178 M, 

179 INT64_INDEX=use_int64_index, 

180 ) 

181 argmax_kernel_2[(1, 1, 1)]( 

182 mid_value, mid_index, out, real_size, mid_size 

183 ) 

184 return out 

185 else: 

186 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

187 shape = inp.shape 

188 dim = dim % inp.ndim 

189 if inp.numel() == 0: 

190 out_shape = list(shape) 

191 if keepdim: 

192 out_shape[dim] = 1 

193 else: 

194 del out_shape[dim] 

195 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device) 

196 N = shape[dim] 

197 M = math.prod(shape[:dim]) 

198 K = inp.numel() // M // N 

199 

200 inp = inp.contiguous() 

201 use_int64_index = not can_use_int32_index(inp) 

202 

203 shape_list = list(shape) 

204 shape_list[dim] = 1 

205 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

206 if not keepdim: 

207 out_index = torch.squeeze(out_index, dim) 

208 

209 grid = lambda meta: ( 

210 triton.cdiv(M, meta["BLOCK_M"]), 

211 K, 

212 ) 

213 with torch_device_fn.device(inp.device): 

214 argmax_kernel[grid](inp, out_index, M, N, K, INT64_INDEX=use_int64_index) 

215 

216 return out_index