Coverage for src/flag_gems/ops/topk.py: 40%

183 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7import triton.language.core as core 

8 

9try: 

10 # TODO: Triton 2.1 does not implement _log2. 

11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton. 

12 from triton.language.standard import _log2, zeros_like 

13except ImportError: 

14 pass 

15 

16from flag_gems.runtime import torch_device_fn 

17from flag_gems.utils import libentry 

18from flag_gems.utils import triton_lang_extension as tle 

19from flag_gems.utils.limits import get_dtype_max, get_dtype_min 

20 

21logger = logging.getLogger(__name__) 

22_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min) 

23_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max) 

24_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min) 

25_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max) 

26_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min) 

27_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max) 

28_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min) 

29_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max) 

30_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

31_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

32_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

33_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

34_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

35_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

36 

37 

38@triton.jit 

39def _get_finfo_val( 

40 dtype, 

41 return_max, 

42): 

43 if dtype is tl.float32: 

44 if return_max: 

45 return _MAX_FLOAT32_VAL 

46 else: 

47 return _MIN_FLOAT32_VAL 

48 elif dtype is tl.float16: 

49 if return_max: 

50 return _MAX_FLOAT16_VAL 

51 else: 

52 return _MIN_FLOAT16_VAL 

53 elif dtype is tl.bfloat16: 

54 if return_max: 

55 return _MAX_BFLOAT16_VAL 

56 else: 

57 return _MIN_BFLOAT16_VAL 

58 

59 

60@triton.jit 

61def _get_iinfo_val( 

62 dtype, 

63 return_max, 

64): 

65 if return_max: 

66 return get_dtype_max(dtype) 

67 else: 

68 return get_dtype_min(dtype) 

69 

70 

71@libentry() 

72@triton.jit 

73def topk_stage1_kernel( 

74 y_ptr, 

75 index_ptr, 

76 x_ptr, 

77 k, 

78 N: tl.constexpr, 

79 CHUNK_SIZE: tl.constexpr, 

80 DESCENDING: tl.constexpr, 

81): 

82 cur_batch = tle.program_id(0) 

83 cur_chunk_idx = tle.program_id(1) 

84 chunk_num = tle.num_programs(1) 

85 

86 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

87 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

88 

89 chunk_offset = cur_chunk_idx * CHUNK_SIZE 

90 x_ptr += cur_batch * N + chunk_offset 

91 

92 cols = tl.arange(0, CHUNK_SIZE) 

93 mask = (chunk_offset + cols) < N 

94 

95 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING) 

96 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32) 

97 for k_idx in range(k): 

98 if DESCENDING: 

99 chunk_select_val = tl.max(x_val) 

100 chunk_select_idx = tl.argmax(x_val, axis=0) 

101 else: 

102 chunk_select_val = tl.min(x_val) 

103 chunk_select_idx = tl.argmin(x_val, axis=0) 

104 

105 tl.store(y_ptr + k_idx, chunk_select_val) 

106 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset) 

107 

108 if DESCENDING: 

109 x_val = tl.where( 

110 cols == chunk_select_idx, 

111 _get_finfo_val(tl.float32, return_max=False), 

112 x_val, 

113 ) 

114 else: 

115 x_val = tl.where( 

116 cols == chunk_select_idx, 

117 _get_finfo_val(tl.float32, return_max=True), 

118 x_val, 

119 ) 

120 

121 

122""" 

123Note(Zhengzekang): 

124Refer from triton2.2 official `sort` implementation: 

125https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404 

126Just add indices to sort with values. 

127""" 

128 

129 

130@triton.jit 

131def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): 

132 n_outer: core.constexpr = x.numel >> n_dims 

133 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 

134 

135 # tl.device_print("shape is: ", shape) 

136 y = core.reshape(x, shape) 

137 y_idx = core.reshape(ids, shape) 

138 

139 # slice left/right with 'stride' 2**(n_dims - i - 1) 

140 mask = core.arange(0, 2)[None, :, None] 

141 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) 

142 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) 

143 left = core.reshape(left, x.shape) 

144 right = core.reshape(right, x.shape) 

145 

146 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( 

147 ids.dtype 

148 ) 

149 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( 

150 ids.dtype 

151 ) 

152 left_idx = core.reshape(left_idx, ids.shape) 

153 right_idx = core.reshape(right_idx, ids.shape) 

154 

155 # actual compare-and-swap 

156 if core.constexpr(x.dtype.primitive_bitwidth) == 8: 

157 idtype = core.int8 

158 elif core.constexpr(x.dtype.primitive_bitwidth) == 16: 

159 idtype = core.int16 

160 elif core.constexpr(x.dtype.primitive_bitwidth) == 32: 

161 idtype = core.int32 

162 elif core.constexpr(x.dtype.primitive_bitwidth) == 64: 

163 idtype = core.int64 

164 else: 

165 raise ValueError("Unsupported dtype") 

166 

167 ileft = left.to(idtype, bitcast=True) 

168 iright = right.to(idtype, bitcast=True) 

169 ix = x.to(idtype, bitcast=True) 

170 

171 cond = (left > right) ^ flip 

172 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix)) 

173 

