Coverage for src/flag_gems/fused/topk_softmax.py: 12%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def topk_gating_softmax_kernel(
8 input_ptr,
9 finished_ptr, # interface reserved, not yet used
10 output_ptr,
11 indices_ptr,
12 source_rows_ptr,
13 num_rows,
14 k,
15 num_experts,
16 start_expert,
17 end_expert,
18 renormalize: tl.constexpr,
19 INDEX_TY: tl.constexpr,
20 BLOCK_SIZE_ROWS: tl.constexpr,
21 BLOCK_SIZE_EXPERTS: tl.constexpr,
22):
23 pid = tl.program_id(0)
24 rows = tl.arange(0, BLOCK_SIZE_ROWS) + pid * BLOCK_SIZE_ROWS
25 valid_rows = rows < num_rows
27 cols = start_expert + tl.arange(0, BLOCK_SIZE_EXPERTS)
28 valid_cols = cols < end_expert
30 logits = tl.load(
31 input_ptr + rows[:, None] * num_experts + cols[None, :],
32 mask=valid_rows[:, None] & valid_cols[None, :],
33 other=-float("inf"),
34 ).to(tl.float32)
36 row_max = tl.max(logits, axis=1)[:, None]
37 exp_vals = tl.exp(logits - row_max)
38 probs = exp_vals / (tl.sum(exp_vals, axis=1)[:, None] + 1e-8)
40 selected_sum = tl.zeros([BLOCK_SIZE_ROWS], dtype=tl.float32)
41 for ki in range(k):
42 curr_max, curr_arg = tl.max(probs, axis=1, return_indices=True)
44 tl.store(output_ptr + rows * k + ki, curr_max, mask=valid_rows)
45 tl.store(indices_ptr + rows * k + ki, curr_arg.to(INDEX_TY), mask=valid_rows)
46 tl.store(
47 source_rows_ptr + rows * k + ki,
48 (ki * num_rows + rows).to(tl.int32),
49 mask=valid_rows,
50 )
51 if renormalize:
52 selected_sum += curr_max
54 probs = tl.where(
55 cols[None, :] == (curr_arg[:, None] - start_expert), -float("inf"), probs
56 )
58 if renormalize:
59 norm = selected_sum + 1e-8
60 for ki in range(k):
61 idx = rows * k + ki
62 val = tl.load(output_ptr + idx, mask=valid_rows)
63 tl.store(output_ptr + idx, val / norm, mask=valid_rows)
66def topk_softmax(
67 topk_weights: torch.Tensor,
68 topk_indices: torch.Tensor,
69 token_expert_indices: torch.Tensor,
70 gating_output: torch.Tensor,
71 renormalize: bool = False,
72) -> None:
73 num_tokens, num_experts = gating_output.shape
74 topk = topk_weights.size(-1)
75 assert topk <= 32
77 if topk_indices.dtype == torch.int32:
78 index_ty = tl.int32
79 elif topk_indices.dtype == torch.uint32:
80 index_ty = tl.uint32
81 elif topk_indices.dtype == torch.int64:
82 index_ty = tl.int64
83 else:
84 raise TypeError("topk_indices must be int32/int64/uint32")
86 max_total_threads = 1024
87 BLOCK_SIZE_EXPERTS = ((triton.next_power_of_2(num_experts) + 31) // 32) * 32
88 BLOCK_SIZE_EXPERTS = min(BLOCK_SIZE_EXPERTS, 1024)
89 BLOCK_SIZE_ROWS = max_total_threads // BLOCK_SIZE_EXPERTS
90 BLOCK_SIZE_ROWS = max(BLOCK_SIZE_ROWS, 1)
92 # If num_experts > 128, intra-warp shuffling is forced for reduction,
93 # which requires the warp layout to be confined to a single row.
94 # Consequently, in the TTGIR, the second dimension of warpsPerCTA is fixed to 1.
95 if num_experts > 128:
96 BLOCK_SIZE_ROWS = 1
97 num_warps = 1
98 else:
99 num_warps = 4
101 grid = (triton.cdiv(num_tokens, BLOCK_SIZE_ROWS),)
102 topk_gating_softmax_kernel[grid](
103 input_ptr=gating_output,
104 finished_ptr=None,
105 output_ptr=topk_weights,
106 indices_ptr=topk_indices,
107 source_rows_ptr=token_expert_indices,
108 num_rows=num_tokens,
109 k=topk,
110 num_experts=num_experts,
111 start_expert=0,
112 end_expert=num_experts,
113 renormalize=renormalize,
114 INDEX_TY=index_ty,
115 BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS,
116 BLOCK_SIZE_EXPERTS=BLOCK_SIZE_EXPERTS,
117 num_warps=num_warps,
118 )