Coverage for src/flag_gems/runtime/backend/_cambricon/ops/topk.py: 0%

174 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.ops.topk import topk_stage1_kernel, topk_stage2_kernel 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11 

12from ..utils import TOTAL_CORE_NUM 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min) 

16_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max) 

17_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min) 

18_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max) 

19_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min) 

20_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max) 

21_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

22_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

23_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

24_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

25_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

26_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

27 

28 

29@triton.jit 

30def _get_finfo_val( 

31 dtype, 

32 return_max, 

33): 

34 if dtype is tl.float32: 

35 if return_max: 

36 return _MAX_FLOAT32_VAL 

37 else: 

38 return _MIN_FLOAT32_VAL 

39 elif dtype is tl.float16: 

40 if return_max: 

41 return _MAX_FLOAT16_VAL 

42 else: 

43 return _MIN_FLOAT16_VAL 

44 elif dtype is tl.bfloat16: 

45 if return_max: 

46 return _MAX_BFLOAT16_VAL 

47 else: 

48 return _MIN_BFLOAT16_VAL 

49 

50 

51@triton.jit 

52def _get_iinfo_val( 

53 dtype, 

54 return_max, 

55): 

56 if dtype is tl.int16: 

57 if return_max: 

58 return _MAX_INT16_VAL 

59 else: 

60 return _MIN_INT16_VAL 

61 elif dtype is tl.int32: 

62 if return_max: 

63 return _MAX_INT32_VAL 

64 else: 

65 return _MIN_INT32_VAL 

66 elif dtype is tl.int64: 

67 if return_max: 

68 return _MAX_INT64_VAL 

69 else: 

70 return _MIN_INT64_VAL 

71 

72 

73@triton.jit 

74def get_topk_bubble_res( 

75 buffer, buffer_ind, k, axis, mask_val, DESCENDING, BLOCK_M, BLOCK_N 

76): 

77 kep_buffer_n = buffer 

78 topk_buffer_index_n = buffer_ind 

79 ret = tl.empty([BLOCK_M, k], dtype=buffer.dtype) 

80 ret_ind = tl.empty([BLOCK_M, k], dtype=buffer_ind.dtype) 

81 for k_ind in tl.range(0, k): 

82 if DESCENDING: 

83 sel_val, sel_index = tl.max(kep_buffer_n, axis=axis, return_indices=True) 

84 else: 

85 sel_val, sel_index = tl.min(kep_buffer_n, axis=axis, return_indices=True) 

86 

87 if BLOCK_M > 1: 

88 mask_sel = tl.arange(0, BLOCK_N)[None, :] == sel_index[:, None] 

89 tep_sel_index_buffer = tl.where(mask_sel, topk_buffer_index_n, 0) 

90 sel_index_res = tl.max(tep_sel_index_buffer, axis=axis) 

91 sel_val_res = sel_val 

92 ret[:, k_ind] = sel_val_res 

93 ret_ind[:, k_ind] = sel_index_res 

94 

95 # Update buffer. 

96 kep_buffer_n = tl.where(mask_sel, mask_val, kep_buffer_n) 

97 else: 

98 indices = sel_index[0] 

99 ret[:, k_ind] = sel_val 

100 ret_ind[:, k_ind] = topk_buffer_index_n[:, indices] 

101 # Update buffer. 

102 kep_buffer_n[:, indices] = mask_val 

103 return ret, ret_ind 

104 

105 

106BLOCK_BATCH = [1, 16] 

107BLOCK_N = [128, 512, 1024, 2048] 

108 

109 

110def topk_cfggen(): 

111 num_stage = [1, 3] 

112 configs = [ 

113 triton.Config({"TILE_M": m, "TILE_N": n}, num_warps=1, num_stages=s) 

114 for m in BLOCK_BATCH 

115 for n in BLOCK_N 

116 for s in num_stage 

117 ] 

118 return configs 

119 

120 

121def topk_config_prune(configs, named_args, **kwargs): 

122 k = named_args["k"] 

