Coverage for src/flag_gems/fused/grouped_topk.py: 8%
133 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def topk_with_k2_triton(
12 scores_ptr,
13 bias_ptr,
14 group_scores_ptr,
15 num_experts_per_group,
16 n_group,
17 stride_scores_token,
18 stride_group_scores_token,
19 BLOCK_SIZE: tl.constexpr,
20 INPUT_DTYPE: tl.constexpr,
21):
22 pid = tl.program_id(0)
24 token_id = pid // n_group
25 group_id = pid % n_group
27 lane = tl.arange(0, BLOCK_SIZE)
28 mask = lane < num_experts_per_group
30 scores_offset = token_id * stride_scores_token + group_id * num_experts_per_group
31 bias_offset = group_id * num_experts_per_group
33 x = tl.load(
34 scores_ptr + scores_offset + lane,
35 mask=mask,
36 other=-float("inf"),
37 )
39 b = tl.load(
40 bias_ptr + bias_offset + lane,
41 mask=mask,
42 other=0.0,
43 )
45 x = x + b
47 x_f32 = x.to(tl.float32)
49 max1 = tl.max(x_f32, axis=0)
50 is_max1 = (x_f32 == max1) & mask
51 count_max1 = tl.sum(is_max1.to(tl.int32), axis=0)
53 x2 = tl.where(
54 is_max1 & (count_max1 == 1),
55 -float("inf"),
56 x_f32,
57 )
58 max2 = tl.max(x2, axis=0)
60 group_scores_offset = token_id * stride_group_scores_token + group_id
61 tl.store(
62 group_scores_ptr + group_scores_offset,
63 (max1 + max2).to(INPUT_DTYPE),
64 )
67@triton.jit
68def group_idx_and_topk_triton(
69 scores_ptr,
70 group_scores_ptr,
71 topk_values_ptr,
72 topk_indices_ptr,
73 bias_ptr,
74 num_tokens,
75 n_group,
76 topk_group,
77 topk,
78 num_experts,
79 num_experts_per_group,
80 routed_scaling_factor,
81 stride_scores_token,
82 stride_group_scores_token,
83 stride_out_token,
84 N_GROUP: tl.constexpr,
85 TOPK_GROUP: tl.constexpr,
86 TOPK: tl.constexpr,
87 BLOCK_GROUP: tl.constexpr,
88 BLOCK_EXPERT: tl.constexpr,
89 INPUT_DTYPE: tl.constexpr,
90 renormalize: tl.constexpr,
91):
92 pid = tl.program_id(0)
93 if pid >= num_tokens:
94 return
96 neg_inf = -float("inf")
98 group_offsets = tl.arange(0, BLOCK_GROUP)
99 valid_group = group_offsets < n_group
101 group_scores = tl.load(
102 group_scores_ptr + pid * stride_group_scores_token + group_offsets,
103 mask=valid_group,
104 other=neg_inf,
105 )
107 group_scores_f32 = group_scores.to(tl.float32)
108 is_finite = (group_scores_f32 == group_scores_f32) & (
109 group_scores_f32 != float("inf")
110 )
111 group_scores_f32 = tl.where(is_finite & valid_group, group_scores_f32, neg_inf)
113 max_group_score = tl.max(group_scores_f32, axis=0)
114 if_proceed = max_group_score != neg_inf
116 value = group_scores_f32
117 target_num_min = BLOCK_GROUP - n_group + topk_group
118 count_equal_to_top_value = BLOCK_GROUP - n_group
119 pre_count_equal_to_top_value = 0
120 topk_group_value = neg_inf
122 for _ in range(TOPK_GROUP):
123 need = count_equal_to_top_value < target_num_min
124 max_val = tl.max(value, axis=0)
126 is_max = need & (value == max_val)
127 value = tl.where(is_max, neg_inf, value)
129 newly = tl.sum(is_max.to(tl.int32), axis=0)
131 pre_count_equal_to_top_value = tl.where(
132 need, count_equal_to_top_value, pre_count_equal_to_top_value
133 )
134 count_equal_to_top_value = tl.where(
135 need, count_equal_to_top_value + newly, count_equal_to_top_value
136 )
137 topk_group_value = tl.where(need, max_val, topk_group_value)
139 num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value
141 group_gt = group_scores_f32 > topk_group_value
142 group_eq = group_scores_f32 == topk_group_value
144 eq_i = group_eq.to(tl.int32)
145 prefix_eq = tl.cumsum(eq_i, axis=0) - eq_i
147 group_selected = (
148 group_gt | (group_eq & (prefix_eq < num_equalto_topkth_group))
149 ) & valid_group
151 expert_offsets = tl.arange(0, BLOCK_EXPERT)
152 valid_expert = expert_offsets < num_experts
153 expert_group = expert_offsets // num_experts_per_group
155 expert_in_group = expert_group[:, None] == group_offsets[None, :]
156 expert_selected = (
157 tl.sum((expert_in_group & group_selected[None, :]).to(tl.int32), axis=1) > 0
158 ) & valid_expert
160 scored = tl.load(
161 scores_ptr + pid * stride_scores_token + expert_offsets,
162 mask=expert_selected,
163 other=neg_inf,
164 )
166 expert_bias = tl.load(
167 bias_ptr + expert_offsets,
168 mask=valid_expert,
169 other=0.0,
170 )
172 selection_scores_native = scored + expert_bias
174 selection_scores = tl.where(
175 expert_selected,
176 selection_scores_native.to(tl.float32),
177 neg_inf,
178 )
180 topk_vals = tl.full([TOPK], 0.0, tl.float32)
181 topk_idx = tl.full([TOPK], 0, tl.int32)
182 pos_range = tl.arange(0, TOPK)
184 for i in range(TOPK):
185 max_val = tl.max(selection_scores, axis=0)
186 is_max = selection_scores == max_val
188 candidate_idx = tl.where(is_max, expert_offsets, num_experts + 1)
189 selected_idx = tl.min(candidate_idx, axis=0)
191 selected_score = tl.load(
192 scores_ptr + pid * stride_scores_token + selected_idx,
193 mask=selected_idx < num_experts,
194 other=neg_inf,
195 ).to(tl.float32)
197 topk_vals = tl.where(pos_range == i, selected_score, topk_vals)
198 topk_idx = tl.where(pos_range == i, selected_idx.to(tl.int32), topk_idx)
200 selection_scores = tl.where(
201 expert_offsets == selected_idx, neg_inf, selection_scores
202 )
204 if renormalize == 1:
205 topk_sum = tl.sum(topk_vals, axis=0) + 1e-20
206 scale = routed_scaling_factor / topk_sum
207 else:
208 scale = routed_scaling_factor
210 topk_vals = topk_vals * scale
212 default_idx = pos_range.to(tl.int32)
213 default_vals = tl.full([TOPK], 1.0 / topk, tl.float32)
215 final_vals = tl.where(if_proceed, topk_vals, default_vals)
216 final_idx = tl.where(if_proceed, topk_idx, default_idx)
218 tl.store(
219 topk_values_ptr + pid * stride_out_token + pos_range,
220 final_vals,
221 mask=pos_range < topk,
222 )
224 tl.store(
225 topk_indices_ptr + pid * stride_out_token + pos_range,
226 final_idx,
227 mask=pos_range < topk,
228 )
231def grouped_topk(
232 scores: torch.Tensor,
233 n_group: int,
234 topk_group: int,
235 topk: int,
236 renormalize: bool,
237 routed_scaling_factor: float,
238 bias: torch.Tensor,
239 scoring_func: int = 0,
240):
241 logger.debug("GEMS GROUPED TOPK")
242 if scores.ndim != 2:
243 raise ValueError("scores must be a 2D Tensor")
244 num_tokens, num_experts = scores.shape
245 if num_experts % n_group != 0:
246 raise ValueError("num_experts must be divisible by n_group")
247 if n_group > 32:
248 raise ValueError("n_group should be smaller than or equal to 32")
249 if topk > 32:
250 raise ValueError("topk should be smaller than or equal to 32 for now")
251 if scoring_func not in (0, 1):
252 raise ValueError("scoring_func must be 0 (none) or 1 (sigmoid)")
254 if bias.dtype != scores.dtype:
255 bias = bias.to(scores.dtype)
256 if bias.ndim != 1:
257 bias = bias.flatten()
258 if len(bias) != num_experts:
259 raise ValueError(
260 f"bias length ({len(bias)}) must match num_experts ({num_experts})"
261 )
263 num_experts_per_group = num_experts // n_group
265 if scores.dtype == torch.float32:
266 INPUT_DTYPE = tl.float32
267 elif scores.dtype == torch.float16:
268 INPUT_DTYPE = tl.float16
269 elif scores.dtype == torch.bfloat16:
270 INPUT_DTYPE = tl.bfloat16
271 else:
272 raise ValueError(f"Unsupported dtype: {scores.dtype}")
274 if scoring_func == 1:
275 from flag_gems.ops.tanh import tanh as gems_tanh
277 scores_processed = 0.5 * gems_tanh(0.5 * scores) + 0.5
278 else:
279 scores_processed = scores
281 group_scores = torch.empty(
282 (num_tokens, n_group),
283 device=scores.device,
284 dtype=scores.dtype,
285 )
287 topk_values = torch.empty(
288 (num_tokens, topk),
289 device=scores.device,
290 dtype=torch.float32,
291 )
293 topk_indices = torch.empty(
294 (num_tokens, topk),
295 device=scores.device,
296 dtype=torch.int32,
297 )
299 BLOCK1 = triton.next_power_of_2(num_experts_per_group)
300 grid1 = (num_tokens * n_group,)
302 topk_with_k2_triton[grid1](
303 scores_processed,
304 bias,
305 group_scores,
306 num_experts_per_group,
307 n_group,
308 scores_processed.stride(0),
309 group_scores.stride(0),
310 BLOCK_SIZE=BLOCK1,
311 INPUT_DTYPE=INPUT_DTYPE,
312 )
314 BLOCK_GROUP = triton.next_power_of_2(n_group)
315 BLOCK_EXPERT = triton.next_power_of_2(num_experts)
316 grid2 = (num_tokens,)
318 group_idx_and_topk_triton[grid2](
319 scores_processed,
320 group_scores,
321 topk_values,
322 topk_indices,
323 bias,
324 num_tokens,
325 n_group,
326 topk_group,
327 topk,
328 num_experts,
329 num_experts_per_group,
330 routed_scaling_factor,
331 scores_processed.stride(0),
332 group_scores.stride(0),
333 topk_values.stride(0),
334 N_GROUP=n_group,
335 TOPK_GROUP=topk_group,
336 TOPK=topk,
337 BLOCK_GROUP=BLOCK_GROUP,
338 BLOCK_EXPERT=BLOCK_EXPERT,
339 INPUT_DTYPE=INPUT_DTYPE,
340 renormalize=int(renormalize),
341 )
343 return topk_values, topk_indices