Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/topk.py: 0%

187 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7import triton.language.core as core 

8 

9try: 

10 # TODO: Triton 2.1 does not implement _log2. 

11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton. 

12 from triton.language.standard import _log2, zeros_like 

13except ImportError: 

14 pass 

15 

16from flag_gems.runtime import torch_device_fn 

17from flag_gems.utils import libentry 

18from flag_gems.utils import triton_lang_extension as tle 

19 

20logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

21_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min) 

22_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max) 

23_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min) 

24_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max) 

25_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min) 

26_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max) 

27_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

28_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

29_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

30_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

31_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

32_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

33 

34 

35@triton.jit 

36def _get_finfo_val( 

37 dtype, 

38 return_max, 

39): 

40 if dtype is tl.float32: 

41 if return_max: 

42 return _MAX_FLOAT32_VAL 

43 else: 

44 return _MIN_FLOAT32_VAL 

45 elif dtype is tl.float16: 

46 if return_max: 

47 return _MAX_FLOAT16_VAL 

48 else: 

49 return _MIN_FLOAT16_VAL 

50 elif dtype is tl.bfloat16: 

51 if return_max: 

52 return _MAX_BFLOAT16_VAL 

53 else: 

54 return _MIN_BFLOAT16_VAL 

55 

56 

57@triton.jit 

58def _get_iinfo_val( 

59 dtype, 

60 return_max, 

61): 

62 if dtype is tl.int16: 

63 if return_max: 

64 return _MAX_INT16_VAL 

65 else: 

66 return _MIN_INT16_VAL 

67 elif dtype is tl.int32: 

68 if return_max: 

69 return _MAX_INT32_VAL 

70 else: 

71 return _MIN_INT32_VAL 

72 elif dtype is tl.int64: 

73 if return_max: 

74 return _MAX_INT64_VAL 

75 else: 

76 return _MIN_INT64_VAL 

77 

78 

79@libentry() 

80@triton.jit 

81def topk_stage1_kernel( 

82 y_ptr, 

83 index_ptr, 

84 x_ptr, 

85 k, 

86 N: tl.constexpr, 

87 CHUNK_SIZE: tl.constexpr, 

88 DESCENDING: tl.constexpr, 

89): 

90 cur_batch = tle.program_id(0) 

91 cur_chunk_idx = tle.program_id(1) 

92 chunk_num = tle.num_programs(1) 

93 

94 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

95 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

96 

97 chunk_offset = cur_chunk_idx * CHUNK_SIZE 

98 x_ptr += cur_batch * N + chunk_offset 

99 

100 cols = tl.arange(0, CHUNK_SIZE) 

101 mask = (chunk_offset + cols) < N 

102 

103 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING) 

104 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32) 

105 for k_idx in range(k): 

106 if DESCENDING: 

107 chunk_select_val = tl.max(x_val) 

108 chunk_select_idx = tl.argmax(x_val, axis=0) 

109 else: 

110 chunk_select_val = tl.min(x_val) 

111 chunk_select_idx = tl.argmin(x_val, axis=0) 

112 

113 tl.store(y_ptr + k_idx, chunk_select_val) 

114 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset) 

115 

116 if DESCENDING: 

117 x_val = tl.where( 

118 cols == chunk_select_idx, 

119 _get_finfo_val(tl.float32, return_max=False), 

120 x_val, 

121 ) 

122 else: 

123 x_val = tl.where( 

124 cols == chunk_select_idx, 

125 _get_finfo_val(tl.float32, return_max=True), 

126 x_val, 

127 ) 

128 

129 

130""" 

131Note(Zhengzekang): 

132Refer from triton2.2 official `sort` implementation: 

133https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404 

134Just add indices to sort with values. 

135""" 

136 

137 

138@triton.jit 

139def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): 

140 n_outer: core.constexpr = x.numel >> n_dims 

141 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 

142 

143 # tl.device_print("shape is: ", shape) 

144 y = core.reshape(x, shape) 

145 y_idx = core.reshape(ids, shape) 

146 

147 # slice left/right with 'stride' 2**(n_dims - i - 1) 

148 mask = core.arange(0, 2)[None, :, None] 

149 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) 

150 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) 

151 left = core.reshape(left, x.shape) 

152 right = core.reshape(right, x.shape) 

153 

154 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( 

155 ids.dtype 

156 ) 

157 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( 

158 ids.dtype 

159 ) 

160 left_idx = core.reshape(left_idx, ids.shape) 

161 right_idx = core.reshape(right_idx, ids.shape) 

162 

163 # actual compare-and-swap 

164 if core.constexpr(x.dtype.primitive_bitwidth) == 8: 

165 idtype = core.int8 

166 elif core.constexpr(x.dtype.primitive_bitwidth) == 16: 

167 idtype = core.int16 

168 elif core.constexpr(x.dtype.primitive_bitwidth) == 32: 

169 idtype = core.int32 

170 elif core.constexpr(x.dtype.primitive_bitwidth) == 64: 

171 idtype = core.int64 

172 else: 

173 raise ValueError("Unsupported dtype") 

174 