123 N = named_args["N"] 

124 block_m = named_args["BLOCK_M"] 

125 new_configs = [] 

126 

127 for config in configs: 

128 tile_n = config.kwargs["TILE_N"] 

129 tile_m = config.kwargs["TILE_M"] 

130 if tile_n < k or tile_m > block_m: 

131 continue 

132 if len(new_configs) >= 1: 

133 last_tn = new_configs[-1].kwargs["TILE_N"] 

134 last_tm = new_configs[-1].kwargs["TILE_M"] 

135 if tile_n > N and last_tn >= N and last_tm == tile_m: 

136 continue 

137 config.kwargs["TILE_M_NUM"] = triton.cdiv(block_m, tile_m) 

138 config.kwargs["TILE_N_NUM"] = triton.cdiv(N, tile_n) 

139 new_configs.append(config) 

140 

141 if (N not in BLOCK_N) and (N <= max(BLOCK_N)): 

142 for tm in BLOCK_BATCH: 

143 new_configs.append( 

144 triton.Config( 

145 { 

146 "TILE_M": tm, 

147 "TILE_N": N, 

148 "TILE_M_NUM": triton.cdiv(block_m, tm), 

149 "TILE_N_NUM": 1, 

150 }, 

151 num_warps=1, 

152 num_stages=3, 

153 ) 

154 ) 

155 return new_configs 

156 

157 

158@libentry() 

159@libtuner( 

160 configs=topk_cfggen(), 

161 key=["k", "N", "M", "BLOCK_M"], 

162 prune_configs_by={"early_config_prune": topk_config_prune}, 

163) 

164@triton.jit 

165def topk_bubble_kernel( 

166 inp_ptr, 

167 out_ptr, 

168 out_index_ptr, 

169 k: tl.constexpr, 

170 M: tl.constexpr, 

171 N: tl.constexpr, 

172 BLOCK_M: tl.constexpr, 

173 TILE_M: tl.constexpr, 

174 TILE_N: tl.constexpr, 

175 TILE_M_NUM: tl.constexpr, 

176 TILE_N_NUM: tl.constexpr, 

177 DESCENDING: tl.constexpr, 

178): 

179 pid = tl.program_id(0) 

180 m_st = pid * BLOCK_M 

181 

182 mask_val = _get_finfo_val(inp_ptr.dtype.element_ty, return_max=not DESCENDING) 

183 mask_val = mask_val.to(inp_ptr.dtype.element_ty) 

184 

185 for m_block_ind in tl.range(0, TILE_M_NUM): 

186 m_iter_st = m_block_ind * TILE_M + m_st 

187 m_offset_val = m_iter_st + tl.arange(0, TILE_M) 

188 m_offset = m_offset_val[:, None] 

189 m_offset_mask = m_offset < M 

190 

191 topk_buffer_n = tl.full( 

192 [TILE_M, TILE_N_NUM * k], value=mask_val, dtype=inp_ptr.dtype.element_ty 

193 ) 

194 topk_buffer_index_n = tl.full( 

195 [TILE_M, TILE_N_NUM * k], value=0, dtype=out_index_ptr.dtype.element_ty 

196 ) 

197 for n_block_ind in tl.range(0, TILE_N_NUM): 

198 n_st = n_block_ind * TILE_N 

199 n_offset = n_st + tl.arange(0, TILE_N)[None, :] 

200 n_offset_mask = n_offset < N 

201 

202 inp_mask = m_offset_mask & n_offset_mask 

203 inp_ptrs = inp_ptr + m_offset * N + n_offset 

204 block_inp_val = tl.load(inp_ptrs, mask=inp_mask, other=mask_val) 

205 

206 local_buffer, local_buffer_ind = get_topk_bubble_res( 

207 block_inp_val, 

208 n_offset.to(out_index_ptr.dtype.element_ty), 

209 k, 

210 1, 

211 mask_val, 

212 DESCENDING, 

213 TILE_M, 

214 TILE_N, 

215 ) 

216 tep_index = n_block_ind * k 

