Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/grouped_topk.py: 0%
139 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import torch
2import triton
3import triton.language as tl
4from triton.language.extra.cuda import libdevice
7@triton.jit
8def topk_with_k2_triton(
9 scores_ptr,
10 bias_ptr,
11 group_scores_ptr,
12 num_experts_per_group,
13 n_group,
14 stride_scores_token,
15 stride_group_scores_token,
16 scoring_func: tl.constexpr,
17 BLOCK_SIZE: tl.constexpr,
18 INPUT_DTYPE: tl.constexpr,
19):
20 pid = tl.program_id(0)
22 token_id = pid // n_group
23 group_id = pid % n_group
25 lane = tl.arange(0, BLOCK_SIZE)
26 mask = lane < num_experts_per_group
28 scores_offset = token_id * stride_scores_token + group_id * num_experts_per_group
29 bias_offset = group_id * num_experts_per_group
31 x = tl.load(
32 scores_ptr + scores_offset + lane,
33 mask=mask,
34 other=-float("inf"),
35 )
37 b = tl.load(
38 bias_ptr + bias_offset + lane,
39 mask=mask,
40 other=0.0,
41 )
43 if scoring_func == 1:
44 x_f32 = x.to(tl.float32)
45 x_f32 = 0.5 * libdevice.tanh(0.5 * x_f32) + 0.5
46 x = x_f32.to(INPUT_DTYPE)
48 x = x + b
50 x_f32 = x.to(tl.float32)
52 max1 = tl.max(x_f32, axis=0)
53 is_max1 = (x_f32 == max1) & mask
54 count_max1 = tl.sum(is_max1.to(tl.int32), axis=0)
56 x2 = tl.where(
57 is_max1 & (count_max1 == 1),
58 -float("inf"),
59 x_f32,
60 )
61 max2 = tl.max(x2, axis=0)
63 group_scores_offset = token_id * stride_group_scores_token + group_id
64 tl.store(
65 group_scores_ptr + group_scores_offset,
66 (max1 + max2).to(INPUT_DTYPE),
67 )
70@triton.jit
71def group_idx_and_topk_triton(
72 scores_ptr,
73 group_scores_ptr,
74 topk_values_ptr,
75 topk_indices_ptr,
76 bias_ptr,
77 num_tokens,
78 n_group,
79 topk_group,
80 topk,
81 num_experts,
82 num_experts_per_group,
83 routed_scaling_factor,
84 scoring_func: tl.constexpr,
85 stride_scores_token,
86 stride_group_scores_token,
87 stride_out_token,
88 N_GROUP: tl.constexpr,
89 TOPK_GROUP: tl.constexpr,
90 TOPK: tl.constexpr,
91 BLOCK_GROUP: tl.constexpr,
92 BLOCK_EXPERT: tl.constexpr,
93 INPUT_DTYPE: tl.constexpr,
94 renormalize: tl.constexpr,
95):
96 pid = tl.program_id(0)
97 if pid >= num_tokens:
98 return
100 neg_inf = -float("inf")
102 group_offsets = tl.arange(0, BLOCK_GROUP)
103 valid_group = group_offsets < n_group
105 group_scores = tl.load(
106 group_scores_ptr + pid * stride_group_scores_token + group_offsets,
107 mask=valid_group,
108 other=neg_inf,
109 )
111 group_scores_f32 = group_scores.to(tl.float32)
112 is_finite = (group_scores_f32 == group_scores_f32) & (
113 group_scores_f32 != float("inf")
114 )
115 group_scores_f32 = tl.where(is_finite & valid_group, group_scores_f32, neg_inf)
117 max_group_score = tl.max(group_scores_f32, axis=0)
118 if_proceed = max_group_score != neg_inf
120 value = group_scores_f32
121 target_num_min = BLOCK_GROUP - n_group + topk_group
122 count_equal_to_top_value = BLOCK_GROUP - n_group
123 pre_count_equal_to_top_value = 0
124 topk_group_value = neg_inf
126 for _ in range(TOPK_GROUP):
127 need = count_equal_to_top_value < target_num_min
128 max_val = tl.max(value, axis=0)
130 is_max = need & (value == max_val)
131 value = tl.where(is_max, neg_inf, value)
133 newly = tl.sum(is_max.to(tl.int32), axis=0)
135 pre_count_equal_to_top_value = tl.where(
136 need, count_equal_to_top_value, pre_count_equal_to_top_value
137 )
138 count_equal_to_top_value = tl.where(
139 need, count_equal_to_top_value + newly, count_equal_to_top_value
140 )
141 topk_group_value = tl.where(need, max_val, topk_group_value)
143 num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value
145 group_gt = group_scores_f32 > topk_group_value
146 group_eq = group_scores_f32 == topk_group_value
148 eq_i = group_eq.to(tl.int32)
149 prefix_eq = tl.cumsum(eq_i, axis=0) - eq_i
151 group_selected = (
152 group_gt | (group_eq & (prefix_eq < num_equalto_topkth_group))
153 ) & valid_group
155 expert_offsets = tl.arange(0, BLOCK_EXPERT)
156 valid_expert = expert_offsets < num_experts
157 expert_group = expert_offsets // num_experts_per_group
159 expert_in_group = expert_group[:, None] == group_offsets[None, :]
160 expert_selected = (
161 tl.sum((expert_in_group & group_selected[None, :]).to(tl.int32), axis=1) > 0
162 ) & valid_expert
164 raw_scores = tl.load(
165 scores_ptr + pid * stride_scores_token + expert_offsets,
166 mask=expert_selected,
167 other=neg_inf,
168 )
170 expert_bias = tl.load(
171 bias_ptr + expert_offsets,
172 mask=valid_expert,
173 other=0.0,
174 )
176 if scoring_func == 1:
177 scored_f32 = raw_scores.to(tl.float32)
178 scored_f32 = 0.5 * libdevice.tanh(0.5 * scored_f32) + 0.5
179 scored = scored_f32.to(INPUT_DTYPE)
180 else:
181 scored = raw_scores
183 selection_scores_native = scored + expert_bias
185 selection_scores = tl.where(
186 expert_selected,
187 selection_scores_native.to(tl.float32),
188 neg_inf,
189 )
191 topk_vals = tl.full([TOPK], 0.0, tl.float32)
192 topk_idx = tl.full([TOPK], 0, tl.int32)
193 pos_range = tl.arange(0, TOPK)
195 for i in range(TOPK):
196 max_val = tl.max(selection_scores, axis=0)
197 is_max = selection_scores == max_val
199 candidate_idx = tl.where(is_max, expert_offsets, num_experts + 1)
200 selected_idx = tl.min(candidate_idx, axis=0)
202 selected_raw = tl.load(
203 scores_ptr + pid * stride_scores_token + selected_idx,
204 mask=selected_idx < num_experts,
205 other=neg_inf,
206 ).to(tl.float32)
208 if scoring_func == 1:
209 selected_score = 0.5 * libdevice.tanh(0.5 * selected_raw) + 0.5
210 else:
211 selected_score = selected_raw
213 topk_vals = tl.where(pos_range == i, selected_score, topk_vals)
214 topk_idx = tl.where(pos_range == i, selected_idx.to(tl.int32), topk_idx)
216 selection_scores = tl.where(
217 expert_offsets == selected_idx, neg_inf, selection_scores
218 )
220 if renormalize == 1:
221 topk_sum = tl.sum(topk_vals, axis=0) + 1e-20
222 scale = routed_scaling_factor / topk_sum
223 else:
224 scale = routed_scaling_factor
226 topk_vals = topk_vals * scale
228 default_idx = pos_range.to(tl.int32)
229 default_vals = tl.full([TOPK], 1.0 / topk, tl.float32)
231 final_vals = tl.where(if_proceed, topk_vals, default_vals)
232 final_idx = tl.where(if_proceed, topk_idx, default_idx)
234 tl.store(
235 topk_values_ptr + pid * stride_out_token + pos_range,
236 final_vals,
237 mask=pos_range < topk,
238 )
240 tl.store(
241 topk_indices_ptr + pid * stride_out_token + pos_range,
242 final_idx,
243 mask=pos_range < topk,
244 )
247def grouped_topk(
248 scores: torch.Tensor,
249 n_group: int,
250 topk_group: int,
251 topk: int,
252 renormalize: bool,
253 routed_scaling_factor: float,
254 bias: torch.Tensor,
255 scoring_func: int = 0,
256):
257 if scores.ndim != 2:
258 raise ValueError("scores must be a 2D Tensor")
259 num_tokens, num_experts = scores.shape
260 if num_experts % n_group != 0:
261 raise ValueError("num_experts must be divisible by n_group")
262 if n_group > 32:
263 raise ValueError("n_group should be smaller than or equal to 32")
264 if topk > 32:
265 raise ValueError("topk should be smaller than or equal to 32 for now")
266 if scoring_func not in (0, 1):
267 raise ValueError("scoring_func must be 0 (none) or 1 (sigmoid)")
269 if bias.dtype != scores.dtype:
270 bias = bias.to(scores.dtype)
271 if bias.ndim != 1:
272 bias = bias.flatten()
273 if len(bias) != num_experts:
274 raise ValueError(
275 f"bias length ({len(bias)}) must match num_experts ({num_experts})"
276 )
278 num_experts_per_group = num_experts // n_group
280 if scores.dtype == torch.float32:
281 INPUT_DTYPE = tl.float32
282 elif scores.dtype == torch.float16:
283 INPUT_DTYPE = tl.float16
284 elif scores.dtype == torch.bfloat16:
285 INPUT_DTYPE = tl.bfloat16
286 else:
287 raise ValueError(f"Unsupported dtype: {scores.dtype}")
289 group_scores = torch.empty(
290 (num_tokens, n_group),
291 device=scores.device,
292 dtype=scores.dtype,
293 )
295 topk_values = torch.empty(
296 (num_tokens, topk),
297 device=scores.device,
298 dtype=torch.float32,
299 )
301 topk_indices = torch.empty(
302 (num_tokens, topk),
303 device=scores.device,
304 dtype=torch.int32,
305 )
307 BLOCK1 = triton.next_power_of_2(num_experts_per_group)
308 grid1 = (num_tokens * n_group,)
310 topk_with_k2_triton[grid1](
311 scores,
312 bias,
313 group_scores,
314 num_experts_per_group,
315 n_group,
316 scores.stride(0),
317 group_scores.stride(0),
318 scoring_func,
319 BLOCK_SIZE=BLOCK1,
320 INPUT_DTYPE=INPUT_DTYPE,
321 )
323 BLOCK_GROUP = triton.next_power_of_2(n_group)
324 BLOCK_EXPERT = triton.next_power_of_2(num_experts)
325 grid2 = (num_tokens,)
327 group_idx_and_topk_triton[grid2](
328 scores,
329 group_scores,
330 topk_values,
331 topk_indices,
332 bias,
333 num_tokens,
334 n_group,
335 topk_group,
336 topk,
337 num_experts,
338 num_experts_per_group,
339 routed_scaling_factor,
340 scoring_func,
341 scores.stride(0),
342 group_scores.stride(0),
343 topk_values.stride(0),
344 N_GROUP=n_group,
345 TOPK_GROUP=topk_group,
346 TOPK=topk,
347 BLOCK_GROUP=BLOCK_GROUP,
348 BLOCK_EXPERT=BLOCK_EXPERT,
349 INPUT_DTYPE=INPUT_DTYPE,
350 renormalize=int(renormalize),
351 )
353 return topk_values, topk_indices