175 ileft = left.to(idtype, bitcast=True) 

176 iright = right.to(idtype, bitcast=True) 

177 ix = x.to(idtype, bitcast=True) 

178 

179 cond = (left > right) ^ flip 

180 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix)) 

181 

182 if core.constexpr(ids.dtype.primitive_bitwidth) == 8: 

183 idx_dtype = core.int8 

184 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16: 

185 idx_dtype = core.int16 

186 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32: 

187 idx_dtype = core.int32 

188 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64: 

189 idx_dtype = core.int64 

190 else: 

191 raise ValueError("Unsupported dtype") 

192 

193 ileft_idx = left_idx.to(idx_dtype, bitcast=True) 

194 iright_idx = right_idx.to(idx_dtype, bitcast=True) 

195 ix_idx = ids.to(idx_dtype, bitcast=True) 

196 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx)) 

197 

198 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True) 

199 

200 

201@triton.jit 

202def _bitonic_merge( 

203 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr 

204): 

205 """ 

206 order_type 0 == ascending 

207 order_type 1 == descending 

208 order_type 2 == alternating 

209 """ 

210 n_outer: core.constexpr = x.numel >> n_dims 

211 core.static_assert(stage <= n_dims) 

212 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 

213 # descending order. 

214 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 

215 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 

216 # a stride of 2) at this stage 

217 if order == 2: 

218 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 

219 flip = core.reshape( 

220 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

221 ) 

222 else: 

223 flip = order 

224 # perform `stage` rounds of `compare-and-swap` 

225 for i in core.static_range(stage): 

226 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) 

227 return x, ids 

228 

229 

230@triton.jit 

231def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): 

232 # handle default dimension or check that it is the most minor dim 

233 _dim: core.constexpr = dim 

234 n_dims: core.constexpr = _log2(x.shape[_dim]) 

235 for i in core.static_range(1, n_dims + 1): 

236 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) 

237 return x, ids 

238 

239 

240@libentry() 

241@triton.jit 

242def topk_stage2_kernel( 

243 y_ptr, 

244 index_ptr, 

245 chunk_x, 

246 chunk_index, 

247 sort_dim: tl.constexpr, 

248 k: tl.constexpr, 

249 N: tl.constexpr, 

250 BLOCK_SIZE: tl.constexpr, 

251 DESCENDING: tl.constexpr, 

252): 

253 cur_batch = tle.program_id(0) 

254 chunk_x += cur_batch * N 

255 chunk_index += cur_batch * N 

256 y_ptr += cur_batch * k 

257 index_ptr += cur_batch * k 

258 

259 cols = tl.arange(0, BLOCK_SIZE) 

260 mask = cols < N 

261 

262 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) 

263 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL 

264 

265 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32) 

266 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to( 

267 tl.int32 

268 ) 

269 

270 sorted_chunk_x, sorted_chunk_index = argsort( 

271 chunk_x_val, chunk_index_val, 0, descending=DESCENDING 

272 ) 

273 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k) 

274 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k) 

275 

276 

277def topk(x, k, dim=-1, largest=True, sorted=True): 

278 logger.debug("GEMS TOPK") 

279 # If dim equals to last dim, we set it to -1. 

280 if dim < 0: 

281 dim = dim + x.ndim 

282 

283 assert dim == x.ndim - 1, "Currently only support topk in last dimension" 

284 assert sorted, "Currently only support sorted == True" 

285 

286 descending = True 

287 if not largest: 

288 descending = False 

289 

290 topk_elem_cnt = x.shape[dim] 

291 batch_size = math.prod(x.shape) // topk_elem_cnt 

292 

293 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size. 

294 if topk_elem_cnt < 1024: 

295 chunk_size = 256 

296 else: 

297 chunk_size = 1024 

298 

299 # Note(Zhengzekang): We should promise chunk_size is larger than k. 

300 if chunk_size < k: 

301 chunk_size = triton.next_power_of_2(k) 

302 

303 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size) 

304 

305 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype) 

306 stage1_out_idx = torch.empty( 

307 batch_size * chunk_num * k, device=x.device, dtype=torch.int64 

308 ) 

309 

310 out_shape = x.shape[:-1] + (k,) 

311 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

312 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

313 

314 with torch_device_fn.device(x.device): 

315 topk_stage1_kernel[ 

316 batch_size, 

317 chunk_num, 

318 ]( 

319 stage1_out, # pointer to the output 

320 stage1_out_idx, # pointer to the output 

321 x, # pointer to the input 

322 k, 

323 topk_elem_cnt, 

324 chunk_size, 

325 descending, 

326 ) 

327 stage2_elem_cnt = chunk_num * k 

328 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt) 

329 

330 with torch_device_fn.device(x.device): 

331 topk_stage2_kernel[batch_size,]( 

332 stage2_out, 

333 stage2_out_idx, 

334 stage1_out, 

335 stage1_out_idx, 

336 dim, 

337 k, 

338 stage2_elem_cnt, 

339 BLOCK_SIZE, 

340 descending, 

341 ) 

342 

343 return (stage2_out, stage2_out_idx)