Coverage for src/flag_gems/ops/topk.py: 40%
183 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
7import triton.language.core as core
9try:
10 # TODO: Triton 2.1 does not implement _log2.
11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton.
12 from triton.language.standard import _log2, zeros_like
13except ImportError:
14 pass
16from flag_gems.runtime import torch_device_fn
17from flag_gems.utils import libentry
18from flag_gems.utils import triton_lang_extension as tle
19from flag_gems.utils.limits import get_dtype_max, get_dtype_min
21logger = logging.getLogger(__name__)
22_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min)
23_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max)
24_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min)
25_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max)
26_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min)
27_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max)
28_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min)
29_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max)
30_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
31_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
32_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
33_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
34_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
35_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
38@triton.jit
39def _get_finfo_val(
40 dtype,
41 return_max,
42):
43 if dtype is tl.float32:
44 if return_max:
45 return _MAX_FLOAT32_VAL
46 else:
47 return _MIN_FLOAT32_VAL
48 elif dtype is tl.float16:
49 if return_max:
50 return _MAX_FLOAT16_VAL
51 else:
52 return _MIN_FLOAT16_VAL
53 elif dtype is tl.bfloat16:
54 if return_max:
55 return _MAX_BFLOAT16_VAL
56 else:
57 return _MIN_BFLOAT16_VAL
60@triton.jit
61def _get_iinfo_val(
62 dtype,
63 return_max,
64):
65 if return_max:
66 return get_dtype_max(dtype)
67 else:
68 return get_dtype_min(dtype)
71@libentry()
72@triton.jit
73def topk_stage1_kernel(
74 y_ptr,
75 index_ptr,
76 x_ptr,
77 k,
78 N: tl.constexpr,
79 CHUNK_SIZE: tl.constexpr,
80 DESCENDING: tl.constexpr,
81):
82 cur_batch = tle.program_id(0)
83 cur_chunk_idx = tle.program_id(1)
84 chunk_num = tle.num_programs(1)
86 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
87 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
89 chunk_offset = cur_chunk_idx * CHUNK_SIZE
90 x_ptr += cur_batch * N + chunk_offset
92 cols = tl.arange(0, CHUNK_SIZE)
93 mask = (chunk_offset + cols) < N
95 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING)
96 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32)
97 for k_idx in range(k):
98 if DESCENDING:
99 chunk_select_val = tl.max(x_val)
100 chunk_select_idx = tl.argmax(x_val, axis=0)
101 else:
102 chunk_select_val = tl.min(x_val)
103 chunk_select_idx = tl.argmin(x_val, axis=0)
105 tl.store(y_ptr + k_idx, chunk_select_val)
106 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset)
108 if DESCENDING:
109 x_val = tl.where(
110 cols == chunk_select_idx,
111 _get_finfo_val(tl.float32, return_max=False),
112 x_val,
113 )
114 else:
115 x_val = tl.where(
116 cols == chunk_select_idx,
117 _get_finfo_val(tl.float32, return_max=True),
118 x_val,
119 )
122"""
123Note(Zhengzekang):
124Refer from triton2.2 official `sort` implementation:
125https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404
126Just add indices to sort with values.
127"""
130@triton.jit
131def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
132 n_outer: core.constexpr = x.numel >> n_dims
133 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
135 # tl.device_print("shape is: ", shape)
136 y = core.reshape(x, shape)
137 y_idx = core.reshape(ids, shape)
139 # slice left/right with 'stride' 2**(n_dims - i - 1)
140 mask = core.arange(0, 2)[None, :, None]
141 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype)
142 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype)
143 left = core.reshape(left, x.shape)
144 right = core.reshape(right, x.shape)
146 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to(
147 ids.dtype
148 )
149 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to(
150 ids.dtype
151 )
152 left_idx = core.reshape(left_idx, ids.shape)
153 right_idx = core.reshape(right_idx, ids.shape)
155 # actual compare-and-swap
156 if core.constexpr(x.dtype.primitive_bitwidth) == 8:
157 idtype = core.int8
158 elif core.constexpr(x.dtype.primitive_bitwidth) == 16:
159 idtype = core.int16
160 elif core.constexpr(x.dtype.primitive_bitwidth) == 32:
161 idtype = core.int32
162 elif core.constexpr(x.dtype.primitive_bitwidth) == 64:
163 idtype = core.int64
164 else:
165 raise ValueError("Unsupported dtype")
167 ileft = left.to(idtype, bitcast=True)
168 iright = right.to(idtype, bitcast=True)
169 ix = x.to(idtype, bitcast=True)
171 cond = (left > right) ^ flip
172 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))
174 if core.constexpr(ids.dtype.primitive_bitwidth) == 8:
175 idx_dtype = core.int8
176 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16:
177 idx_dtype = core.int16
178 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32:
179 idx_dtype = core.int32
180 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64:
181 idx_dtype = core.int64
182 else:
183 raise ValueError("Unsupported dtype")
185 ileft_idx = left_idx.to(idx_dtype, bitcast=True)
186 iright_idx = right_idx.to(idx_dtype, bitcast=True)
187 ix_idx = ids.to(idx_dtype, bitcast=True)
188 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx))
190 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True)
193@triton.jit
194def _bitonic_merge(
195 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr
196):
197 """
198 order_type 0 == ascending
199 order_type 1 == descending
200 order_type 2 == alternating
201 """
202 n_outer: core.constexpr = x.numel >> n_dims
203 core.static_assert(stage <= n_dims)
204 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
205 # descending order.
206 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
207 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
208 # a stride of 2) at this stage
209 if order == 2:
210 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
211 flip = core.reshape(
212 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
213 )
214 else:
215 flip = order
216 # perform `stage` rounds of `compare-and-swap`
217 for i in core.static_range(stage):
218 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
219 return x, ids
222@triton.jit
223def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):
224 # handle default dimension or check that it is the most minor dim
225 _dim: core.constexpr = dim
226 n_dims: core.constexpr = _log2(x.shape[_dim])
227 for i in core.static_range(1, n_dims + 1):
228 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
229 return x, ids
232@libentry()
233@triton.jit
234def topk_stage2_kernel(
235 y_ptr,
236 index_ptr,
237 chunk_x,
238 chunk_index,
239 sort_dim: tl.constexpr,
240 k: tl.constexpr,
241 N: tl.constexpr,
242 BLOCK_SIZE: tl.constexpr,
243 DESCENDING: tl.constexpr,
244):
245 cur_batch = tle.program_id(0)
246 chunk_x += cur_batch * N
247 chunk_index += cur_batch * N
248 y_ptr += cur_batch * k
249 index_ptr += cur_batch * k
251 cols = tl.arange(0, BLOCK_SIZE)
252 mask = cols < N
254 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
255 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL
257 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32)
258 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to(
259 tl.int32
260 )
262 sorted_chunk_x, sorted_chunk_index = argsort(
263 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
264 )
265 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k)
266 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k)
269def topk(x, k, dim=-1, largest=True, sorted=True):
270 logger.debug("GEMS TOPK")
271 # If dim equals to last dim, we set it to -1.
272 if dim < 0:
273 dim = dim + x.ndim
275 assert dim == x.ndim - 1, "Currently only support topk in last dimension"
276 # assert sorted, "Currently only support sorted == True"
278 # Early return for k=0 to avoid Triton kernel compilation error.
279 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0.
280 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0.
281 if k == 0:
282 out_shape = list(x.shape[:-1]) + [0]
283 return (
284 torch.empty(out_shape, device=x.device, dtype=x.dtype),
285 torch.empty(out_shape, device=x.device, dtype=torch.int64),
286 )
288 descending = True
289 if not largest:
290 descending = False
292 topk_elem_cnt = x.shape[dim]
293 batch_size = math.prod(x.shape) // topk_elem_cnt
295 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size.
296 if topk_elem_cnt < 1024:
297 chunk_size = 256
298 else:
299 chunk_size = 1024
301 # Note(Zhengzekang): We should promise chunk_size is larger than k.
302 if chunk_size < k:
303 chunk_size = triton.next_power_of_2(k)
305 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size)
307 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype)
308 stage1_out_idx = torch.empty(
309 batch_size * chunk_num * k, device=x.device, dtype=torch.int64
310 )
312 out_shape = x.shape[:-1] + (k,)
313 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
314 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
316 with torch_device_fn.device(x.device):
317 topk_stage1_kernel[
318 batch_size,
319 chunk_num,
320 ](
321 stage1_out, # pointer to the output
322 stage1_out_idx, # pointer to the output
323 x, # pointer to the input
324 k,
325 topk_elem_cnt,
326 chunk_size,
327 descending,
328 )
329 stage2_elem_cnt = chunk_num * k
330 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt)
332 with torch_device_fn.device(x.device):
333 topk_stage2_kernel[batch_size,](
334 stage2_out,
335 stage2_out_idx,
336 stage1_out,
337 stage1_out_idx,
338 dim,
339 k,
340 stage2_elem_cnt,
341 BLOCK_SIZE,
342 descending,
343 )
345 return (stage2_out, stage2_out_idx)