Coverage for src/flag_gems/fused/topk_softmax.py: 15%
54 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def topk_gating_softmax_kernel(
12 input_ptr,
13 finished_ptr, # interface reserved, not yet used
14 output_ptr,
15 indices_ptr,
16 source_rows_ptr,
17 num_rows,
18 k,
19 num_experts,
20 start_expert,
21 end_expert,
22 renormalize: tl.constexpr,
23 INDEX_TY: tl.constexpr,
24 BLOCK_SIZE_ROWS: tl.constexpr,
25 BLOCK_SIZE_EXPERTS: tl.constexpr,
26):
27 pid = tl.program_id(0)
28 rows = tl.arange(0, BLOCK_SIZE_ROWS) + pid * BLOCK_SIZE_ROWS
29 valid_rows = rows < num_rows
31 cols = start_expert + tl.arange(0, BLOCK_SIZE_EXPERTS)
32 valid_cols = cols < end_expert
34 logits = tl.load(
35 input_ptr + rows[:, None] * num_experts + cols[None, :],
36 mask=valid_rows[:, None] & valid_cols[None, :],
37 other=-float("inf"),
38 ).to(tl.float32)
40 row_max = tl.max(logits, axis=1)[:, None]
41 exp_vals = tl.exp(logits - row_max)
42 probs = exp_vals / (tl.sum(exp_vals, axis=1)[:, None] + 1e-8)
44 selected_sum = tl.zeros([BLOCK_SIZE_ROWS], dtype=tl.float32)
45 for ki in range(k):
46 curr_max, curr_arg = tl.max(probs, axis=1, return_indices=True)
48 tl.store(output_ptr + rows * k + ki, curr_max, mask=valid_rows)
49 tl.store(indices_ptr + rows * k + ki, curr_arg.to(INDEX_TY), mask=valid_rows)
50 tl.store(
51 source_rows_ptr + rows * k + ki,
52 (ki * num_rows + rows).to(tl.int32),
53 mask=valid_rows,
54 )
55 if renormalize:
56 selected_sum += curr_max
58 probs = tl.where(
59 cols[None, :] == (curr_arg[:, None] - start_expert), -float("inf"), probs
60 )
62 if renormalize:
63 norm = selected_sum + 1e-8
64 for ki in range(k):
65 idx = rows * k + ki
66 val = tl.load(output_ptr + idx, mask=valid_rows)
67 tl.store(output_ptr + idx, val / norm, mask=valid_rows)
70def topk_softmax(
71 topk_weights: torch.Tensor,
72 topk_indices: torch.Tensor,
73 token_expert_indices: torch.Tensor,
74 gating_output: torch.Tensor,
75 renormalize: bool = False,
76) -> None:
77 logger.debug("GEMS TOPK SOFTMAX")
78 num_tokens, num_experts = gating_output.shape
79 topk = topk_weights.size(-1)
80 assert topk <= 32
82 if topk_indices.dtype == torch.int32:
83 index_ty = tl.int32
84 elif topk_indices.dtype == torch.uint32:
85 index_ty = tl.uint32
86 elif topk_indices.dtype == torch.int64:
87 index_ty = tl.int64
88 else:
89 raise TypeError("topk_indices must be int32/int64/uint32")
91 max_total_threads = 1024
92 BLOCK_SIZE_EXPERTS = ((triton.next_power_of_2(num_experts) + 31) // 32) * 32
93 BLOCK_SIZE_EXPERTS = min(BLOCK_SIZE_EXPERTS, 1024)
94 BLOCK_SIZE_ROWS = max_total_threads // BLOCK_SIZE_EXPERTS
95 BLOCK_SIZE_ROWS = max(BLOCK_SIZE_ROWS, 1)
97 # If num_experts > 128, intra-warp shuffling is forced for reduction,
98 # which requires the warp layout to be confined to a single row.
99 # Consequently, in the TTGIR, the second dimension of warpsPerCTA is fixed to 1.
100 if num_experts > 128:
101 BLOCK_SIZE_ROWS = 1
102 num_warps = 1
103 else:
104 num_warps = 4
106 grid = (triton.cdiv(num_tokens, BLOCK_SIZE_ROWS),)
107 topk_gating_softmax_kernel[grid](
108 input_ptr=gating_output,
109 finished_ptr=None,
110 output_ptr=topk_weights,
111 indices_ptr=topk_indices,
112 source_rows_ptr=token_expert_indices,
113 num_rows=num_tokens,
114 k=topk,
115 num_experts=num_experts,
116 start_expert=0,
117 end_expert=num_experts,
118 renormalize=renormalize,
119 INDEX_TY=index_ty,
120 BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS,
121 BLOCK_SIZE_EXPERTS=BLOCK_SIZE_EXPERTS,
122 num_warps=num_warps,
123 )