217 topk_buffer_n[:, tep_index : tep_index + k] = local_buffer 

218 topk_buffer_index_n[:, tep_index : tep_index + k] = local_buffer_ind 

219 if TILE_N_NUM > 1: 

220 global_res, global_res_ind = get_topk_bubble_res( 

221 topk_buffer_n, 

222 topk_buffer_index_n, 

223 k, 

224 1, 

225 mask_val, 

226 DESCENDING, 

227 TILE_M, 

228 TILE_N_NUM * k, 

229 ) 

230 else: 

231 global_res = topk_buffer_n 

232 global_res_ind = topk_buffer_index_n 

233 

234 # Store topk. 

235 store_ptrs = m_offset * k + tl.arange(0, k)[None, :] 

236 store_mask = m_offset_mask 

237 tl.store(store_ptrs + out_ptr, global_res, store_mask) 

238 tl.store(store_ptrs + out_index_ptr, global_res_ind, store_mask) 

239 

240 

241def topk(x, k, dim=-1, largest=True, sorted=True): 

242 logger.debug("GEMS_CAMBRICON TOPK") 

243 # If dim equals to last dim, we set it to -1. 

244 if dim < 0: 

245 dim = dim + x.ndim 

246 

247 assert dim == x.ndim - 1, "Currently only support topk in last dimension" 

248 assert sorted, "Currently only support sorted == True" 

249 

250 descending = True 

251 if not largest: 

252 descending = False 

253 

254 topk_elem_cnt = x.shape[dim] 

255 batch_size = math.prod(x.shape) // topk_elem_cnt 

256 out_shape = x.shape[:-1] + (k,) 

257 

258 if k <= math.log2(topk_elem_cnt): 

259 logger.debug("GEMS_CAMBRICON TOPK USING BUBBLE") 

260 topk_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

261 topk_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

262 

263 def grid_fn(meta): 

264 return (min(batch_size, TOTAL_CORE_NUM),) 

265 

266 block_m = triton.cdiv(batch_size, TOTAL_CORE_NUM) 

267 topk_bubble_kernel[grid_fn]( 

268 x, 

269 topk_out, 

270 topk_out_idx, 

271 k, 

272 batch_size, 

273 topk_elem_cnt, 

274 block_m, 

275 DESCENDING=descending, 

276 ) 

277 return (topk_out, topk_out_idx) 

278 else: 

279 logger.debug("GEMS_CAMBRICON TOPK USING SORT") 

280 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size. 

281 if topk_elem_cnt < 1024: 

282 chunk_size = 256 

283 else: 

284 chunk_size = 1024 

285 

286 # Note(Zhengzekang): We should promise chunk_size is larger than k. 

287 if chunk_size < k: 

288 chunk_size = triton.next_power_of_2(k) 

289 

290 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size) 

291 

292 stage1_out = torch.empty( 

293 batch_size * chunk_num * k, device=x.device, dtype=x.dtype 

294 ) 

295 stage1_out_idx = torch.empty( 

296 batch_size * chunk_num * k, device=x.device, dtype=torch.int64 

297 ) 

298 

299 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

300 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

301 

302 with torch_device_fn.device(x.device): 

303 topk_stage1_kernel[ 

304 batch_size, 

305 chunk_num, 

306 ]( 

307 stage1_out, # pointer to the output 

308 stage1_out_idx, # pointer to the output 

309 x, # pointer to the input 

310 k, 

311 topk_elem_cnt, 

312 chunk_size, 

313 descending, 

314 ) 

315 stage2_elem_cnt = chunk_num * k 

316 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt) 

317 

318 with torch_device_fn.device(x.device): 

319 topk_stage2_kernel[batch_size,]( 

320 stage2_out, 

321 stage2_out_idx, 

322 stage1_out, 

323 stage1_out_idx, 

324 dim, 

325 k, 

326 stage2_elem_cnt, 

327 BLOCK_SIZE, 

328 descending, 

329 ) 

330 

331 return (stage2_out, stage2_out_idx)