Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/topk.py: 0%
187 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
7import triton.language.core as core
9try:
10 # TODO: Triton 2.1 does not implement _log2.
11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton.
12 from triton.language.standard import _log2, zeros_like
13except ImportError:
14 pass
16from flag_gems.runtime import torch_device_fn
17from flag_gems.utils import libentry
18from flag_gems.utils import triton_lang_extension as tle
20logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
21_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min)
22_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max)
23_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min)
24_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max)
25_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min)
26_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max)
27_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
28_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
29_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
30_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
31_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
32_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
35@triton.jit
36def _get_finfo_val(
37 dtype,
38 return_max,
39):
40 if dtype is tl.float32:
41 if return_max:
42 return _MAX_FLOAT32_VAL
43 else:
44 return _MIN_FLOAT32_VAL
45 elif dtype is tl.float16:
46 if return_max:
47 return _MAX_FLOAT16_VAL
48 else:
49 return _MIN_FLOAT16_VAL
50 elif dtype is tl.bfloat16:
51 if return_max:
52 return _MAX_BFLOAT16_VAL
53 else:
54 return _MIN_BFLOAT16_VAL
57@triton.jit
58def _get_iinfo_val(
59 dtype,
60 return_max,
61):
62 if dtype is tl.int16:
63 if return_max:
64 return _MAX_INT16_VAL
65 else:
66 return _MIN_INT16_VAL
67 elif dtype is tl.int32:
68 if return_max:
69 return _MAX_INT32_VAL
70 else:
71 return _MIN_INT32_VAL
72 elif dtype is tl.int64:
73 if return_max:
74 return _MAX_INT64_VAL
75 else:
76 return _MIN_INT64_VAL
79@libentry()
80@triton.jit
81def topk_stage1_kernel(
82 y_ptr,
83 index_ptr,
84 x_ptr,
85 k,
86 N: tl.constexpr,
87 CHUNK_SIZE: tl.constexpr,
88 DESCENDING: tl.constexpr,
89):
90 cur_batch = tle.program_id(0)
91 cur_chunk_idx = tle.program_id(1)
92 chunk_num = tle.num_programs(1)
94 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
95 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
97 chunk_offset = cur_chunk_idx * CHUNK_SIZE
98 x_ptr += cur_batch * N + chunk_offset
100 cols = tl.arange(0, CHUNK_SIZE)
101 mask = (chunk_offset + cols) < N
103 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING)
104 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32)
105 for k_idx in range(k):
106 if DESCENDING:
107 chunk_select_val = tl.max(x_val)
108 chunk_select_idx = tl.argmax(x_val, axis=0)
109 else:
110 chunk_select_val = tl.min(x_val)
111 chunk_select_idx = tl.argmin(x_val, axis=0)
113 tl.store(y_ptr + k_idx, chunk_select_val)
114 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset)
116 if DESCENDING:
117 x_val = tl.where(
118 cols == chunk_select_idx,
119 _get_finfo_val(tl.float32, return_max=False),
120 x_val,
121 )
122 else:
123 x_val = tl.where(
124 cols == chunk_select_idx,
125 _get_finfo_val(tl.float32, return_max=True),
126 x_val,
127 )
130"""
131Note(Zhengzekang):
132Refer from triton2.2 official `sort` implementation:
133https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404
134Just add indices to sort with values.
135"""
138@triton.jit
139def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
140 n_outer: core.constexpr = x.numel >> n_dims
141 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
143 # tl.device_print("shape is: ", shape)
144 y = core.reshape(x, shape)
145 y_idx = core.reshape(ids, shape)
147 # slice left/right with 'stride' 2**(n_dims - i - 1)
148 mask = core.arange(0, 2)[None, :, None]
149 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype)
150 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype)
151 left = core.reshape(left, x.shape)
152 right = core.reshape(right, x.shape)
154 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to(
155 ids.dtype
156 )
157 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to(
158 ids.dtype
159 )
160 left_idx = core.reshape(left_idx, ids.shape)
161 right_idx = core.reshape(right_idx, ids.shape)
163 # actual compare-and-swap
164 if core.constexpr(x.dtype.primitive_bitwidth) == 8:
165 idtype = core.int8
166 elif core.constexpr(x.dtype.primitive_bitwidth) == 16:
167 idtype = core.int16
168 elif core.constexpr(x.dtype.primitive_bitwidth) == 32:
169 idtype = core.int32
170 elif core.constexpr(x.dtype.primitive_bitwidth) == 64:
171 idtype = core.int64
172 else:
173 raise ValueError("Unsupported dtype")
175 ileft = left.to(idtype, bitcast=True)
176 iright = right.to(idtype, bitcast=True)
177 ix = x.to(idtype, bitcast=True)
179 cond = (left > right) ^ flip
180 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))
182 if core.constexpr(ids.dtype.primitive_bitwidth) == 8:
183 idx_dtype = core.int8
184 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16:
185 idx_dtype = core.int16
186 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32:
187 idx_dtype = core.int32
188 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64:
189 idx_dtype = core.int64
190 else:
191 raise ValueError("Unsupported dtype")
193 ileft_idx = left_idx.to(idx_dtype, bitcast=True)
194 iright_idx = right_idx.to(idx_dtype, bitcast=True)
195 ix_idx = ids.to(idx_dtype, bitcast=True)
196 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx))
198 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True)
201@triton.jit
202def _bitonic_merge(
203 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr
204):
205 """
206 order_type 0 == ascending
207 order_type 1 == descending
208 order_type 2 == alternating
209 """
210 n_outer: core.constexpr = x.numel >> n_dims
211 core.static_assert(stage <= n_dims)
212 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
213 # descending order.
214 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
215 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
216 # a stride of 2) at this stage
217 if order == 2:
218 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
219 flip = core.reshape(
220 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
221 )
222 else:
223 flip = order
224 # perform `stage` rounds of `compare-and-swap`
225 for i in core.static_range(stage):
226 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
227 return x, ids
230@triton.jit
231def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):
232 # handle default dimension or check that it is the most minor dim
233 _dim: core.constexpr = dim
234 n_dims: core.constexpr = _log2(x.shape[_dim])
235 for i in core.static_range(1, n_dims + 1):
236 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
237 return x, ids
240@libentry()
241@triton.jit
242def topk_stage2_kernel(
243 y_ptr,
244 index_ptr,
245 chunk_x,
246 chunk_index,
247 sort_dim: tl.constexpr,
248 k: tl.constexpr,
249 N: tl.constexpr,
250 BLOCK_SIZE: tl.constexpr,
251 DESCENDING: tl.constexpr,
252):
253 cur_batch = tle.program_id(0)
254 chunk_x += cur_batch * N
255 chunk_index += cur_batch * N
256 y_ptr += cur_batch * k
257 index_ptr += cur_batch * k
259 cols = tl.arange(0, BLOCK_SIZE)
260 mask = cols < N
262 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
263 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL
265 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32)
266 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to(
267 tl.int32
268 )
270 sorted_chunk_x, sorted_chunk_index = argsort(
271 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
272 )
273 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k)
274 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k)
277def topk(x, k, dim=-1, largest=True, sorted=True):
278 logger.debug("GEMS TOPK")
279 # If dim equals to last dim, we set it to -1.
280 if dim < 0:
281 dim = dim + x.ndim
283 assert dim == x.ndim - 1, "Currently only support topk in last dimension"
284 assert sorted, "Currently only support sorted == True"
286 descending = True
287 if not largest:
288 descending = False
290 topk_elem_cnt = x.shape[dim]
291 batch_size = math.prod(x.shape) // topk_elem_cnt
293 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size.
294 if topk_elem_cnt < 1024:
295 chunk_size = 256
296 else:
297 chunk_size = 1024
299 # Note(Zhengzekang): We should promise chunk_size is larger than k.
300 if chunk_size < k:
301 chunk_size = triton.next_power_of_2(k)
303 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size)
305 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype)
306 stage1_out_idx = torch.empty(
307 batch_size * chunk_num * k, device=x.device, dtype=torch.int64
308 )
310 out_shape = x.shape[:-1] + (k,)
311 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
312 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
314 with torch_device_fn.device(x.device):
315 topk_stage1_kernel[
316 batch_size,
317 chunk_num,
318 ](
319 stage1_out, # pointer to the output
320 stage1_out_idx, # pointer to the output
321 x, # pointer to the input
322 k,
323 topk_elem_cnt,
324 chunk_size,
325 descending,
326 )
327 stage2_elem_cnt = chunk_num * k
328 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt)
330 with torch_device_fn.device(x.device):
331 topk_stage2_kernel[batch_size,](
332 stage2_out,
333 stage2_out_idx,
334 stage1_out,
335 stage1_out_idx,
336 dim,
337 k,
338 stage2_elem_cnt,
339 BLOCK_SIZE,
340 descending,
341 )
343 return (stage2_out, stage2_out_idx)