Coverage for src/flag_gems/runtime/backend/_cambricon/ops/argmax.py: 0%

133 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils.shape_utils import can_use_int32_index 

12 

13from ..utils import TOTAL_CORE_NUM 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16 

17 

18def cfggen_reduce_op(): 

19 return runtime.get_tuned_config("argmax_kernel_1") 

20 

21 

22@libentry() 

23@triton.jit 

24def argmax_kernel_once( 

25 inp, 

26 out, 

27 M: tl.constexpr, 

28): 

29 offset = tl.arange(0, M) 

30 inp_val = tl.load(inp + offset) 

31 index_val = tl.argmax(inp_val, axis=0) 

32 tl.store(out, index_val.to(tl.int64)) 

33 

34 

35@libentry() 

36@libtuner( 

37 configs=cfggen_reduce_op(), 

38 key=["M"], 

39 strategy=["log"], 

40) 

41@triton.jit 

42def argmax_kernel_1( 

43 inp, 

44 mid_value, 

45 mid_index, 

46 real_size, 

47 M, 

48 BLOCK_SIZE: tl.constexpr, 

49 INT64_INDEX: tl.constexpr = False, 

50): 

51 pid = tl.program_id(0) 

52 if INT64_INDEX: 

53 pid = pid.to(tl.int64) 

54 num_jobs = tl.num_programs(axis=0) 

55 

56 size_per_job = (M + num_jobs - 1) // num_jobs 

57 start_idx = pid * size_per_job 

58 end_idx = min(start_idx + size_per_job, M) 

59 

60 max_tmp = -float("inf") 

61 index_tmp = 0 

62 if INT64_INDEX: 

63 index_tmp = index_tmp.to(tl.int64) 

64 for off in range(start_idx, end_idx, BLOCK_SIZE): 

65 offset = off + tl.arange(0, BLOCK_SIZE) 

66 mask = offset < end_idx 

67 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")) 

68 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) 

69 if max_val > max_tmp: 

70 max_tmp = max_val.to(tl.float32) 

71 index_tmp = max_index + off 

72 mid_value_ptr = mid_value + pid 

73 max_index_ptr = mid_index + pid 

74 tl.store(mid_value_ptr, max_tmp) 

75 tl.store(max_index_ptr, index_tmp) 

76 tl.store(real_size, num_jobs) 

77 

78 

79@libentry() 

80@triton.jit 

81def argmax_kernel_2(mid_value, mid_index, out, real_size, mid_size: tl.constexpr): 

82 size = tl.load(real_size) 

83 offset = tl.arange(0, mid_size) 

84 mid_ptrs = mid_value + offset 

85 mid_val = tl.load(mid_ptrs, mask=offset < size, other=-float("inf")) 

86 index_val = tl.argmax(mid_val, axis=0) 

87 mid_index_ptrs = mid_index + index_val 

88 out_val = tl.load(mid_index_ptrs) 

89 tl.store(out, out_val) 

90 

91 

92@libentry() 

93@libtuner( 

94 configs=runtime.get_tuned_config("argmax"), 

95 key=[ 

96 "M", 

97 "N", 

98 ], 

99 strategy=["log", "log"], 

100) 

101@triton.jit 

102def argmax_kernel( 

103 inp, 

104 out_index, 

105 M, 

106 N, 

107 K, 

108 BLOCK_M: tl.constexpr, 

109 BLOCK_N: tl.constexpr, 

110 INT64_INDEX: tl.constexpr = False, 

111): 

112 # set offset 

113 pid_m = tl.program_id(0) 

114 pid_k = tl.program_id(1) 

115 if INT64_INDEX: 

116 pid_m = pid_m.to(tl.int64) 

117 pid_k = pid_k.to(tl.int64) 

118 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

119 

120 max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf")) 

121 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

122 for start_n in range(0, N, BLOCK_N): 

123 n_offset = start_n + tl.arange(0, BLOCK_N) 

124 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

125 mask = m_offset[:, None] < M and n_offset[None, :] < N 

126 inp_ptrs = inp + offset 

127 inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")) 

128 local_max, local_argmax = tl.max( 

129 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

130 ) 

131 # if return indices is not supported, call a tl.argmax in addition 

132 # local_argmax = tl.argmax(inp_vals, 1) 

133 update = local_max > max_values 

134 max_values = tl.where(update, local_max, max_values) 

135 argmax_values = tl.where(update, start_n + local_argmax, argmax_values) 

136 

137 offset_index = m_offset * K + pid_k 

138 out_index_ptrs = out_index + offset_index 

139 mask1 = m_offset < M 

140 tl.store(out_index_ptrs, argmax_values, mask=mask1) 

141 

142 

143def argmax(inp, dim=None, keepdim=False, *, dtype=None): 

144 logger.debug("GEMS_CAMBRICON ARGMAX") 

145 if dim is None: 

146 M = inp.numel() 

147 if dtype is None: 

148 dtype = inp.dtype 

149 

150 use_int64_index = not can_use_int32_index(inp) 

151 

152 if keepdim: 

153 shape = list(inp.shape) 

154 for i in range(0, inp.dim()): 

155 shape[i] = 1 

156 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

157 else: 

158 out = torch.empty([], dtype=torch.int64, device=inp.device) 

159 

160 if M <= 65530: 

161 with torch_device_fn.device(inp.device): 

162 argmax_kernel_once[(1, 1, 1)](inp, out, M) 

163 else: 

164 grid = lambda meta: ( 

165 min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

166 ) 

167 mid_size = TOTAL_CORE_NUM 

168 real_size = torch.empty([], dtype=torch.int32, device=inp.device) 

169 mid_value = torch.empty((mid_size,), dtype=torch.float32, device=inp.device) 

170 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

171 with torch_device_fn.device(inp.device): 

172 argmax_kernel_1[grid]( 

173 inp, 

174 mid_value, 

175 mid_index, 

176 real_size, 

177 M, 

178 INT64_INDEX=use_int64_index, 

179 ) 

180 argmax_kernel_2[(1, 1, 1)]( 

181 mid_value, mid_index, out, real_size, mid_size 

182 ) 

183 return out 

184 else: 

185 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

186 shape = inp.shape 

187 dim = dim % inp.ndim 

188 if inp.numel() == 0: 

189 out_shape = list(shape) 

190 if keepdim: 

191 out_shape[dim] = 1 

192 else: 

193 del out_shape[dim] 

194 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device) 

195 N = shape[dim] 

196 M = math.prod(shape[:dim]) 

197 K = inp.numel() // M // N 

198 

199 inp = inp.contiguous() 

200 use_int64_index = not can_use_int32_index(inp) 

201 

202 shape_list = list(shape) 

203 shape_list[dim] = 1 

204 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

205 if not keepdim: 

206 out_index = torch.squeeze(out_index, dim) 

207 

208 grid = lambda meta: ( 

209 triton.cdiv(M, meta["BLOCK_M"]), 

210 K, 

211 ) 

212 with torch_device_fn.device(inp.device): 

213 argmax_kernel[grid](inp, out_index, M, N, K, INT64_INDEX=use_int64_index) 

214 

215 return out_index