Coverage for src/flag_gems/runtime/backend/_cambricon/ops/topk.py: 0%
174 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.ops.topk import topk_stage1_kernel, topk_stage2_kernel
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
12from ..utils import TOTAL_CORE_NUM
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min)
16_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max)
17_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min)
18_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max)
19_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min)
20_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max)
21_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
22_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
23_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
24_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
25_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
26_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
29@triton.jit
30def _get_finfo_val(
31 dtype,
32 return_max,
33):
34 if dtype is tl.float32:
35 if return_max:
36 return _MAX_FLOAT32_VAL
37 else:
38 return _MIN_FLOAT32_VAL
39 elif dtype is tl.float16:
40 if return_max:
41 return _MAX_FLOAT16_VAL
42 else:
43 return _MIN_FLOAT16_VAL
44 elif dtype is tl.bfloat16:
45 if return_max:
46 return _MAX_BFLOAT16_VAL
47 else:
48 return _MIN_BFLOAT16_VAL
51@triton.jit
52def _get_iinfo_val(
53 dtype,
54 return_max,
55):
56 if dtype is tl.int16:
57 if return_max:
58 return _MAX_INT16_VAL
59 else:
60 return _MIN_INT16_VAL
61 elif dtype is tl.int32:
62 if return_max:
63 return _MAX_INT32_VAL
64 else:
65 return _MIN_INT32_VAL
66 elif dtype is tl.int64:
67 if return_max:
68 return _MAX_INT64_VAL
69 else:
70 return _MIN_INT64_VAL
73@triton.jit
74def get_topk_bubble_res(
75 buffer, buffer_ind, k, axis, mask_val, DESCENDING, BLOCK_M, BLOCK_N
76):
77 kep_buffer_n = buffer
78 topk_buffer_index_n = buffer_ind
79 ret = tl.empty([BLOCK_M, k], dtype=buffer.dtype)
80 ret_ind = tl.empty([BLOCK_M, k], dtype=buffer_ind.dtype)
81 for k_ind in tl.range(0, k):
82 if DESCENDING:
83 sel_val, sel_index = tl.max(kep_buffer_n, axis=axis, return_indices=True)
84 else:
85 sel_val, sel_index = tl.min(kep_buffer_n, axis=axis, return_indices=True)
87 if BLOCK_M > 1:
88 mask_sel = tl.arange(0, BLOCK_N)[None, :] == sel_index[:, None]
89 tep_sel_index_buffer = tl.where(mask_sel, topk_buffer_index_n, 0)
90 sel_index_res = tl.max(tep_sel_index_buffer, axis=axis)
91 sel_val_res = sel_val
92 ret[:, k_ind] = sel_val_res
93 ret_ind[:, k_ind] = sel_index_res
95 # Update buffer.
96 kep_buffer_n = tl.where(mask_sel, mask_val, kep_buffer_n)
97 else:
98 indices = sel_index[0]
99 ret[:, k_ind] = sel_val
100 ret_ind[:, k_ind] = topk_buffer_index_n[:, indices]
101 # Update buffer.
102 kep_buffer_n[:, indices] = mask_val
103 return ret, ret_ind
106BLOCK_BATCH = [1, 16]
107BLOCK_N = [128, 512, 1024, 2048]
110def topk_cfggen():
111 num_stage = [1, 3]
112 configs = [
113 triton.Config({"TILE_M": m, "TILE_N": n}, num_warps=1, num_stages=s)
114 for m in BLOCK_BATCH
115 for n in BLOCK_N
116 for s in num_stage
117 ]
118 return configs
121def topk_config_prune(configs, named_args, **kwargs):
122 k = named_args["k"]
123 N = named_args["N"]
124 block_m = named_args["BLOCK_M"]
125 new_configs = []
127 for config in configs:
128 tile_n = config.kwargs["TILE_N"]
129 tile_m = config.kwargs["TILE_M"]
130 if tile_n < k or tile_m > block_m:
131 continue
132 if len(new_configs) >= 1:
133 last_tn = new_configs[-1].kwargs["TILE_N"]
134 last_tm = new_configs[-1].kwargs["TILE_M"]
135 if tile_n > N and last_tn >= N and last_tm == tile_m:
136 continue
137 config.kwargs["TILE_M_NUM"] = triton.cdiv(block_m, tile_m)
138 config.kwargs["TILE_N_NUM"] = triton.cdiv(N, tile_n)
139 new_configs.append(config)
141 if (N not in BLOCK_N) and (N <= max(BLOCK_N)):
142 for tm in BLOCK_BATCH:
143 new_configs.append(
144 triton.Config(
145 {
146 "TILE_M": tm,
147 "TILE_N": N,
148 "TILE_M_NUM": triton.cdiv(block_m, tm),
149 "TILE_N_NUM": 1,
150 },
151 num_warps=1,
152 num_stages=3,
153 )
154 )
155 return new_configs
158@libentry()
159@libtuner(
160 configs=topk_cfggen(),
161 key=["k", "N", "M", "BLOCK_M"],
162 prune_configs_by={"early_config_prune": topk_config_prune},
163)
164@triton.jit
165def topk_bubble_kernel(
166 inp_ptr,
167 out_ptr,
168 out_index_ptr,
169 k: tl.constexpr,
170 M: tl.constexpr,
171 N: tl.constexpr,
172 BLOCK_M: tl.constexpr,
173 TILE_M: tl.constexpr,
174 TILE_N: tl.constexpr,
175 TILE_M_NUM: tl.constexpr,
176 TILE_N_NUM: tl.constexpr,
177 DESCENDING: tl.constexpr,
178):
179 pid = tl.program_id(0)
180 m_st = pid * BLOCK_M
182 mask_val = _get_finfo_val(inp_ptr.dtype.element_ty, return_max=not DESCENDING)
183 mask_val = mask_val.to(inp_ptr.dtype.element_ty)
185 for m_block_ind in tl.range(0, TILE_M_NUM):
186 m_iter_st = m_block_ind * TILE_M + m_st
187 m_offset_val = m_iter_st + tl.arange(0, TILE_M)
188 m_offset = m_offset_val[:, None]
189 m_offset_mask = m_offset < M
191 topk_buffer_n = tl.full(
192 [TILE_M, TILE_N_NUM * k], value=mask_val, dtype=inp_ptr.dtype.element_ty
193 )
194 topk_buffer_index_n = tl.full(
195 [TILE_M, TILE_N_NUM * k], value=0, dtype=out_index_ptr.dtype.element_ty
196 )
197 for n_block_ind in tl.range(0, TILE_N_NUM):
198 n_st = n_block_ind * TILE_N
199 n_offset = n_st + tl.arange(0, TILE_N)[None, :]
200 n_offset_mask = n_offset < N
202 inp_mask = m_offset_mask & n_offset_mask
203 inp_ptrs = inp_ptr + m_offset * N + n_offset
204 block_inp_val = tl.load(inp_ptrs, mask=inp_mask, other=mask_val)
206 local_buffer, local_buffer_ind = get_topk_bubble_res(
207 block_inp_val,
208 n_offset.to(out_index_ptr.dtype.element_ty),
209 k,
210 1,
211 mask_val,
212 DESCENDING,
213 TILE_M,
214 TILE_N,
215 )
216 tep_index = n_block_ind * k
217 topk_buffer_n[:, tep_index : tep_index + k] = local_buffer
218 topk_buffer_index_n[:, tep_index : tep_index + k] = local_buffer_ind
219 if TILE_N_NUM > 1:
220 global_res, global_res_ind = get_topk_bubble_res(
221 topk_buffer_n,
222 topk_buffer_index_n,
223 k,
224 1,
225 mask_val,
226 DESCENDING,
227 TILE_M,
228 TILE_N_NUM * k,
229 )
230 else:
231 global_res = topk_buffer_n
232 global_res_ind = topk_buffer_index_n
234 # Store topk.
235 store_ptrs = m_offset * k + tl.arange(0, k)[None, :]
236 store_mask = m_offset_mask
237 tl.store(store_ptrs + out_ptr, global_res, store_mask)
238 tl.store(store_ptrs + out_index_ptr, global_res_ind, store_mask)
241def topk(x, k, dim=-1, largest=True, sorted=True):
242 logger.debug("GEMS_CAMBRICON TOPK")
243 # If dim equals to last dim, we set it to -1.
244 if dim < 0:
245 dim = dim + x.ndim
247 assert dim == x.ndim - 1, "Currently only support topk in last dimension"
248 assert sorted, "Currently only support sorted == True"
250 descending = True
251 if not largest:
252 descending = False
254 topk_elem_cnt = x.shape[dim]
255 batch_size = math.prod(x.shape) // topk_elem_cnt
256 out_shape = x.shape[:-1] + (k,)
258 if k <= math.log2(topk_elem_cnt):
259 logger.debug("GEMS_CAMBRICON TOPK USING BUBBLE")
260 topk_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
261 topk_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
263 def grid_fn(meta):
264 return (min(batch_size, TOTAL_CORE_NUM),)
266 block_m = triton.cdiv(batch_size, TOTAL_CORE_NUM)
267 topk_bubble_kernel[grid_fn](
268 x,
269 topk_out,
270 topk_out_idx,
271 k,
272 batch_size,
273 topk_elem_cnt,
274 block_m,
275 DESCENDING=descending,
276 )
277 return (topk_out, topk_out_idx)
278 else:
279 logger.debug("GEMS_CAMBRICON TOPK USING SORT")
280 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size.
281 if topk_elem_cnt < 1024:
282 chunk_size = 256
283 else:
284 chunk_size = 1024
286 # Note(Zhengzekang): We should promise chunk_size is larger than k.
287 if chunk_size < k:
288 chunk_size = triton.next_power_of_2(k)
290 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size)
292 stage1_out = torch.empty(
293 batch_size * chunk_num * k, device=x.device, dtype=x.dtype
294 )
295 stage1_out_idx = torch.empty(
296 batch_size * chunk_num * k, device=x.device, dtype=torch.int64
297 )
299 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
300 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
302 with torch_device_fn.device(x.device):
303 topk_stage1_kernel[
304 batch_size,
305 chunk_num,
306 ](
307 stage1_out, # pointer to the output
308 stage1_out_idx, # pointer to the output
309 x, # pointer to the input
310 k,
311 topk_elem_cnt,
312 chunk_size,
313 descending,
314 )
315 stage2_elem_cnt = chunk_num * k
316 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt)
318 with torch_device_fn.device(x.device):
319 topk_stage2_kernel[batch_size,](
320 stage2_out,
321 stage2_out_idx,
322 stage1_out,
323 stage1_out_idx,
324 dim,
325 k,
326 stage2_elem_cnt,
327 BLOCK_SIZE,
328 descending,
329 )
331 return (stage2_out, stage2_out_idx)