Coverage for src/flag_gems/fused/moe_align_block_size.py: 24%
328 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
2from functools import lru_cache
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems.utils import libentry, libtuner
11try:
12 import triton.experimental.tle.language as tle
14 HAS_TLE = True
15except ImportError:
16 tle = None
17 HAS_TLE = False
20logger = logging.getLogger(__name__)
22TLE_CLUSTER_SIZE = 8
23TLE_BIG_TOKEN_THRESHOLD_TOKENS = 4096
24_TRITON_ALLOCATOR_INSTALLED = False
25TLE_ATOMIC_WARPS_CONFIGS = [
26 triton.Config(kwargs={}, num_warps=4),
27 triton.Config(kwargs={}, num_warps=8),
28]
29TLE_CLUSTER_LAUNCH_CONFIGS = [
30 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=4),
31 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=8),
32 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=4),
33 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=8),
34 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=4),
35 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=8),
36 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=4),
37 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=8),
38]
41def ceil_div(a, b):
42 return (a + b - 1) // b
45def round_up(x: int, y: int) -> int:
46 return ((x + y - 1) // y) * y
49@lru_cache(maxsize=64)
50def _block_mesh(num_blocks: int):
51 return tle.device_mesh({"block": [("block_x", int(num_blocks))]})
54@lru_cache(maxsize=1)
55def _block_cluster_mesh_8():
56 return tle.device_mesh({"block_cluster": [("cluster_x", TLE_CLUSTER_SIZE)]})
59def _supports_tle_cluster_remote() -> bool:
60 if not torch.cuda.is_available():
61 return False
62 major, _minor = torch.cuda.get_device_capability()
63 return major >= 9
66def _install_triton_default_allocator(device: torch.device) -> None:
67 global _TRITON_ALLOCATOR_INSTALLED
68 if _TRITON_ALLOCATOR_INSTALLED:
69 return
71 def _alloc(size: int, _alignment: int, _stream: Optional[int]):
72 return torch.empty((size,), dtype=torch.uint8, device=device)
74 triton.set_allocator(_alloc)
75 _TRITON_ALLOCATOR_INSTALLED = True
78def _pick_tle_fused_launch_params(numel: int, num_experts: int) -> "tuple[int, int]":
79 if num_experts >= 256:
80 if numel >= 32768:
81 return 4096, 4
82 if numel >= 1024:
83 return 1024, 4
84 return 256, 8
86 if numel <= 512:
87 return 128, 8
88 if num_experts <= 64 and numel <= 2048:
89 return 128, 8
90 return 256, 8
93def _pick_tle_atomic_fused_launch_params(
94 numel: int, num_experts: int
95) -> "tuple[int, int]":
96 if num_experts >= 256:
97 if numel <= 16384:
98 return 256, 8
99 if numel <= 32768:
100 return 512, 4
101 return 1024, 4
102 return _pick_tle_fused_launch_params(numel, num_experts)
105def _pick_tle_atomic_fused_num_blocks(
106 numel: int, num_experts: int, block_tokens: int, device: torch.device
107) -> int:
108 if device.type != "cuda" or not torch.cuda.is_available():
109 return 1
110 props = torch.cuda.get_device_properties(device)
111 sm_count = int(getattr(props, "multi_processor_count", 1))
112 token_programs = triton.cdiv(numel, block_tokens)
113 cap_mult = 4 if num_experts < 256 else 16
114 block_cap = sm_count * cap_mult
115 return max(1, min(token_programs, block_cap))
118@libentry()
119@libtuner(
120 configs=TLE_ATOMIC_WARPS_CONFIGS,
121 key=["numel"],
122 strategy=["log"],
123)
124@triton.jit(do_not_specialize=["numel"])
125def moe_align_block_size_tle_atomic_fused_coop(
126 topk_ids_ptr,
127 sorted_token_ids_ptr,
128 expert_ids_ptr,
129 num_tokens_post_pad_ptr,
130 cumsum_ptr,
131 mesh: tl.constexpr,
132 num_experts: tl.constexpr,
133 block_size: tl.constexpr,
134 numel,
135 numel_sorted_token_ids: tl.constexpr,
136 numel_expert_ids: tl.constexpr,
137 NUM_BLOCKS: tl.constexpr,
138 BLOCK_TOKENS: tl.constexpr,
139 BLOCK_EXPERT: tl.constexpr,
140 EXPERTS_PER_PROG: tl.constexpr,
141):
142 pid = tl.program_id(0)
143 expert_offsets = tl.arange(0, BLOCK_EXPERT)
144 expert_mask = expert_offsets < num_experts
145 token_offsets = tl.arange(0, BLOCK_TOKENS)
147 for base in range(
148 pid * BLOCK_TOKENS, numel_sorted_token_ids, NUM_BLOCKS * BLOCK_TOKENS
149 ):
150 offs = base + token_offsets
151 tl.store(sorted_token_ids_ptr + offs, numel, mask=offs < numel_sorted_token_ids)
152 for base in range(pid * BLOCK_TOKENS, numel_expert_ids, NUM_BLOCKS * BLOCK_TOKENS):
153 offs = base + token_offsets
154 tl.store(expert_ids_ptr + offs, 0, mask=offs < numel_expert_ids)
155 if pid == 0:
156 tl.store(cumsum_ptr + expert_offsets, 0, mask=expert_mask)
157 tle.distributed_barrier(mesh)
159 local_counts = tle.gpu.alloc(
160 [BLOCK_EXPERT],
161 dtype=tl.int32,
162 layout=None,
163 scope=tle.gpu.smem,
164 nv_mma_shared_layout=False,
165 )
166 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,))
167 tl.store(local_counts_ptrs, 0, mask=expert_mask)
169 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS):
170 offs = base + token_offsets
171 mask = offs < numel
172 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
173 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
174 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta")
176 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0)
177 prefix_before = tl.atomic_add(
178 cumsum_ptr + expert_offsets,
179 local_counts_vals,
180 mask=expert_mask,
181 sem="acq_rel",
182 scope="gpu",
183 )
184 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask)
185 tle.distributed_barrier(mesh)
187 if pid == 0:
188 total_counts = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0)
189 aligned_counts = tl.cdiv(total_counts, block_size) * block_size
190 expert_starts = tl.cumsum(aligned_counts, axis=0) - aligned_counts
191 tl.store(cumsum_ptr + expert_offsets, expert_starts, mask=expert_mask)
192 total_tokens = tl.sum(aligned_counts, axis=0)
193 tl.store(num_tokens_post_pad_ptr, total_tokens)
194 tle.distributed_barrier(mesh)
196 expert_starts_local = tle.gpu.alloc(
197 [BLOCK_EXPERT],
198 dtype=tl.int32,
199 layout=None,
200 scope=tle.gpu.smem,
201 nv_mma_shared_layout=False,
202 )
203 expert_starts_ptrs = tle.gpu.local_ptr(expert_starts_local, (expert_offsets,))
204 expert_starts_vals = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0)
205 tl.store(expert_starts_ptrs, expert_starts_vals, mask=expert_mask)
207 total_tokens = tl.load(num_tokens_post_pad_ptr)
208 for local_expert_idx in range(EXPERTS_PER_PROG):
209 expert_id = pid + local_expert_idx * NUM_BLOCKS
210 valid_expert = expert_id < num_experts
211 start_idx = tl.load(
212 tle.gpu.local_ptr(expert_starts_local, (expert_id,)),
213 mask=valid_expert,
214 other=0,
215 )
216 next_expert = expert_id + 1
217 has_next = valid_expert & (next_expert < num_experts)
218 end_idx = tl.load(
219 tle.gpu.local_ptr(expert_starts_local, (next_expert,)),
220 mask=has_next,
221 other=total_tokens,
222 )
223 end_idx = tl.where(has_next, end_idx, total_tokens)
224 start_idx = tl.where(valid_expert, start_idx, 0)
225 end_idx = tl.where(valid_expert, end_idx, 0)
226 for i in range(start_idx, end_idx, block_size):
227 tl.store(expert_ids_ptr + i // block_size, expert_id)
229 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS):
230 offs = base + token_offsets
231 mask = offs < numel
232 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
233 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
234 rank_with_prefix = tl.atomic_add(
235 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta"
236 )
237 rank_base = tl.load(
238 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), mask=mask, other=0
239 )
240 rank_post_pad = rank_with_prefix + rank_base
241 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask)
244@libentry()
245@libtuner(
246 configs=TLE_CLUSTER_LAUNCH_CONFIGS,
247 key=["numel"],
248 strategy=["log"],
249)
250@triton.jit(do_not_specialize=["numel"])
251def moe_align_block_size_tle_cluster_fused(
252 topk_ids_ptr,
253 sorted_token_ids_ptr,
254 expert_ids_ptr,
255 num_tokens_post_pad_ptr,
256 num_experts: tl.constexpr,
257 block_size: tl.constexpr,
258 numel,
259 numel_sorted_token_ids: tl.constexpr,
260 numel_expert_ids: tl.constexpr,
261 mesh: tl.constexpr,
262 CLUSTER_SIZE: tl.constexpr,
263 BLOCK_TOKENS: tl.constexpr,
264 BLOCK_EXPERT: tl.constexpr,
265 EXPERTS_PER_SHARD: tl.constexpr,
266):
267 cluster_rank = tle.shard_id(mesh, "cluster_x")
268 is_rank0 = cluster_rank == 0
269 expert_offsets = tl.arange(0, BLOCK_EXPERT)
270 expert_mask = expert_offsets < num_experts
272 init_offsets = tl.arange(0, BLOCK_TOKENS)
273 for base in range(
274 cluster_rank * BLOCK_TOKENS, numel_sorted_token_ids, CLUSTER_SIZE * BLOCK_TOKENS
275 ):
276 offs = base + init_offsets
277 mask = offs < numel_sorted_token_ids
278 tl.store(sorted_token_ids_ptr + offs, numel, mask=mask)
279 for base in range(
280 cluster_rank * BLOCK_TOKENS, numel_expert_ids, CLUSTER_SIZE * BLOCK_TOKENS
281 ):
282 offs = base + init_offsets
283 mask = offs < numel_expert_ids
284 tl.store(expert_ids_ptr + offs, 0, mask=mask)
286 local_counts = tle.gpu.alloc(
287 [BLOCK_EXPERT],
288 dtype=tl.int32,
289 layout=None,
290 scope=tle.gpu.smem,
291 nv_mma_shared_layout=False,
292 )
293 cumsum_local = tle.gpu.alloc(
294 [BLOCK_EXPERT],
295 dtype=tl.int32,
296 layout=None,
297 scope=tle.gpu.smem,
298 nv_mma_shared_layout=False,
299 )
301 rank0_cumsum_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_offsets,))
302 if is_rank0:
303 tl.store(rank0_cumsum_ptrs, 0, mask=expert_mask)
304 tle.distributed_barrier(mesh)
306 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,))
307 tl.store(local_counts_ptrs, 0, mask=expert_mask)
309 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS):
310 offs = base + init_offsets
311 mask = offs < numel
312 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
313 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
314 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta")
316 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0)
317 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh)
318 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,))
319 prefix_before = tl.atomic_add(
320 rank0_cumsum_remote_ptrs,
321 local_counts_vals,
322 mask=expert_mask,
323 sem="relaxed",
324 scope="cta",
325 )
326 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask)
328 tle.distributed_barrier(mesh)
330 if is_rank0:
331 total_counts = tl.load(rank0_cumsum_ptrs, mask=expert_mask, other=0)
332 aligned_counts = tl.cdiv(total_counts, block_size) * block_size
333 expert_cumsum_inclusive = tl.cumsum(aligned_counts, axis=0)
334 expert_start_offsets = expert_cumsum_inclusive - aligned_counts
335 tl.store(rank0_cumsum_ptrs, expert_start_offsets, mask=expert_mask)
336 total_tokens = tl.sum(aligned_counts, axis=0)
337 tl.store(num_tokens_post_pad_ptr, total_tokens)
339 tle.distributed_barrier(mesh)
341 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh)
342 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,))
343 cumsum_vals = tl.load(rank0_cumsum_remote_ptrs, mask=expert_mask, other=0)
344 tl.store(
345 tle.gpu.local_ptr(cumsum_local, (expert_offsets,)),
346 cumsum_vals,
347 mask=expert_mask,
348 )
349 total_tokens = tl.load(num_tokens_post_pad_ptr)
351 for local_expert_idx in range(EXPERTS_PER_SHARD):
352 expert_idx = cluster_rank * EXPERTS_PER_SHARD + local_expert_idx
353 expert_id = expert_idx
354 valid_expert = expert_id < num_experts
355 start_ptr = tle.gpu.local_ptr(cumsum_local, (expert_id,))
356 start_idx = tl.load(start_ptr, mask=valid_expert, other=0)
357 next_expert_id = expert_id + 1
358 has_next = valid_expert & (next_expert_id < num_experts)
359 next_ptr = tle.gpu.local_ptr(cumsum_local, (next_expert_id,))
360 end_from_next = tl.load(next_ptr, mask=has_next, other=0)
361 end_idx = tl.where(has_next, end_from_next, total_tokens)
362 start_idx = tl.where(valid_expert, start_idx, 0)
363 end_idx = tl.where(valid_expert, end_idx, 0)
364 for i in range(start_idx, end_idx, block_size):
365 tl.store(expert_ids_ptr + i // block_size, expert_idx)
367 tle.distributed_barrier(mesh)
369 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS):
370 offs = base + init_offsets
371 mask = offs < numel
372 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
373 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
374 rank_with_prefix = tl.atomic_add(
375 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta"
376 )
377 base_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_id,))
378 rank_base = tl.load(base_ptrs, mask=mask, other=0)
379 rank_post_pad = rank_with_prefix + rank_base
380 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask)
383@triton.jit(do_not_specialize=["numel"])
384def moe_align_block_size_stage1(
385 topk_ids_ptr,
386 tokens_cnts_ptr,
387 num_experts: tl.constexpr,
388 numel,
389 tokens_per_thread: tl.constexpr,
390 sorted_token_ids_ptr,
391 expert_ids_ptr,
392 numel_sorted_token_ids: tl.constexpr,
393 numel_expert_ids: tl.constexpr,
394 block_size_sorted: tl.constexpr,
395 block_size_expert: tl.constexpr,
396):
397 pid = tl.program_id(0)
399 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted)
400 mask_sorted = offsets_sorted < numel_sorted_token_ids
401 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted)
403 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert)
404 mask_expert = offsets_expert < numel_expert_ids
405 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert)
407 start_idx = pid * tokens_per_thread
409 off_c = (pid + 1) * num_experts
411 offsets = start_idx + tl.arange(0, tokens_per_thread)
412 mask = offsets < numel
413 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0)
414 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1, mask=mask)
417@triton.jit
418def moe_align_block_size_stage2_vec(
419 tokens_cnts_ptr,
420 num_experts: tl.constexpr,
421):
422 pid = tl.program_id(0)
424 offset = tl.arange(0, num_experts) + 1
425 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid)
426 cnt = tl.cumsum(token_cnt, axis=0)
427 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt)
430@triton.jit
431def moe_align_block_size_stage2(
432 tokens_cnts_ptr,
433 num_experts: tl.constexpr,
434):
435 pid = tl.program_id(0)
437 last_cnt = 0
438 for i in range(1, num_experts + 1):
439 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
440 last_cnt = last_cnt + token_cnt
441 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
444@triton.jit
445def moe_align_block_size_stage3(
446 total_tokens_post_pad_ptr,
447 tokens_cnts_ptr,
448 cumsum_ptr,
449 num_experts: tl.constexpr,
450 num_experts_next_power_of_2: tl.constexpr,
451 block_size: tl.constexpr,
452):
453 off_cnt = num_experts * num_experts
455 expert_offsets = tl.arange(0, num_experts_next_power_of_2)
456 mask = expert_offsets < num_experts
457 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask)
458 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size
460 cumsum_values = tl.cumsum(aligned_cnts, axis=0)
461 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask)
463 total_tokens = tl.sum(aligned_cnts, axis=0)
464 tl.store(total_tokens_post_pad_ptr, total_tokens)
467@triton.jit(do_not_specialize=["numel"])
468def moe_align_block_size_stage4(
469 topk_ids_ptr,
470 sorted_token_ids_ptr,
471 expert_ids_ptr,
472 tokens_cnts_ptr,
473 cumsum_ptr,
474 num_experts: tl.constexpr,
475 block_size: tl.constexpr,
476 numel,
477 tokens_per_thread: tl.constexpr,
478):
479 pid = tl.program_id(0)
480 start_idx = tl.load(cumsum_ptr + pid)
481 end_idx = tl.load(cumsum_ptr + pid + 1)
483 for i in range(start_idx, end_idx, block_size):
484 tl.store(expert_ids_ptr + i // block_size, pid)
486 start_idx = pid * tokens_per_thread
487 off_t = pid * num_experts
489 offset = tl.arange(0, tokens_per_thread) + start_idx
490 mask = offset < numel
491 expert_id = tl.load(topk_ids_ptr + offset, mask=mask)
492 token_idx_in_expert = tl.atomic_add(
493 tokens_cnts_ptr + off_t + expert_id, 1, mask=mask
494 )
495 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id, mask=mask)
496 tl.store(sorted_token_ids_ptr + rank_post_pad, offset, mask=mask)
499def moe_align_block_size_triton(
500 topk_ids: torch.Tensor,
501 num_experts: int,
502 block_size: int,
503 sorted_token_ids: torch.Tensor,
504 expert_ids: torch.Tensor,
505 num_tokens_post_pad: torch.Tensor,
506) -> None:
507 numel = topk_ids.numel()
508 numel_sorted_token_ids = sorted_token_ids.numel()
509 numel_expert_ids = expert_ids.numel()
510 grid = (num_experts,)
511 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts))
512 block_size_sorted = triton.next_power_of_2(
513 ceil_div(numel_sorted_token_ids, num_experts)
514 )
515 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts))
516 block_expert_tle = triton.next_power_of_2(num_experts)
518 if HAS_TLE and topk_ids.is_cuda and block_expert_tle <= 1024:
519 block_tokens_taf, _ = _pick_tle_atomic_fused_launch_params(numel, num_experts)
520 experts_per_shard = ceil_div(num_experts, TLE_CLUSTER_SIZE)
521 num_tokens = topk_ids.shape[0] if topk_ids.ndim > 1 else numel
523 def _run_tle_atomic_fused() -> bool:
524 cumsum_tle = torch.zeros(
525 (num_experts,), dtype=torch.int32, device=topk_ids.device
526 )
527 num_blocks = _pick_tle_atomic_fused_num_blocks(
528 numel,
529 num_experts,
530 block_tokens_taf,
531 topk_ids.device,
532 )
533 experts_per_prog = ceil_div(num_experts, num_blocks)
534 while True:
535 try:
536 moe_align_block_size_tle_atomic_fused_coop[(num_blocks,)](
537 topk_ids,
538 sorted_token_ids,
539 expert_ids,
540 num_tokens_post_pad,
541 cumsum_tle,
542 _block_mesh(num_blocks),
543 num_experts,
544 block_size,
545 numel,
546 numel_sorted_token_ids,
547 numel_expert_ids,
548 NUM_BLOCKS=num_blocks,
549 BLOCK_TOKENS=block_tokens_taf,
550 BLOCK_EXPERT=block_expert_tle,
551 EXPERTS_PER_PROG=experts_per_prog,
552 launch_cooperative_grid=True,
553 )
554 return True
555 except Exception as ex:
556 msg = str(ex).lower()
557 if "no allocator was set" in msg:
558 _install_triton_default_allocator(topk_ids.device)
559 continue
560 if num_blocks <= 1 or "cooperative" not in msg:
561 logger.debug(
562 "TLE atomic fused launch failed, fallback to triton: %s",
563 ex,
564 )
565 return False
566 num_blocks = max(1, num_blocks // 2)
567 experts_per_prog = ceil_div(num_experts, num_blocks)
569 if (
570 num_tokens < TLE_BIG_TOKEN_THRESHOLD_TOKENS
571 and _supports_tle_cluster_remote()
572 ):
573 try:
574 moe_align_block_size_tle_cluster_fused[(1,)](
575 topk_ids,
576 sorted_token_ids,
577 expert_ids,
578 num_tokens_post_pad,
579 num_experts,
580 block_size,
581 numel,
582 numel_sorted_token_ids,
583 numel_expert_ids,
584 mesh=_block_cluster_mesh_8(),
585 CLUSTER_SIZE=TLE_CLUSTER_SIZE,
586 BLOCK_EXPERT=block_expert_tle,
587 EXPERTS_PER_SHARD=experts_per_shard,
588 )
589 return
590 except Exception as ex:
591 logger.debug(
592 "TLE cluster fused launch failed, fallback to atomic/triton: %s",
593 ex,
594 )
596 if _run_tle_atomic_fused():
597 return
599 # The tensor needs to be padded before calculating IDs,
600 # to prevent out-of-bounds address access.
601 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
602 tokens_cnts = torch.zeros(
603 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
604 )
605 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts)
607 moe_align_block_size_stage1[grid](
608 topk_ids,
609 tokens_cnts,
610 num_experts,
611 numel,
612 tokens_per_thread,
613 sorted_token_ids,
614 expert_ids,
615 numel_sorted_token_ids,
616 numel_expert_ids,
617 block_size_sorted,
618 block_size_expert,
619 )
620 if num_experts == triton.next_power_of_2(num_experts):
621 moe_align_block_size_stage2_vec[grid](tokens_cnts, num_experts)
622 else:
623 moe_align_block_size_stage2[grid](tokens_cnts, num_experts)
624 moe_align_block_size_stage3[(1,)](
625 num_tokens_post_pad,
626 tokens_cnts,
627 cumsum,
628 num_experts,
629 num_experts_next_power_of_2,
630 block_size,
631 )
632 moe_align_block_size_stage4[grid](
633 topk_ids,
634 sorted_token_ids,
635 expert_ids,
636 tokens_cnts,
637 cumsum,
638 num_experts,
639 block_size,
640 numel,
641 tokens_per_thread,
642 )
645def moe_align_block_size(
646 topk_ids: torch.Tensor,
647 block_size: int,
648 num_experts: int,
649 expert_map: Optional[torch.Tensor] = None,
650 pad_sorted_ids: bool = False,
651) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]":
652 logger.debug("GEMS MOE ALIGN BLOCK SIZE")
653 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
654 if pad_sorted_ids:
655 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
656 sorted_ids = torch.empty(
657 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
658 )
659 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
660 expert_ids = torch.empty(
661 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
662 )
663 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
665 moe_align_block_size_triton(
666 topk_ids,
667 num_experts,
668 block_size,
669 sorted_ids,
670 expert_ids,
671 num_tokens_post_pad,
672 )
674 if expert_map is not None:
675 expert_ids = expert_map[expert_ids]
677 return sorted_ids, expert_ids, num_tokens_post_pad