Coverage for src/flag_gems/fused/moe_align_block_size.py: 23%
334 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2from functools import lru_cache
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems.utils import has_triton_tle, libentry, libtuner
11if has_triton_tle(3, 6, 0):
12 try:
13 import triton.experimental.tle.language as tle
14 import triton.experimental.tle.language.gpu as tleg
16 HAS_TLE = True
17 except ImportError:
18 tle = None
19 tleg = None
20 HAS_TLE = False
21else:
22 tle = None
23 tleg = None
24 HAS_TLE = False
27logger = logging.getLogger(__name__)
29TLE_CLUSTER_SIZE = 8
30TLE_BIG_TOKEN_THRESHOLD_TOKENS = 4096
31_TRITON_ALLOCATOR_INSTALLED = False
32TLE_ATOMIC_WARPS_CONFIGS = [
33 triton.Config(kwargs={}, num_warps=4),
34 triton.Config(kwargs={}, num_warps=8),
35]
36TLE_CLUSTER_LAUNCH_CONFIGS = [
37 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=4),
38 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=8),
39 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=4),
40 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=8),
41 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=4),
42 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=8),
43 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=4),
44 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=8),
45]
48def ceil_div(a, b):
49 return (a + b - 1) // b
52def round_up(x: int, y: int) -> int:
53 return ((x + y - 1) // y) * y
56@lru_cache(maxsize=64)
57def _block_mesh(num_blocks: int):
58 return tle.device_mesh({"block": [("block_x", int(num_blocks))]})
61@lru_cache(maxsize=1)
62def _block_cluster_mesh_8():
63 return tle.device_mesh({"block_cluster": [("cluster_x", TLE_CLUSTER_SIZE)]})
66def _supports_tle_cluster_remote() -> bool:
67 if not torch.cuda.is_available():
68 return False
69 major, _minor = torch.cuda.get_device_capability()
70 return major >= 9
73def _install_triton_default_allocator(device: torch.device) -> None:
74 global _TRITON_ALLOCATOR_INSTALLED
75 if _TRITON_ALLOCATOR_INSTALLED:
76 return
78 def _alloc(size: int, _alignment: int, _stream: Optional[int]):
79 return torch.empty((size,), dtype=torch.uint8, device=device)
81 triton.set_allocator(_alloc)
82 _TRITON_ALLOCATOR_INSTALLED = True
85def _pick_tle_fused_launch_params(numel: int, num_experts: int) -> "tuple[int, int]":
86 if num_experts >= 256:
87 if numel >= 32768:
88 return 4096, 4
89 if numel >= 1024:
90 return 1024, 4
91 return 256, 8
93 if numel <= 512:
94 return 128, 8
95 if num_experts <= 64 and numel <= 2048:
96 return 128, 8
97 return 256, 8
100def _pick_tle_atomic_fused_launch_params(
101 numel: int, num_experts: int
102) -> "tuple[int, int]":
103 if num_experts >= 256:
104 if numel <= 16384:
105 return 256, 8
106 if numel <= 32768:
107 return 512, 4
108 return 1024, 4
109 return _pick_tle_fused_launch_params(numel, num_experts)
112def _pick_tle_atomic_fused_num_blocks(
113 numel: int, num_experts: int, block_tokens: int, device: torch.device
114) -> int:
115 if device.type != "cuda" or not torch.cuda.is_available():
116 return 1
117 props = torch.cuda.get_device_properties(device)
118 sm_count = int(getattr(props, "multi_processor_count", 1))
119 token_programs = triton.cdiv(numel, block_tokens)
120 cap_mult = 4 if num_experts < 256 else 16
121 block_cap = sm_count * cap_mult
122 return max(1, min(token_programs, block_cap))
125@libentry()
126@libtuner(
127 configs=TLE_ATOMIC_WARPS_CONFIGS,
128 key=["numel"],
129 strategy=["log"],
130)
131@triton.jit(do_not_specialize=["numel"])
132def moe_align_block_size_tle_atomic_fused_coop(
133 topk_ids_ptr,
134 sorted_token_ids_ptr,
135 expert_ids_ptr,
136 num_tokens_post_pad_ptr,
137 cumsum_ptr,
138 mesh: tl.constexpr,
139 num_experts: tl.constexpr,
140 block_size: tl.constexpr,
141 numel,
142 numel_sorted_token_ids: tl.constexpr,
143 numel_expert_ids: tl.constexpr,
144 NUM_BLOCKS: tl.constexpr,
145 BLOCK_TOKENS: tl.constexpr,
146 BLOCK_EXPERT: tl.constexpr,
147 EXPERTS_PER_PROG: tl.constexpr,
148):
149 pid = tl.program_id(0)
150 expert_offsets = tl.arange(0, BLOCK_EXPERT)
151 expert_mask = expert_offsets < num_experts
152 token_offsets = tl.arange(0, BLOCK_TOKENS)
154 for base in range(
155 pid * BLOCK_TOKENS, numel_sorted_token_ids, NUM_BLOCKS * BLOCK_TOKENS
156 ):
157 offs = base + token_offsets
158 tl.store(sorted_token_ids_ptr + offs, numel, mask=offs < numel_sorted_token_ids)
159 for base in range(pid * BLOCK_TOKENS, numel_expert_ids, NUM_BLOCKS * BLOCK_TOKENS):
160 offs = base + token_offsets
161 tl.store(expert_ids_ptr + offs, 0, mask=offs < numel_expert_ids)
162 if pid == 0:
163 tl.store(cumsum_ptr + expert_offsets, 0, mask=expert_mask)
164 tle.distributed_barrier(mesh)
166 local_counts = tle.gpu.alloc(
167 [BLOCK_EXPERT],
168 dtype=tl.int32,
169 layout=None,
170 scope=tle.gpu.smem,
171 nv_mma_shared_layout=False,
172 )
173 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,))
174 tl.store(local_counts_ptrs, 0, mask=expert_mask)
176 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS):
177 offs = base + token_offsets
178 mask = offs < numel
179 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
180 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
181 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta")
183 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0)
184 prefix_before = tl.atomic_add(
185 cumsum_ptr + expert_offsets,
186 local_counts_vals,
187 mask=expert_mask,
188 sem="acq_rel",
189 scope="gpu",
190 )
191 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask)
192 tle.distributed_barrier(mesh)
194 if pid == 0:
195 total_counts = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0)
196 aligned_counts = tl.cdiv(total_counts, block_size) * block_size
197 expert_starts = tl.cumsum(aligned_counts, axis=0) - aligned_counts
198 tl.store(cumsum_ptr + expert_offsets, expert_starts, mask=expert_mask)
199 total_tokens = tl.sum(aligned_counts, axis=0)
200 tl.store(num_tokens_post_pad_ptr, total_tokens)
201 tle.distributed_barrier(mesh)
203 expert_starts_local = tle.gpu.alloc(
204 [BLOCK_EXPERT],
205 dtype=tl.int32,
206 layout=None,
207 scope=tle.gpu.smem,
208 nv_mma_shared_layout=False,
209 )
210 expert_starts_ptrs = tle.gpu.local_ptr(expert_starts_local, (expert_offsets,))
211 expert_starts_vals = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0)
212 tl.store(expert_starts_ptrs, expert_starts_vals, mask=expert_mask)
214 total_tokens = tl.load(num_tokens_post_pad_ptr)
215 for local_expert_idx in range(EXPERTS_PER_PROG):
216 expert_id = pid + local_expert_idx * NUM_BLOCKS
217 valid_expert = expert_id < num_experts
218 start_idx = tl.load(
219 tle.gpu.local_ptr(expert_starts_local, (expert_id,)),
220 mask=valid_expert,
221 other=0,
222 )
223 next_expert = expert_id + 1
224 has_next = valid_expert & (next_expert < num_experts)
225 end_idx = tl.load(
226 tle.gpu.local_ptr(expert_starts_local, (next_expert,)),
227 mask=has_next,
228 other=total_tokens,
229 )
230 end_idx = tl.where(has_next, end_idx, total_tokens)
231 start_idx = tl.where(valid_expert, start_idx, 0)
232 end_idx = tl.where(valid_expert, end_idx, 0)
233 for i in range(start_idx, end_idx, block_size):
234 tl.store(expert_ids_ptr + i // block_size, expert_id)
236 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS):
237 offs = base + token_offsets
238 mask = offs < numel
239 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
240 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
241 rank_with_prefix = tl.atomic_add(
242 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta"
243 )
244 rank_base = tl.load(
245 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), mask=mask, other=0
246 )
247 rank_post_pad = rank_with_prefix + rank_base
248 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask)
251@libentry()
252@libtuner(
253 configs=TLE_CLUSTER_LAUNCH_CONFIGS,
254 key=["numel"],
255 strategy=["log"],
256)
257@triton.jit(do_not_specialize=["numel"])
258def moe_align_block_size_tle_cluster_fused(
259 topk_ids_ptr,
260 sorted_token_ids_ptr,
261 expert_ids_ptr,
262 num_tokens_post_pad_ptr,
263 num_experts: tl.constexpr,
264 block_size: tl.constexpr,
265 numel,
266 numel_sorted_token_ids: tl.constexpr,
267 numel_expert_ids: tl.constexpr,
268 mesh: tl.constexpr,
269 CLUSTER_SIZE: tl.constexpr,
270 BLOCK_TOKENS: tl.constexpr,
271 BLOCK_EXPERT: tl.constexpr,
272 EXPERTS_PER_SHARD: tl.constexpr,
273):
274 cluster_rank = tle.shard_id(mesh, "cluster_x")
275 is_rank0 = cluster_rank == 0
276 expert_offsets = tl.arange(0, BLOCK_EXPERT)
277 expert_mask = expert_offsets < num_experts
279 init_offsets = tl.arange(0, BLOCK_TOKENS)
280 for base in range(
281 cluster_rank * BLOCK_TOKENS, numel_sorted_token_ids, CLUSTER_SIZE * BLOCK_TOKENS
282 ):
283 offs = base + init_offsets
284 mask = offs < numel_sorted_token_ids
285 tl.store(sorted_token_ids_ptr + offs, numel, mask=mask)
286 for base in range(
287 cluster_rank * BLOCK_TOKENS, numel_expert_ids, CLUSTER_SIZE * BLOCK_TOKENS
288 ):
289 offs = base + init_offsets
290 mask = offs < numel_expert_ids
291 tl.store(expert_ids_ptr + offs, 0, mask=mask)
293 local_counts = tle.gpu.alloc(
294 [BLOCK_EXPERT],
295 dtype=tl.int32,
296 layout=None,
297 scope=tle.gpu.smem,
298 nv_mma_shared_layout=False,
299 )
300 cumsum_local = tle.gpu.alloc(
301 [BLOCK_EXPERT],
302 dtype=tl.int32,
303 layout=None,
304 scope=tle.gpu.smem,
305 nv_mma_shared_layout=False,
306 )
308 rank0_cumsum_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_offsets,))
309 if is_rank0:
310 tl.store(rank0_cumsum_ptrs, 0, mask=expert_mask)
311 tle.distributed_barrier(mesh)
313 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,))
314 tl.store(local_counts_ptrs, 0, mask=expert_mask)
316 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS):
317 offs = base + init_offsets
318 mask = offs < numel
319 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
320 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
321 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta")
323 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0)
324 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh)
325 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,))
326 prefix_before = tl.atomic_add(
327 rank0_cumsum_remote_ptrs,
328 local_counts_vals,
329 mask=expert_mask,
330 sem="relaxed",
331 scope="cta",
332 )
333 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask)
335 tle.distributed_barrier(mesh)
337 if is_rank0:
338 total_counts = tl.load(rank0_cumsum_ptrs, mask=expert_mask, other=0)
339 aligned_counts = tl.cdiv(total_counts, block_size) * block_size
340 expert_cumsum_inclusive = tl.cumsum(aligned_counts, axis=0)
341 expert_start_offsets = expert_cumsum_inclusive - aligned_counts
342 tl.store(rank0_cumsum_ptrs, expert_start_offsets, mask=expert_mask)
343 total_tokens = tl.sum(aligned_counts, axis=0)
344 tl.store(num_tokens_post_pad_ptr, total_tokens)
346 tle.distributed_barrier(mesh)
348 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh)
349 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,))
350 cumsum_vals = tl.load(rank0_cumsum_remote_ptrs, mask=expert_mask, other=0)
351 tl.store(
352 tle.gpu.local_ptr(cumsum_local, (expert_offsets,)),
353 cumsum_vals,
354 mask=expert_mask,
355 )
356 total_tokens = tl.load(num_tokens_post_pad_ptr)
358 for local_expert_idx in range(EXPERTS_PER_SHARD):
359 expert_idx = cluster_rank * EXPERTS_PER_SHARD + local_expert_idx
360 expert_id = expert_idx
361 valid_expert = expert_id < num_experts
362 start_ptr = tle.gpu.local_ptr(cumsum_local, (expert_id,))
363 start_idx = tl.load(start_ptr, mask=valid_expert, other=0)
364 next_expert_id = expert_id + 1
365 has_next = valid_expert & (next_expert_id < num_experts)
366 next_ptr = tle.gpu.local_ptr(cumsum_local, (next_expert_id,))
367 end_from_next = tl.load(next_ptr, mask=has_next, other=0)
368 end_idx = tl.where(has_next, end_from_next, total_tokens)
369 start_idx = tl.where(valid_expert, start_idx, 0)
370 end_idx = tl.where(valid_expert, end_idx, 0)
371 for i in range(start_idx, end_idx, block_size):
372 tl.store(expert_ids_ptr + i // block_size, expert_idx)
374 tle.distributed_barrier(mesh)
376 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS):
377 offs = base + init_offsets
378 mask = offs < numel
379 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
380 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
381 rank_with_prefix = tl.atomic_add(
382 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta"
383 )
384 base_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_id,))
385 rank_base = tl.load(base_ptrs, mask=mask, other=0)
386 rank_post_pad = rank_with_prefix + rank_base
387 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask)
390@triton.jit(do_not_specialize=["numel"])
391def moe_align_block_size_stage1(
392 topk_ids_ptr,
393 tokens_cnts_ptr,
394 num_experts: tl.constexpr,
395 numel,
396 tokens_per_thread: tl.constexpr,
397 sorted_token_ids_ptr,
398 expert_ids_ptr,
399 numel_sorted_token_ids: tl.constexpr,
400 numel_expert_ids: tl.constexpr,
401 block_size_sorted: tl.constexpr,
402 block_size_expert: tl.constexpr,
403):
404 pid = tl.program_id(0)
406 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted)
407 mask_sorted = offsets_sorted < numel_sorted_token_ids
408 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted)
410 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert)
411 mask_expert = offsets_expert < numel_expert_ids
412 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert)
414 start_idx = pid * tokens_per_thread
416 off_c = (pid + 1) * num_experts
418 offsets = start_idx + tl.arange(0, tokens_per_thread)
419 mask = offsets < numel
420 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0)
421 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1, mask=mask)
424@triton.jit
425def moe_align_block_size_stage2_vec(
426 tokens_cnts_ptr,
427 num_experts: tl.constexpr,
428):
429 pid = tl.program_id(0)
431 offset = tl.arange(0, num_experts) + 1
432 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid)
433 cnt = tl.cumsum(token_cnt, axis=0)
434 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt)
437@triton.jit
438def moe_align_block_size_stage2(
439 tokens_cnts_ptr,
440 num_experts: tl.constexpr,
441):
442 pid = tl.program_id(0)
444 last_cnt = 0
445 for i in range(1, num_experts + 1):
446 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
447 last_cnt = last_cnt + token_cnt
448 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
451@triton.jit
452def moe_align_block_size_stage3(
453 total_tokens_post_pad_ptr,
454 tokens_cnts_ptr,
455 cumsum_ptr,
456 num_experts: tl.constexpr,
457 num_experts_next_power_of_2: tl.constexpr,
458 block_size: tl.constexpr,
459):
460 off_cnt = num_experts * num_experts
462 expert_offsets = tl.arange(0, num_experts_next_power_of_2)
463 mask = expert_offsets < num_experts
464 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask)
465 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size
467 cumsum_values = tl.cumsum(aligned_cnts, axis=0)
468 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask)
470 total_tokens = tl.sum(aligned_cnts, axis=0)
471 tl.store(total_tokens_post_pad_ptr, total_tokens)
474@triton.jit(do_not_specialize=["numel"])
475def moe_align_block_size_stage4(
476 topk_ids_ptr,
477 sorted_token_ids_ptr,
478 expert_ids_ptr,
479 tokens_cnts_ptr,
480 cumsum_ptr,
481 num_experts: tl.constexpr,
482 block_size: tl.constexpr,
483 numel,
484 tokens_per_thread: tl.constexpr,
485):
486 pid = tl.program_id(0)
487 start_idx = tl.load(cumsum_ptr + pid)
488 end_idx = tl.load(cumsum_ptr + pid + 1)
490 for i in range(start_idx, end_idx, block_size):
491 tl.store(expert_ids_ptr + i // block_size, pid)
493 start_idx = pid * tokens_per_thread
494 off_t = pid * num_experts
496 offset = tl.arange(0, tokens_per_thread) + start_idx
497 mask = offset < numel
498 expert_id = tl.load(topk_ids_ptr + offset, mask=mask)
499 token_idx_in_expert = tl.atomic_add(
500 tokens_cnts_ptr + off_t + expert_id, 1, mask=mask
501 )
502 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id, mask=mask)
503 tl.store(sorted_token_ids_ptr + rank_post_pad, offset, mask=mask)
506def moe_align_block_size_triton(
507 topk_ids: torch.Tensor,
508 num_experts: int,
509 block_size: int,
510 sorted_token_ids: torch.Tensor,
511 expert_ids: torch.Tensor,
512 num_tokens_post_pad: torch.Tensor,
513) -> None:
514 logger.debug("GEMS MOE ALIGN BLOCK SIZE")
515 numel = topk_ids.numel()
516 numel_sorted_token_ids = sorted_token_ids.numel()
517 numel_expert_ids = expert_ids.numel()
518 grid = (num_experts,)
519 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts))
520 block_size_sorted = triton.next_power_of_2(
521 ceil_div(numel_sorted_token_ids, num_experts)
522 )
523 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts))
524 block_expert_tle = triton.next_power_of_2(num_experts)
526 if HAS_TLE and topk_ids.is_cuda and block_expert_tle <= 1024:
527 block_tokens_taf, _ = _pick_tle_atomic_fused_launch_params(numel, num_experts)
528 experts_per_shard = ceil_div(num_experts, TLE_CLUSTER_SIZE)
529 num_tokens = topk_ids.shape[0] if topk_ids.ndim > 1 else numel
531 def _run_tle_atomic_fused() -> bool:
532 cumsum_tle = torch.zeros(
533 (num_experts,), dtype=torch.int32, device=topk_ids.device
534 )
535 num_blocks = _pick_tle_atomic_fused_num_blocks(
536 numel,
537 num_experts,
538 block_tokens_taf,
539 topk_ids.device,
540 )
541 experts_per_prog = ceil_div(num_experts, num_blocks)
542 while True:
543 try:
544 moe_align_block_size_tle_atomic_fused_coop[(num_blocks,)](
545 topk_ids,
546 sorted_token_ids,
547 expert_ids,
548 num_tokens_post_pad,
549 cumsum_tle,
550 _block_mesh(num_blocks),
551 num_experts,
552 block_size,
553 numel,
554 numel_sorted_token_ids,
555 numel_expert_ids,
556 NUM_BLOCKS=num_blocks,
557 BLOCK_TOKENS=block_tokens_taf,
558 BLOCK_EXPERT=block_expert_tle,
559 EXPERTS_PER_PROG=experts_per_prog,
560 launch_cooperative_grid=True,
561 )
562 return True
563 except Exception as ex:
564 msg = str(ex).lower()
565 if "no allocator was set" in msg:
566 _install_triton_default_allocator(topk_ids.device)
567 continue
568 if num_blocks <= 1 or "cooperative" not in msg:
569 logger.debug(
570 "TLE atomic fused launch failed, fallback to triton: %s",
571 ex,
572 )
573 return False
574 num_blocks = max(1, num_blocks // 2)
575 experts_per_prog = ceil_div(num_experts, num_blocks)
577 if (
578 num_tokens < TLE_BIG_TOKEN_THRESHOLD_TOKENS
579 and _supports_tle_cluster_remote()
580 ):
581 try:
582 moe_align_block_size_tle_cluster_fused[(1,)](
583 topk_ids,
584 sorted_token_ids,
585 expert_ids,
586 num_tokens_post_pad,
587 num_experts,
588 block_size,
589 numel,
590 numel_sorted_token_ids,
591 numel_expert_ids,
592 mesh=_block_cluster_mesh_8(),
593 CLUSTER_SIZE=TLE_CLUSTER_SIZE,
594 BLOCK_EXPERT=block_expert_tle,
595 EXPERTS_PER_SHARD=experts_per_shard,
596 )
597 return
598 except Exception as ex:
599 logger.debug(
600 "TLE cluster fused launch failed, fallback to atomic/triton: %s",
601 ex,
602 )
604 if _run_tle_atomic_fused():
605 return
607 # The tensor needs to be padded before calculating IDs,
608 # to prevent out-of-bounds address access.
609 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
610 tokens_cnts = torch.zeros(
611 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
612 )
613 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts)
615 moe_align_block_size_stage1[grid](
616 topk_ids,
617 tokens_cnts,
618 num_experts,
619 numel,
620 tokens_per_thread,
621 sorted_token_ids,
622 expert_ids,
623 numel_sorted_token_ids,
624 numel_expert_ids,
625 block_size_sorted,
626 block_size_expert,
627 )
628 if num_experts == triton.next_power_of_2(num_experts):
629 moe_align_block_size_stage2_vec[grid](tokens_cnts, num_experts)
630 else:
631 moe_align_block_size_stage2[grid](tokens_cnts, num_experts)
632 moe_align_block_size_stage3[(1,)](
633 num_tokens_post_pad,
634 tokens_cnts,
635 cumsum,
636 num_experts,
637 num_experts_next_power_of_2,
638 block_size,
639 )
640 moe_align_block_size_stage4[grid](
641 topk_ids,
642 sorted_token_ids,
643 expert_ids,
644 tokens_cnts,
645 cumsum,
646 num_experts,
647 block_size,
648 numel,
649 tokens_per_thread,
650 )
653def moe_align_block_size(
654 topk_ids: torch.Tensor,
655 block_size: int,
656 num_experts: int,
657 expert_map: Optional[torch.Tensor] = None,
658 pad_sorted_ids: bool = False,
659) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]":
660 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
661 if pad_sorted_ids:
662 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
663 sorted_ids = torch.empty(
664 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
665 )
666 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
667 expert_ids = torch.empty(
668 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
669 )
670 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
672 moe_align_block_size_triton(
673 topk_ids,
674 num_experts,
675 block_size,
676 sorted_ids,
677 expert_ids,
678 num_tokens_post_pad,
679 )
681 if expert_map is not None:
682 expert_ids = expert_map[expert_ids]
684 return sorted_ids, expert_ids, num_tokens_post_pad