174 if core.constexpr(ids.dtype.primitive_bitwidth) == 8: 

175 idx_dtype = core.int8 

176 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16: 

177 idx_dtype = core.int16 

178 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32: 

179 idx_dtype = core.int32 

180 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64: 

181 idx_dtype = core.int64 

182 else: 

183 raise ValueError("Unsupported dtype") 

184 

185 ileft_idx = left_idx.to(idx_dtype, bitcast=True) 

186 iright_idx = right_idx.to(idx_dtype, bitcast=True) 

187 ix_idx = ids.to(idx_dtype, bitcast=True) 

188 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx)) 

189 

190 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True) 

191 

192 

193@triton.jit 

194def _bitonic_merge( 

195 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr 

196): 

197 """ 

198 order_type 0 == ascending 

199 order_type 1 == descending 

200 order_type 2 == alternating 

201 """ 

202 n_outer: core.constexpr = x.numel >> n_dims 

203 core.static_assert(stage <= n_dims) 

204 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 

205 # descending order. 

206 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 

207 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 

208 # a stride of 2) at this stage 

209 if order == 2: 

210 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 

211 flip = core.reshape( 

212 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

213 ) 

214 else: 

215 flip = order 

216 # perform `stage` rounds of `compare-and-swap` 

217 for i in core.static_range(stage): 

218 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) 

219 return x, ids 

220 

221 

222@triton.jit 

223def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): 

224 # handle default dimension or check that it is the most minor dim 

225 _dim: core.constexpr = dim 

226 n_dims: core.constexpr = _log2(x.shape[_dim]) 

227 for i in core.static_range(1, n_dims + 1): 

228 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) 

229 return x, ids 

230 

231 

232@libentry() 

233@triton.jit 

234def topk_stage2_kernel( 

235 y_ptr, 

236 index_ptr, 

237 chunk_x, 

238 chunk_index, 

239 sort_dim: tl.constexpr, 

240 k: tl.constexpr, 

241 N: tl.constexpr, 

242 BLOCK_SIZE: tl.constexpr, 

243 DESCENDING: tl.constexpr, 

244): 

245 cur_batch = tle.program_id(0) 

246 chunk_x += cur_batch * N 

247 chunk_index += cur_batch * N 

248 y_ptr += cur_batch * k 

249 index_ptr += cur_batch * k 

250 

251 cols = tl.arange(0, BLOCK_SIZE) 

252 mask = cols < N 

253 

254 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) 

255 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL 

256 

257 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32) 

258 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to( 

259 tl.int32 

260 ) 

261 

262 sorted_chunk_x, sorted_chunk_index = argsort( 

263 chunk_x_val, chunk_index_val, 0, descending=DESCENDING 

264 ) 

265 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k) 

266 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k) 

267 

268 

269def topk(x, k, dim=-1, largest=True, sorted=True): 

270 logger.debug("GEMS TOPK") 

271 # If dim equals to last dim, we set it to -1. 

272 if dim < 0: 

273 dim = dim + x.ndim 

274 

275 assert dim == x.ndim - 1, "Currently only support topk in last dimension" 

276 # assert sorted, "Currently only support sorted == True" 

277 

278 # Early return for k=0 to avoid Triton kernel compilation error. 

279 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0. 

280 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0. 

281 if k == 0: 

282 out_shape = list(x.shape[:-1]) + [0] 

283 return ( 

284 torch.empty(out_shape, device=x.device, dtype=x.dtype), 

285 torch.empty(out_shape, device=x.device, dtype=torch.int64), 

286 ) 

287 

288 descending = True 

289 if not largest: 

290 descending = False 

291 

292 topk_elem_cnt = x.shape[dim] 

293 batch_size = math.prod(x.shape) // topk_elem_cnt 

294 

295 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size. 

296 if topk_elem_cnt < 1024: 

297 chunk_size = 256 

298 else: 

299 chunk_size = 1024 

300 

301 # Note(Zhengzekang): We should promise chunk_size is larger than k. 

302 if chunk_size < k: 

303 chunk_size = triton.next_power_of_2(k) 

304 

305 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size) 

306 

307 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype) 

308 stage1_out_idx = torch.empty( 

309 batch_size * chunk_num * k, device=x.device, dtype=torch.int64 

310 ) 

311 

312 out_shape = x.shape[:-1] + (k,) 

313 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

314 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

315 

316 with torch_device_fn.device(x.device): 

317 topk_stage1_kernel[ 

318 batch_size, 

319 chunk_num, 

320 ]( 

321 stage1_out, # pointer to the output 

322 stage1_out_idx, # pointer to the output 

323 x, # pointer to the input 

324 k, 

325 topk_elem_cnt, 

326 chunk_size, 

327 descending, 

328 ) 

329 stage2_elem_cnt = chunk_num * k 

330 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt) 

331 

332 with torch_device_fn.device(x.device): 

333 topk_stage2_kernel[batch_size,]( 

334 stage2_out, 

335 stage2_out_idx, 

336 stage1_out, 

337 stage1_out_idx, 

338 dim, 

339 k, 

340 stage2_elem_cnt, 

341 BLOCK_SIZE, 

342 descending, 

343 ) 

344 

345 return (stage2_out, stage2_out_idx)