Coverage for src/flag_gems/runtime/backend/_tsingmicro/fused/moe_align_block_size.py: 0%
136 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2from typing import Optional
4import torch
5import triton
6import triton.language as tl
8try:
9 import triton.experimental.tle.language.gpu as tle
11 HAS_TLE = True
12except ImportError:
13 HAS_TLE = False
16logger = logging.getLogger(__name__)
19def ceil_div(a, b):
20 return (a + b - 1) // b
23def round_up(x: int, y: int) -> int:
24 return ((x + y - 1) // y) * y
27@triton.jit(do_not_specialize=["numel"])
28def moe_align_block_size_stage1_tle(
29 topk_ids_ptr,
30 tokens_cnts_ptr,
31 num_experts: tl.constexpr,
32 numel,
33 tokens_per_thread: tl.constexpr,
34 sorted_token_ids_ptr,
35 expert_ids_ptr,
36 numel_sorted_token_ids: tl.constexpr,
37 numel_expert_ids: tl.constexpr,
38 block_size_sorted: tl.constexpr,
39 block_size_expert: tl.constexpr,
40 BLOCK_EXPERT: tl.constexpr,
41):
42 pid = tl.program_id(0)
44 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted)
45 mask_sorted = offsets_sorted < numel_sorted_token_ids
46 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted)
48 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert)
49 mask_expert = offsets_expert < numel_expert_ids
50 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert)
52 start_idx = pid * tokens_per_thread
53 off_c = (pid + 1) * num_experts
55 offsets = start_idx + tl.arange(0, tokens_per_thread)
56 mask = offsets < numel
57 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0).to(tl.int32)
58 valid = mask & (expert_id < num_experts)
59 expert_id = tl.where(valid, expert_id, 0)
61 expert_offsets = tl.arange(0, BLOCK_EXPERT)
62 expert_mask = expert_offsets < num_experts
64 smem_counts = tle.alloc(
65 [BLOCK_EXPERT],
66 dtype=tl.int32,
67 layout=None,
68 scope=tle.smem,
69 nv_mma_shared_layout=False,
70 )
71 smem_ptrs = tle.local_ptr(smem_counts, (expert_offsets,))
72 tl.store(smem_ptrs, 0)
73 tl.debug_barrier()
75 count_ptrs = tle.local_ptr(smem_counts, (expert_id,))
76 tl.atomic_add(count_ptrs, 1, mask=valid, sem="relaxed", scope="cta")
77 tl.debug_barrier()
79 counts = tl.load(smem_ptrs, mask=expert_mask, other=0)
80 tl.store(tokens_cnts_ptr + off_c + expert_offsets, counts, mask=expert_mask)
83@triton.jit(do_not_specialize=["numel"])
84def moe_align_block_size_stage1(
85 topk_ids_ptr,
86 tokens_cnts_ptr,
87 num_experts: tl.constexpr,
88 numel,
89 tokens_per_thread: tl.constexpr,
90 sorted_token_ids_ptr,
91 expert_ids_ptr,
92 numel_sorted_token_ids: tl.constexpr,
93 numel_expert_ids: tl.constexpr,
94 block_size_sorted: tl.constexpr,
95 block_size_expert: tl.constexpr,
96):
97 pid = tl.program_id(0)
99 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted)
100 mask_sorted = offsets_sorted < numel_sorted_token_ids
101 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted)
103 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert)
104 mask_expert = offsets_expert < numel_expert_ids
105 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert)
107 start_idx = pid * tokens_per_thread
109 off_c = (pid + 1) * num_experts
111 for i in range(tokens_per_thread):
112 if start_idx + i < numel:
113 idx = tl.load(topk_ids_ptr + start_idx + i)
114 token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
115 tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
118@triton.jit
119def moe_align_block_size_stage2_vec(
120 tokens_cnts_ptr,
121 num_experts: tl.constexpr,
122 experts_per_cta: tl.constexpr,
123):
124 pid = tl.program_id(0)
126 # row: 行索引 1..num_experts,shape [num_experts, 1]
127 row = (tl.arange(0, num_experts) + 1)[:, None]
128 # col: 本 CTA 负责的列范围,shape [1, experts_per_cta]
129 col = (tl.arange(0, experts_per_cta) + pid * experts_per_cta)[None, :]
131 addr = row * num_experts + col # [num_experts, experts_per_cta]
133 token_cnt = tl.load(tokens_cnts_ptr + addr)
134 cnt = tl.cumsum(token_cnt, axis=0)
135 tl.store(tokens_cnts_ptr + addr, cnt)
138@triton.jit
139def moe_align_block_size_stage2(
140 tokens_cnts_ptr,
141 num_experts: tl.constexpr,
142):
143 pid = tl.program_id(0)
145 last_cnt = 0
146 for i in range(1, num_experts + 1):
147 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
148 last_cnt = last_cnt + token_cnt
149 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
152@triton.jit
153def moe_align_block_size_stage3(
154 total_tokens_post_pad_ptr,
155 tokens_cnts_ptr,
156 cumsum_ptr,
157 num_experts: tl.constexpr,
158 num_experts_next_power_of_2: tl.constexpr,
159 block_size: tl.constexpr,
160):
161 off_cnt = num_experts * num_experts
163 expert_offsets = tl.arange(0, num_experts_next_power_of_2)
164 mask = expert_offsets < num_experts
165 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask)
166 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size
168 cumsum_values = tl.cumsum(aligned_cnts, axis=0)
169 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask)
171 total_tokens = tl.sum(aligned_cnts, axis=0)
172 tl.store(total_tokens_post_pad_ptr, total_tokens)
175@triton.jit(do_not_specialize=["numel"])
176def moe_align_block_size_stage4(
177 topk_ids_ptr,
178 sorted_token_ids_ptr,
179 expert_ids_ptr,
180 tokens_cnts_ptr,
181 cumsum_ptr,
182 num_experts: tl.constexpr,
183 block_size: tl.constexpr,
184 numel,
185 tokens_per_thread: tl.constexpr,
186):
187 pid = tl.program_id(0)
188 start_idx = tl.load(cumsum_ptr + pid)
189 end_idx = tl.load(cumsum_ptr + pid + 1)
191 for i in range(start_idx, end_idx, block_size):
192 tl.store(expert_ids_ptr + i // block_size, pid)
194 start_idx = pid * tokens_per_thread
195 off_t = pid * num_experts
197 for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
198 expert_id = tl.load(topk_ids_ptr + i)
199 token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
200 rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
201 tl.store(sorted_token_ids_ptr + rank_post_pad, i)
202 tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
205def moe_align_block_size_triton(
206 topk_ids: torch.Tensor,
207 num_experts: int,
208 block_size: int,
209 sorted_token_ids: torch.Tensor,
210 expert_ids: torch.Tensor,
211 num_tokens_post_pad: torch.Tensor,
212) -> None:
213 numel = topk_ids.numel()
214 numel_sorted_token_ids = sorted_token_ids.numel()
215 numel_expert_ids = expert_ids.numel()
216 # The tensor needs to be padded before calculating IDs,
217 # to prevent out-of-bounds address access.
219 grid = (num_experts,)
220 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
221 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts))
223 block_size_sorted = triton.next_power_of_2(
224 ceil_div(numel_sorted_token_ids, num_experts)
225 )
226 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts))
228 tokens_cnts = torch.zeros(
229 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
230 )
231 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts)
232 block_expert = triton.cdiv(num_experts, 32) * 32
234 if HAS_TLE:
235 moe_align_block_size_stage1_tle[grid](
236 topk_ids,
237 tokens_cnts,
238 num_experts,
239 numel,
240 tokens_per_thread,
241 sorted_token_ids,
242 expert_ids,
243 numel_sorted_token_ids,
244 numel_expert_ids,
245 block_size_sorted,
246 block_size_expert,
247 BLOCK_EXPERT=block_expert,
248 )
249 else:
250 moe_align_block_size_stage1[grid](
251 topk_ids,
252 tokens_cnts,
253 num_experts,
254 numel,
255 tokens_per_thread,
256 sorted_token_ids,
257 expert_ids,
258 numel_sorted_token_ids,
259 numel_expert_ids,
260 block_size_sorted,
261 block_size_expert,
262 )
263 if num_experts == triton.next_power_of_2(num_experts):
264 experts_per_cta = num_experts // 16
265 grid2 = (num_experts // experts_per_cta,)
266 moe_align_block_size_stage2_vec[grid2](
267 tokens_cnts,
268 num_experts,
269 experts_per_cta,
270 )
271 else:
272 moe_align_block_size_stage2[grid](
273 tokens_cnts,
274 num_experts,
275 )
276 moe_align_block_size_stage3[(1,)](
277 num_tokens_post_pad,
278 tokens_cnts,
279 cumsum,
280 num_experts,
281 num_experts_next_power_of_2,
282 block_size,
283 )
284 moe_align_block_size_stage4[grid](
285 topk_ids,
286 sorted_token_ids,
287 expert_ids,
288 tokens_cnts,
289 cumsum,
290 num_experts,
291 block_size,
292 numel,
293 tokens_per_thread,
294 )
297def moe_align_block_size(
298 topk_ids: torch.Tensor,
299 block_size: int,
300 num_experts: int,
301 expert_map: Optional[torch.Tensor] = None,
302 pad_sorted_ids: bool = False,
303) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]":
304 logger.debug("GEMS MOE ALIGN BLOCK SIZE")
305 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
306 if pad_sorted_ids:
307 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
308 sorted_ids = torch.empty(
309 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
310 )
311 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
312 expert_ids = torch.empty(
313 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
314 )
315 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
317 moe_align_block_size_triton(
318 topk_ids,
319 num_experts,
320 block_size,
321 sorted_ids,
322 expert_ids,
323 num_tokens_post_pad,
324 )
326 if expert_map is not None:
327 expert_ids = expert_map[expert_ids]
329 return sorted_ids, expert_ids, num_tokens_post_pad