Coverage for src/flag_gems/fused/grouped_topk.py: 6%
130 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def topk_with_k2_triton(
8 scores_ptr,
9 bias_ptr,
10 group_scores_ptr,
11 num_experts_per_group,
12 n_group,
13 stride_scores_token,
14 stride_group_scores_token,
15 BLOCK_SIZE: tl.constexpr,
16 INPUT_DTYPE: tl.constexpr,
17):
18 pid = tl.program_id(0)
20 token_id = pid // n_group
21 group_id = pid % n_group
23 lane = tl.arange(0, BLOCK_SIZE)
24 mask = lane < num_experts_per_group
26 scores_offset = token_id * stride_scores_token + group_id * num_experts_per_group
27 bias_offset = group_id * num_experts_per_group
29 x = tl.load(
30 scores_ptr + scores_offset + lane,
31 mask=mask,
32 other=-float("inf"),
33 )
35 b = tl.load(
36 bias_ptr + bias_offset + lane,
37 mask=mask,
38 other=0.0,
39 )
41 x = x + b
43 x_f32 = x.to(tl.float32)
45 max1 = tl.max(x_f32, axis=0)
46 is_max1 = (x_f32 == max1) & mask
47 count_max1 = tl.sum(is_max1.to(tl.int32), axis=0)
49 x2 = tl.where(
50 is_max1 & (count_max1 == 1),
51 -float("inf"),
52 x_f32,
53 )
54 max2 = tl.max(x2, axis=0)
56 group_scores_offset = token_id * stride_group_scores_token + group_id
57 tl.store(
58 group_scores_ptr + group_scores_offset,
59 (max1 + max2).to(INPUT_DTYPE),
60 )
63@triton.jit
64def group_idx_and_topk_triton(
65 scores_ptr,
66 group_scores_ptr,
67 topk_values_ptr,
68 topk_indices_ptr,
69 bias_ptr,
70 num_tokens,
71 n_group,
72 topk_group,
73 topk,
74 num_experts,
75 num_experts_per_group,
76 routed_scaling_factor,
77 stride_scores_token,
78 stride_group_scores_token,
79 stride_out_token,
80 N_GROUP: tl.constexpr,
81 TOPK_GROUP: tl.constexpr,
82 TOPK: tl.constexpr,
83 BLOCK_GROUP: tl.constexpr,
84 BLOCK_EXPERT: tl.constexpr,
85 INPUT_DTYPE: tl.constexpr,
86 renormalize: tl.constexpr,
87):
88 pid = tl.program_id(0)
89 if pid >= num_tokens:
90 return
92 neg_inf = -float("inf")
94 group_offsets = tl.arange(0, BLOCK_GROUP)
95 valid_group = group_offsets < n_group
97 group_scores = tl.load(
98 group_scores_ptr + pid * stride_group_scores_token + group_offsets,
99 mask=valid_group,
100 other=neg_inf,
101 )
103 group_scores_f32 = group_scores.to(tl.float32)
104 is_finite = (group_scores_f32 == group_scores_f32) & (
105 group_scores_f32 != float("inf")
106 )
107 group_scores_f32 = tl.where(is_finite & valid_group, group_scores_f32, neg_inf)
109 max_group_score = tl.max(group_scores_f32, axis=0)
110 if_proceed = max_group_score != neg_inf
112 value = group_scores_f32
113 target_num_min = BLOCK_GROUP - n_group + topk_group
114 count_equal_to_top_value = BLOCK_GROUP - n_group
115 pre_count_equal_to_top_value = 0
116 topk_group_value = neg_inf
118 for _ in range(TOPK_GROUP):
119 need = count_equal_to_top_value < target_num_min
120 max_val = tl.max(value, axis=0)
122 is_max = need & (value == max_val)
123 value = tl.where(is_max, neg_inf, value)
125 newly = tl.sum(is_max.to(tl.int32), axis=0)
127 pre_count_equal_to_top_value = tl.where(
128 need, count_equal_to_top_value, pre_count_equal_to_top_value
129 )
130 count_equal_to_top_value = tl.where(
131 need, count_equal_to_top_value + newly, count_equal_to_top_value
132 )
133 topk_group_value = tl.where(need, max_val, topk_group_value)
135 num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value
137 group_gt = group_scores_f32 > topk_group_value
138 group_eq = group_scores_f32 == topk_group_value
140 eq_i = group_eq.to(tl.int32)
141 prefix_eq = tl.cumsum(eq_i, axis=0) - eq_i
143 group_selected = (
144 group_gt | (group_eq & (prefix_eq < num_equalto_topkth_group))
145 ) & valid_group
147 expert_offsets = tl.arange(0, BLOCK_EXPERT)
148 valid_expert = expert_offsets < num_experts
149 expert_group = expert_offsets // num_experts_per_group
151 expert_in_group = expert_group[:, None] == group_offsets[None, :]
152 expert_selected = (
153 tl.sum((expert_in_group & group_selected[None, :]).to(tl.int32), axis=1) > 0
154 ) & valid_expert
156 scored = tl.load(
157 scores_ptr + pid * stride_scores_token + expert_offsets,
158 mask=expert_selected,
159 other=neg_inf,
160 )
162 expert_bias = tl.load(
163 bias_ptr + expert_offsets,
164 mask=valid_expert,
165 other=0.0,
166 )
168 selection_scores_native = scored + expert_bias
170 selection_scores = tl.where(
171 expert_selected,
172 selection_scores_native.to(tl.float32),
173 neg_inf,
174 )
176 topk_vals = tl.full([TOPK], 0.0, tl.float32)
177 topk_idx = tl.full([TOPK], 0, tl.int32)
178 pos_range = tl.arange(0, TOPK)
180 for i in range(TOPK):
181 max_val = tl.max(selection_scores, axis=0)
182 is_max = selection_scores == max_val
184 candidate_idx = tl.where(is_max, expert_offsets, num_experts + 1)
185 selected_idx = tl.min(candidate_idx, axis=0)
187 selected_score = tl.load(
188 scores_ptr + pid * stride_scores_token + selected_idx,
189 mask=selected_idx < num_experts,
190 other=neg_inf,
191 ).to(tl.float32)
193 topk_vals = tl.where(pos_range == i, selected_score, topk_vals)
194 topk_idx = tl.where(pos_range == i, selected_idx.to(tl.int32), topk_idx)
196 selection_scores = tl.where(
197 expert_offsets == selected_idx, neg_inf, selection_scores
198 )
200 if renormalize == 1:
201 topk_sum = tl.sum(topk_vals, axis=0) + 1e-20
202 scale = routed_scaling_factor / topk_sum
203 else:
204 scale = routed_scaling_factor
206 topk_vals = topk_vals * scale
208 default_idx = pos_range.to(tl.int32)
209 default_vals = tl.full([TOPK], 1.0 / topk, tl.float32)
211 final_vals = tl.where(if_proceed, topk_vals, default_vals)
212 final_idx = tl.where(if_proceed, topk_idx, default_idx)
214 tl.store(
215 topk_values_ptr + pid * stride_out_token + pos_range,
216 final_vals,
217 mask=pos_range < topk,
218 )
220 tl.store(
221 topk_indices_ptr + pid * stride_out_token + pos_range,
222 final_idx,
223 mask=pos_range < topk,
224 )
227def grouped_topk(
228 scores: torch.Tensor,
229 n_group: int,
230 topk_group: int,
231 topk: int,
232 renormalize: bool,
233 routed_scaling_factor: float,
234 bias: torch.Tensor,
235 scoring_func: int = 0,
236):
237 if scores.ndim != 2:
238 raise ValueError("scores must be a 2D Tensor")
239 num_tokens, num_experts = scores.shape
240 if num_experts % n_group != 0:
241 raise ValueError("num_experts must be divisible by n_group")
242 if n_group > 32:
243 raise ValueError("n_group should be smaller than or equal to 32")
244 if topk > 32:
245 raise ValueError("topk should be smaller than or equal to 32 for now")
246 if scoring_func not in (0, 1):
247 raise ValueError("scoring_func must be 0 (none) or 1 (sigmoid)")
249 if bias.dtype != scores.dtype:
250 bias = bias.to(scores.dtype)
251 if bias.ndim != 1:
252 bias = bias.flatten()
253 if len(bias) != num_experts:
254 raise ValueError(
255 f"bias length ({len(bias)}) must match num_experts ({num_experts})"
256 )
258 num_experts_per_group = num_experts // n_group
260 if scores.dtype == torch.float32:
261 INPUT_DTYPE = tl.float32
262 elif scores.dtype == torch.float16:
263 INPUT_DTYPE = tl.float16
264 elif scores.dtype == torch.bfloat16:
265 INPUT_DTYPE = tl.bfloat16
266 else:
267 raise ValueError(f"Unsupported dtype: {scores.dtype}")
269 if scoring_func == 1:
270 from flag_gems.ops.tanh import tanh as gems_tanh
272 scores_processed = 0.5 * gems_tanh(0.5 * scores) + 0.5
273 else:
274 scores_processed = scores
276 group_scores = torch.empty(
277 (num_tokens, n_group),
278 device=scores.device,
279 dtype=scores.dtype,
280 )
282 topk_values = torch.empty(
283 (num_tokens, topk),
284 device=scores.device,
285 dtype=torch.float32,
286 )
288 topk_indices = torch.empty(
289 (num_tokens, topk),
290 device=scores.device,
291 dtype=torch.int32,
292 )
294 BLOCK1 = triton.next_power_of_2(num_experts_per_group)
295 grid1 = (num_tokens * n_group,)
297 topk_with_k2_triton[grid1](
298 scores_processed,
299 bias,
300 group_scores,
301 num_experts_per_group,
302 n_group,
303 scores_processed.stride(0),
304 group_scores.stride(0),
305 BLOCK_SIZE=BLOCK1,
306 INPUT_DTYPE=INPUT_DTYPE,
307 )
309 BLOCK_GROUP = triton.next_power_of_2(n_group)
310 BLOCK_EXPERT = triton.next_power_of_2(num_experts)
311 grid2 = (num_tokens,)
313 group_idx_and_topk_triton[grid2](
314 scores_processed,
315 group_scores,
316 topk_values,
317 topk_indices,
318 bias,
319 num_tokens,
320 n_group,
321 topk_group,
322 topk,
323 num_experts,
324 num_experts_per_group,
325 routed_scaling_factor,
326 scores_processed.stride(0),
327 group_scores.stride(0),
328 topk_values.stride(0),
329 N_GROUP=n_group,
330 TOPK_GROUP=topk_group,
331 TOPK=topk,
332 BLOCK_GROUP=BLOCK_GROUP,
333 BLOCK_EXPERT=BLOCK_EXPERT,
334 INPUT_DTYPE=INPUT_DTYPE,
335 renormalize=int(renormalize),
336 )
338 return topk_values, topk_indices