Coverage for src/flag_gems/fused/moe_align_block_size.py: 34%
131 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2from typing import Optional
4import torch
5import triton
6import triton.language as tl
8try:
9 import triton.experimental.tle.language.gpu as tle
11 HAS_TLE = True
12except ImportError:
13 HAS_TLE = False
16logger = logging.getLogger(__name__)
19def ceil_div(a, b):
20 return (a + b - 1) // b
23def round_up(x: int, y: int) -> int:
24 return ((x + y - 1) // y) * y
27@triton.jit(do_not_specialize=["numel"])
28def moe_align_block_size_stage1_tle(
29 topk_ids_ptr,
30 tokens_cnts_ptr,
31 num_experts: tl.constexpr,
32 numel,
33 tokens_per_thread: tl.constexpr,
34 sorted_token_ids_ptr,
35 expert_ids_ptr,
36 numel_sorted_token_ids: tl.constexpr,
37 numel_expert_ids: tl.constexpr,
38 block_size_sorted: tl.constexpr,
39 block_size_expert: tl.constexpr,
40 BLOCK_EXPERT: tl.constexpr,
41):
42 pid = tl.program_id(0)
44 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted)
45 mask_sorted = offsets_sorted < numel_sorted_token_ids
46 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted)
48 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert)
49 mask_expert = offsets_expert < numel_expert_ids
50 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert)
52 start_idx = pid * tokens_per_thread
53 off_c = (pid + 1) * num_experts
55 offsets = start_idx + tl.arange(0, tokens_per_thread)
56 mask = offsets < numel
57 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0).to(tl.int32)
58 valid = mask & (expert_id < num_experts)
59 expert_id = tl.where(valid, expert_id, 0)
61 expert_offsets = tl.arange(0, BLOCK_EXPERT)
62 expert_mask = expert_offsets < num_experts
64 smem_counts = tle.alloc(
65 [BLOCK_EXPERT],
66 dtype=tl.int32,
67 layout=None,
68 scope=tle.smem,
69 nv_mma_shared_layout=False,
70 )
71 smem_ptrs = tle.local_ptr(smem_counts, (expert_offsets,))
72 tl.store(smem_ptrs, 0)
73 tl.debug_barrier()
75 count_ptrs = tle.local_ptr(smem_counts, (expert_id,))
76 tl.atomic_add(count_ptrs, 1, mask=valid, sem="relaxed", scope="cta")
77 tl.debug_barrier()
79 counts = tl.load(smem_ptrs, mask=expert_mask, other=0)
80 tl.store(tokens_cnts_ptr + off_c + expert_offsets, counts, mask=expert_mask)
83@triton.jit(do_not_specialize=["numel"])
84def moe_align_block_size_stage1(
85 topk_ids_ptr,
86 tokens_cnts_ptr,
87 num_experts: tl.constexpr,
88 numel,
89 tokens_per_thread: tl.constexpr,
90 sorted_token_ids_ptr,
91 expert_ids_ptr,
92 numel_sorted_token_ids: tl.constexpr,
93 numel_expert_ids: tl.constexpr,
94 block_size_sorted: tl.constexpr,
95 block_size_expert: tl.constexpr,
96):
97 pid = tl.program_id(0)
99 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted)
100 mask_sorted = offsets_sorted < numel_sorted_token_ids
101 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted)
103 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert)
104 mask_expert = offsets_expert < numel_expert_ids
105 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert)
107 start_idx = pid * tokens_per_thread
109 off_c = (pid + 1) * num_experts
111 offsets = start_idx + tl.arange(0, tokens_per_thread)
112 mask = offsets < numel
113 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0)
114 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1, mask=mask)
117@triton.jit
118def moe_align_block_size_stage2_vec(
119 tokens_cnts_ptr,
120 num_experts: tl.constexpr,
121):
122 pid = tl.program_id(0)
124 offset = tl.arange(0, num_experts) + 1
125 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid)
126 cnt = tl.cumsum(token_cnt, axis=0)
127 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt)
130@triton.jit
131def moe_align_block_size_stage2(
132 tokens_cnts_ptr,
133 num_experts: tl.constexpr,
134):
135 pid = tl.program_id(0)
137 last_cnt = 0
138 for i in range(1, num_experts + 1):
139 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
140 last_cnt = last_cnt + token_cnt
141 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
144@triton.jit
145def moe_align_block_size_stage3(
146 total_tokens_post_pad_ptr,
147 tokens_cnts_ptr,
148 cumsum_ptr,
149 num_experts: tl.constexpr,
150 num_experts_next_power_of_2: tl.constexpr,
151 block_size: tl.constexpr,
152):
153 off_cnt = num_experts * num_experts
155 expert_offsets = tl.arange(0, num_experts_next_power_of_2)
156 mask = expert_offsets < num_experts
157 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask)
158 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size
160 cumsum_values = tl.cumsum(aligned_cnts, axis=0)
161 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask)
163 total_tokens = tl.sum(aligned_cnts, axis=0)
164 tl.store(total_tokens_post_pad_ptr, total_tokens)
167@triton.jit(do_not_specialize=["numel"])
168def moe_align_block_size_stage4(
169 topk_ids_ptr,
170 sorted_token_ids_ptr,
171 expert_ids_ptr,
172 tokens_cnts_ptr,
173 cumsum_ptr,
174 num_experts: tl.constexpr,
175 block_size: tl.constexpr,
176 numel,
177 tokens_per_thread: tl.constexpr,
178):
179 pid = tl.program_id(0)
180 start_idx = tl.load(cumsum_ptr + pid)
181 end_idx = tl.load(cumsum_ptr + pid + 1)
183 for i in range(start_idx, end_idx, block_size):
184 tl.store(expert_ids_ptr + i // block_size, pid)
186 start_idx = pid * tokens_per_thread
187 off_t = pid * num_experts
189 offset = tl.arange(0, tokens_per_thread) + start_idx
190 mask = offset < numel
191 expert_id = tl.load(topk_ids_ptr + offset, mask=mask)
192 token_idx_in_expert = tl.atomic_add(
193 tokens_cnts_ptr + off_t + expert_id, 1, mask=mask
194 )
195 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id, mask=mask)
196 tl.store(sorted_token_ids_ptr + rank_post_pad, offset, mask=mask)
199def moe_align_block_size_triton(
200 topk_ids: torch.Tensor,
201 num_experts: int,
202 block_size: int,
203 sorted_token_ids: torch.Tensor,
204 expert_ids: torch.Tensor,
205 num_tokens_post_pad: torch.Tensor,
206) -> None:
207 numel = topk_ids.numel()
208 numel_sorted_token_ids = sorted_token_ids.numel()
209 numel_expert_ids = expert_ids.numel()
210 # The tensor needs to be padded before calculating IDs,
211 # to prevent out-of-bounds address access.
213 grid = (num_experts,)
214 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
215 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts))
217 block_size_sorted = triton.next_power_of_2(
218 ceil_div(numel_sorted_token_ids, num_experts)
219 )
220 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts))
222 tokens_cnts = torch.zeros(
223 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
224 )
225 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts)
226 block_expert = triton.cdiv(num_experts, 32) * 32
228 if HAS_TLE:
229 moe_align_block_size_stage1_tle[grid](
230 topk_ids,
231 tokens_cnts,
232 num_experts,
233 numel,
234 tokens_per_thread,
235 sorted_token_ids,
236 expert_ids,
237 numel_sorted_token_ids,
238 numel_expert_ids,
239 block_size_sorted,
240 block_size_expert,
241 BLOCK_EXPERT=block_expert,
242 )
243 else:
244 moe_align_block_size_stage1[grid](
245 topk_ids,
246 tokens_cnts,
247 num_experts,
248 numel,
249 tokens_per_thread,
250 sorted_token_ids,
251 expert_ids,
252 numel_sorted_token_ids,
253 numel_expert_ids,
254 block_size_sorted,
255 block_size_expert,
256 )
257 if num_experts == triton.next_power_of_2(num_experts):
258 moe_align_block_size_stage2_vec[grid](
259 tokens_cnts,
260 num_experts,
261 )
262 else:
263 moe_align_block_size_stage2[grid](
264 tokens_cnts,
265 num_experts,
266 )
267 moe_align_block_size_stage3[(1,)](
268 num_tokens_post_pad,
269 tokens_cnts,
270 cumsum,
271 num_experts,
272 num_experts_next_power_of_2,
273 block_size,
274 )
275 moe_align_block_size_stage4[grid](
276 topk_ids,
277 sorted_token_ids,
278 expert_ids,
279 tokens_cnts,
280 cumsum,
281 num_experts,
282 block_size,
283 numel,
284 tokens_per_thread,
285 )
288def moe_align_block_size(
289 topk_ids: torch.Tensor,
290 block_size: int,
291 num_experts: int,
292 expert_map: Optional[torch.Tensor] = None,
293 pad_sorted_ids: bool = False,
294) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]":
295 logger.debug("GEMS MOE ALIGN BLOCK SIZE")
296 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
297 if pad_sorted_ids:
298 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
299 sorted_ids = torch.empty(
300 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
301 )
302 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
303 expert_ids = torch.empty(
304 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
305 )
306 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
308 moe_align_block_size_triton(
309 topk_ids,
310 num_experts,
311 block_size,
312 sorted_ids,
313 expert_ids,
314 num_tokens_post_pad,
315 )
317 if expert_map is not None:
318 expert_ids = expert_map[expert_ids]
320 return sorted_ids, expert_ids, num_tokens_post_pad