Coverage for src/flag_gems/runtime/backend/_ascend/ops/argmax.py: 0%

148 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_min 

13 

14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

15 

16 

17@libentry() 

18@triton.jit 

19def argmax_kernel_1( 

20 inp, 

21 mid_value, 

22 mid_index, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 min_value = get_dtype_min(inp.type.element_ty) 

31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) 

32 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) 

33 max_index = max_index + pid * BLOCK_SIZE 

34 mid_value_ptr = mid_value + pid 

35 max_index_ptr = mid_index + pid 

36 tl.store(mid_value_ptr, max_val) 

37 tl.store(max_index_ptr, max_index) 

38 

39 

40@libentry() 

41@triton.jit 

42def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr): 

43 offset = tl.arange(0, BLOCK_MID) 

44 mid_ptrs = mid_value + offset 

45 mask = offset < mid_size 

46 min_value = get_dtype_min(mid_value.type.element_ty) 

47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) 

48 index_val = tl.argmax(mid_val, axis=0) 

49 mid_index_ptrs = mid_index + index_val 

50 out_val = tl.load(mid_index_ptrs) 

51 tl.store(out, out_val) 

52 

53 

54@libentry() 

55@triton.heuristics(runtime.get_heuristic_config("argmax_non_inner")) 

56@triton.jit 

57def argmax_kernel_non_inner( 

58 inp, 

59 out_index, 

60 M, 

61 N, 

62 K, 

63 TILE_K: tl.constexpr, 

64 TILE_N: tl.constexpr, 

65 ONE_TILE_PER_CTA: tl.constexpr, 

66): 

67 pid_m = tle.program_id(0) 

68 pid_k = tle.program_id(1) 

69 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K) 

70 

71 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

72 inp.dtype.element_ty == tl.bfloat16 

73 ): 

74 cdtype = tl.float32 

75 else: 

76 cdtype = inp.dtype.element_ty 

77 

78 min_value = get_dtype_min(cdtype) 

79 

80 if ONE_TILE_PER_CTA: 

81 n_offset = tl.arange(0, TILE_N) 

82 offset = pid_m * N * K + n_offset[:, None] * K + k_offset 

83 mask = k_offset < K and n_offset[:, None] < N 

84 inp_ptrs = inp + offset 

85 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

86 local_max, local_argmax = tl.max( 

87 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True 

88 ) 

89 offset_index = pid_m * K + k_offset 

90 out_index_ptrs = out_index + offset_index 

91 mask1 = k_offset < K 

92 tl.store(out_index_ptrs, local_argmax, mask=mask1) 

93 else: 

94 max_values = tl.full([TILE_K], dtype=cdtype, value=min_value) 

95 argmax_values = tl.full([TILE_K], dtype=tl.int64, value=0) 

96 

97 for start_n in range(0, N, TILE_N): 

98 n_offset = start_n + tl.arange(0, TILE_N) 

99 offset = pid_m * N * K + n_offset[:, None] * K + k_offset 

100 mask = k_offset < K and n_offset[:, None] < N 

101 inp_ptrs = inp + offset 

102 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

103 local_max, local_argmax = tl.max( 

104 inp_vals, 0, return_indices=True, return_indices_tie_break_left=True 

105 ) 

106 update = local_max > max_values 

107 max_values = tl.where(update, local_max, max_values) 

108 argmax_values = tl.where(update, start_n + local_argmax, argmax_values) 

109 offset_index = pid_m * K + k_offset 

110 out_index_ptrs = out_index + offset_index 

111 mask1 = k_offset < K 

112 tl.store(out_index_ptrs, argmax_values, mask=mask1) 

113 

114 

115@libentry() 

116@triton.heuristics(runtime.get_heuristic_config("argmax")) 

117@triton.jit 

118def argmax_kernel( 

119 inp, 

120 out_index, 

121 M, 

122 N, 

123 K, 

124 BLOCK_M: tl.constexpr, 

125 BLOCK_N: tl.constexpr, 

126): 

127 # set offset 

128 pid_m = tle.program_id(0) 

129 # pid_k = tle.program_id(1) 

130 for pid_k in range(K): 

131 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

132 

133 dtype = inp.type.element_ty 

134 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

135 min_value = get_dtype_min(dtype) 

136 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value) 

137 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

138 for start_n in range(0, N, BLOCK_N): 

139 n_offset = start_n + tl.arange(0, BLOCK_N) 

140 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

141 mask = m_offset[:, None] < M and n_offset[None, :] < N 

142 inp_ptrs = inp + offset 

143 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

144 local_max, local_argmax = tl.max( 

145 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

146 ) 

147 # if return indices is not supported, call a tl.argmax in addition 

148 # local_argmax = tl.argmax(inp_vals, 1) 

149 update = local_max > max_values 

150 max_values = tl.where(update, local_max, max_values) 

151 argmax_values = tl.where(update, start_n + local_argmax, argmax_values) 

152 

153 offset_index = m_offset * K + pid_k 

154 out_index_ptrs = out_index + offset_index 

155 mask1 = m_offset < M 

156 tl.store(out_index_ptrs, argmax_values, mask=mask1) 

157 

158 

159def argmax(inp, dim=None, keepdim=False, *, dtype=None): 

160 logger.debug("GEMS_ASCEND ARGMAX") 

161 if dim is None: 

162 M = inp.numel() 

163 if dtype is None: 

164 dtype = inp.dtype 

165 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

166 mid_size = triton.cdiv(M, block_size) 

167 block_mid = triton.next_power_of_2(mid_size) 

168 

169 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

170 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

171 if keepdim: 

172 shape = list(inp.shape) 

173 for i in range(0, inp.dim()): 

174 shape[i] = 1 

175 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

176 else: 

177 out = torch.empty([], dtype=torch.int64, device=inp.device) 

178 

179 with torch_device_fn.device(inp.device): 

180 argmax_kernel_1[(mid_size, 1, 1)]( 

181 inp, 

182 mid_value, 

183 mid_index, 

184 M, 

185 block_size, 

186 ) 

187 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid) 

188 return out 

189 else: 

190 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

191 shape = inp.shape 

192 dim = dim % inp.ndim 

193 if inp.numel() == 0: 

194 out_shape = list(shape) 

195 if keepdim: 

196 out_shape[dim] = 1 

197 else: 

198 del out_shape[dim] 

199 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device) 

200 N = shape[dim] 

201 M = math.prod(shape[:dim]) 

202 K = inp.numel() // M // N 

203 

204 inp = inp.contiguous() 

205 

206 shape_list = list(shape) 

207 shape_list[dim] = 1 

208 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

209 if not keepdim: 

210 out_index = torch.squeeze(out_index, dim) 

211 

212 with torch_device_fn.device(inp.device): 

213 if K > 1: 

214 grid = lambda meta: ( 

215 M, 

216 triton.cdiv(K, meta["TILE_K"]), 

217 ) 

218 argmax_kernel_non_inner[grid]( 

219 inp, 

220 out_index, 

221 M, 

222 N, 

223 K, 

224 ) 

225 else: 

226 grid = lambda meta: ( 

227 triton.cdiv(M, meta["BLOCK_M"]), 

228 # K, 

229 ) 

230 argmax_kernel[grid]( 

231 inp, 

232 out_index, 

233 M, 

234 N, 

235 K, 

236 ) 

237 return out_index