Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/topk_softmax.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def topk_gating_softmax_kernel(
8 input_ptr,
9 finished_ptr, # interface reserved, not yet used
10 output_ptr,
11 indices_ptr,
12 source_rows_ptr,
13 num_rows,
14 k,
15 num_experts,
16 start_expert,
17 end_expert,
18 INDEX_TY: tl.constexpr,
19 BLOCK_SIZE_ROWS: tl.constexpr,
20 BLOCK_SIZE_EXPERTS: tl.constexpr,
21):
22 pid = tl.program_id(0)
23 rows = tl.arange(0, BLOCK_SIZE_ROWS) + pid * BLOCK_SIZE_ROWS
24 valid_rows = rows < num_rows
26 cols = start_expert + tl.arange(0, BLOCK_SIZE_EXPERTS)
27 valid_cols = cols < end_expert
29 logits = tl.load(
30 input_ptr + rows[:, None] * num_experts + cols[None, :],
31 mask=valid_rows[:, None] & valid_cols[None, :],
32 other=-float("inf"),
33 )
35 row_max = tl.max(logits, axis=1)[:, None]
36 exp_vals = tl.exp(logits - row_max)
37 probs = exp_vals / (tl.sum(exp_vals, axis=1)[:, None] + 1e-8)
39 for ki in range(k):
40 curr_max = tl.max(probs, axis=1)
41 curr_arg = tl.argmax(probs, axis=1) + start_expert
43 tl.store(output_ptr + rows * k + ki, curr_max, mask=valid_rows)
44 tl.store(indices_ptr + rows * k + ki, curr_arg.to(INDEX_TY), mask=valid_rows)
45 tl.store(
46 source_rows_ptr + rows * k + ki,
47 (ki * num_rows + rows).to(tl.int32),
48 mask=valid_rows,
49 )
51 probs = tl.where(
52 cols[None, :] == (curr_arg[:, None] - start_expert), -float("inf"), probs
53 )
56def topk_softmax(
57 topk_weights: torch.Tensor,
58 topk_indices: torch.Tensor,
59 token_expert_indices: torch.Tensor,
60 gating_output: torch.Tensor,
61) -> None:
62 num_tokens, num_experts = gating_output.shape
63 topk = topk_weights.size(-1)
64 assert topk <= 32
66 if topk_indices.dtype == torch.int32:
67 index_ty = tl.int32
68 # elif topk_indices.dtype == torch.uint32:
69 # index_ty = tl.uint32
70 elif topk_indices.dtype == torch.int64:
71 index_ty = tl.int64
72 else:
73 raise TypeError("topk_indices must be int32/int64/uint32")
75 max_total_threads = 1024
76 BLOCK_SIZE_EXPERTS = ((triton.next_power_of_2(num_experts) + 31) // 32) * 32
77 BLOCK_SIZE_EXPERTS = min(BLOCK_SIZE_EXPERTS, 1024)
78 BLOCK_SIZE_ROWS = max_total_threads // BLOCK_SIZE_EXPERTS
79 BLOCK_SIZE_ROWS = max(BLOCK_SIZE_ROWS, 1)
81 grid = (triton.cdiv(num_tokens, BLOCK_SIZE_ROWS),)
83 topk_gating_softmax_kernel[grid](
84 input_ptr=gating_output,
85 finished_ptr=None,
86 output_ptr=topk_weights,
87 indices_ptr=topk_indices,
88 source_rows_ptr=token_expert_indices,
89 num_rows=num_tokens,
90 k=topk,
91 num_experts=num_experts,
92 start_expert=0,
93 end_expert=num_experts,
94 INDEX_TY=index_ty,
95 BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS,
96 BLOCK_SIZE_EXPERTS=BLOCK_SIZE_EXPERTS,
97 isCloseCoreTiling=True,
98 )