Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/moe_align_block_size.py: 0%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2from typing import Optional
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11def ceil_div(a, b):
12 return (a + b - 1) // b
15def round_up(x: int, y: int) -> int:
16 return ((x + y - 1) // y) * y
19@triton.jit(do_not_specialize=["numel", "tokens_per_thread"])
20def moe_align_block_size_stage1(
21 topk_ids_ptr,
22 tokens_cnts_ptr,
23 num_experts: tl.constexpr,
24 numel,
25 tokens_per_thread,
26):
27 pid = tl.program_id(0)
29 start_idx = pid * tokens_per_thread
31 off_c = (pid + 1) * num_experts
33 for i in range(tokens_per_thread):
34 if start_idx + i < numel:
35 idx = tl.load(topk_ids_ptr + start_idx + i)
36 token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
37 tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
40@triton.jit
41def moe_align_block_size_stage2(
42 tokens_cnts_ptr,
43 num_experts: tl.constexpr,
44):
45 pid = tl.program_id(0)
47 last_cnt = 0
48 for i in range(1, num_experts + 1):
49 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
50 last_cnt = last_cnt + token_cnt
51 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
54@triton.jit
55def moe_align_block_size_stage3(
56 total_tokens_post_pad_ptr,
57 tokens_cnts_ptr,
58 cumsum_ptr,
59 num_experts: tl.constexpr,
60 block_size: tl.constexpr,
61):
62 last_cumsum = 0
63 off_cnt = num_experts * num_experts
64 for i in range(1, num_experts + 1):
65 token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
66 last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
67 tl.store(cumsum_ptr + i, last_cumsum)
68 tl.store(total_tokens_post_pad_ptr, last_cumsum)
71@triton.jit(do_not_specialize=["numel", "tokens_per_thread"])
72def moe_align_block_size_stage4(
73 topk_ids_ptr,
74 sorted_token_ids_ptr,
75 expert_ids_ptr,
76 tokens_cnts_ptr,
77 cumsum_ptr,
78 num_experts: tl.constexpr,
79 block_size: tl.constexpr,
80 numel,
81 tokens_per_thread,
82):
83 pid = tl.program_id(0)
84 start_idx = tl.load(cumsum_ptr + pid)
85 end_idx = tl.load(cumsum_ptr + pid + 1)
87 for i in range(start_idx, end_idx, block_size):
88 tl.store(expert_ids_ptr + i // block_size, pid)
90 start_idx = pid * tokens_per_thread
91 off_t = pid * num_experts
93 for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
94 expert_id = tl.load(topk_ids_ptr + i)
95 token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
96 rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
97 tl.store(sorted_token_ids_ptr + rank_post_pad, i)
98 tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
101def moe_align_block_size_triton(
102 topk_ids: torch.Tensor,
103 num_experts: int,
104 block_size: int,
105 sorted_token_ids: torch.Tensor,
106 expert_ids: torch.Tensor,
107 num_tokens_post_pad: torch.Tensor,
108) -> None:
109 numel = topk_ids.numel()
110 grid = (num_experts,)
111 tokens_cnts = torch.zeros(
112 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
113 )
114 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
115 tokens_per_thread = ceil_div(numel, num_experts)
117 moe_align_block_size_stage1[grid](
118 topk_ids,
119 tokens_cnts,
120 num_experts,
121 numel,
122 tokens_per_thread,
123 )
124 moe_align_block_size_stage2[grid](
125 tokens_cnts,
126 num_experts,
127 )
128 moe_align_block_size_stage3[(1,)](
129 num_tokens_post_pad,
130 tokens_cnts,
131 cumsum,
132 num_experts,
133 block_size,
134 )
135 moe_align_block_size_stage4[grid](
136 topk_ids,
137 sorted_token_ids,
138 expert_ids,
139 tokens_cnts,
140 cumsum,
141 num_experts,
142 block_size,
143 numel,
144 tokens_per_thread,
145 )
148def moe_align_block_size(
149 topk_ids: torch.Tensor,
150 block_size: int,
151 num_experts: int,
152 expert_map: Optional[torch.Tensor] = None,
153 pad_sorted_ids: bool = False,
154) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]":
155 logger.debug("GEMS MOE ALIGN BLOCK SIZE")
156 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
157 if pad_sorted_ids:
158 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
159 sorted_ids = torch.empty(
160 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
161 )
162 sorted_ids.fill_(topk_ids.numel())
163 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
164 expert_ids = torch.zeros(
165 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
166 )
167 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
169 moe_align_block_size_triton(
170 topk_ids,
171 num_experts,
172 block_size,
173 sorted_ids,
174 expert_ids,
175 num_tokens_post_pad,
176 )
178 if expert_map is not None:
179 expert_ids = expert_map[expert_ids]
181 return sorted_ids, expert_ids, num_tokens_post_pad