Coverage for src/flag_gems/fused/moe_align_block_size.py: 26%
348 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
2from functools import lru_cache
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems.utils import libentry, libtuner
12def _triton_version_at_least(major: int, minor: int, patch: int = 0) -> bool:
13 version = str(getattr(triton, "__version__", "0.0.0")).split("+", 1)[0]
14 parts = version.split(".")
15 parsed = []
16 for part in parts[:3]:
17 digits = []
18 for ch in part:
19 if ch.isdigit():
20 digits.append(ch)
21 else:
22 break
23 parsed.append(int("".join(digits)) if digits else 0)
24 while len(parsed) < 3:
25 parsed.append(0)
26 return tuple(parsed) >= (major, minor, patch)
29if _triton_version_at_least(3, 6, 0):
30 try:
31 import triton.experimental.tle.language as tle
32 import triton.experimental.tle.language.gpu as tleg
34 HAS_TLE = True
35 except ImportError:
36 tle = None
37 tleg = None
38 HAS_TLE = False
39else:
40 tle = None
41 tleg = None
42 HAS_TLE = False
45logger = logging.getLogger(__name__)
47TLE_CLUSTER_SIZE = 8
48TLE_BIG_TOKEN_THRESHOLD_TOKENS = 4096
49_TRITON_ALLOCATOR_INSTALLED = False
50TLE_ATOMIC_WARPS_CONFIGS = [
51 triton.Config(kwargs={}, num_warps=4),
52 triton.Config(kwargs={}, num_warps=8),
53]
54TLE_CLUSTER_LAUNCH_CONFIGS = [
55 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=4),
56 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=8),
57 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=4),
58 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=8),
59 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=4),
60 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=8),
61 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=4),
62 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=8),
63]
66def ceil_div(a, b):
67 return (a + b - 1) // b
70def round_up(x: int, y: int) -> int:
71 return ((x + y - 1) // y) * y
74@lru_cache(maxsize=64)
75def _block_mesh(num_blocks: int):
76 return tle.device_mesh({"block": [("block_x", int(num_blocks))]})
79@lru_cache(maxsize=1)
80def _block_cluster_mesh_8():
81 return tle.device_mesh({"block_cluster": [("cluster_x", TLE_CLUSTER_SIZE)]})
84def _supports_tle_cluster_remote() -> bool:
85 if not torch.cuda.is_available():
86 return False
87 major, _minor = torch.cuda.get_device_capability()
88 return major >= 9
91def _install_triton_default_allocator(device: torch.device) -> None:
92 global _TRITON_ALLOCATOR_INSTALLED
93 if _TRITON_ALLOCATOR_INSTALLED:
94 return
96 def _alloc(size: int, _alignment: int, _stream: Optional[int]):
97 return torch.empty((size,), dtype=torch.uint8, device=device)
99 triton.set_allocator(_alloc)
100 _TRITON_ALLOCATOR_INSTALLED = True
103def _pick_tle_fused_launch_params(numel: int, num_experts: int) -> "tuple[int, int]":
104 if num_experts >= 256:
105 if numel >= 32768:
106 return 4096, 4
107 if numel >= 1024:
108 return 1024, 4
109 return 256, 8
111 if numel <= 512:
112 return 128, 8
113 if num_experts <= 64 and numel <= 2048:
114 return 128, 8
115 return 256, 8
118def _pick_tle_atomic_fused_launch_params(
119 numel: int, num_experts: int
120) -> "tuple[int, int]":
121 if num_experts >= 256:
122 if numel <= 16384:
123 return 256, 8
124 if numel <= 32768:
125 return 512, 4
126 return 1024, 4
127 return _pick_tle_fused_launch_params(numel, num_experts)
130def _pick_tle_atomic_fused_num_blocks(
131 numel: int, num_experts: int, block_tokens: int, device: torch.device
132) -> int:
133 if device.type != "cuda" or not torch.cuda.is_available():
134 return 1
135 props = torch.cuda.get_device_properties(device)
136 sm_count = int(getattr(props, "multi_processor_count", 1))
137 token_programs = triton.cdiv(numel, block_tokens)
138 cap_mult = 4 if num_experts < 256 else 16
139 block_cap = sm_count * cap_mult
140 return max(1, min(token_programs, block_cap))
143@libentry()
144@libtuner(
145 configs=TLE_ATOMIC_WARPS_CONFIGS,
146 key=["numel"],
147 strategy=["log"],
148)
149@triton.jit(do_not_specialize=["numel"])
150def moe_align_block_size_tle_atomic_fused_coop(
151 topk_ids_ptr,
152 sorted_token_ids_ptr,
153 expert_ids_ptr,
154 num_tokens_post_pad_ptr,
155 cumsum_ptr,
156 mesh: tl.constexpr,
157 num_experts: tl.constexpr,
158 block_size: tl.constexpr,
159 numel,
160 numel_sorted_token_ids: tl.constexpr,
161 numel_expert_ids: tl.constexpr,
162 NUM_BLOCKS: tl.constexpr,
163 BLOCK_TOKENS: tl.constexpr,
164 BLOCK_EXPERT: tl.constexpr,
165 EXPERTS_PER_PROG: tl.constexpr,
166):
167 pid = tl.program_id(0)
168 expert_offsets = tl.arange(0, BLOCK_EXPERT)
169 expert_mask = expert_offsets < num_experts
170 token_offsets = tl.arange(0, BLOCK_TOKENS)
172 for base in range(
173 pid * BLOCK_TOKENS, numel_sorted_token_ids, NUM_BLOCKS * BLOCK_TOKENS
174 ):
175 offs = base + token_offsets
176 tl.store(sorted_token_ids_ptr + offs, numel, mask=offs < numel_sorted_token_ids)
177 for base in range(pid * BLOCK_TOKENS, numel_expert_ids, NUM_BLOCKS * BLOCK_TOKENS):
178 offs = base + token_offsets
179 tl.store(expert_ids_ptr + offs, 0, mask=offs < numel_expert_ids)
180 if pid == 0:
181 tl.store(cumsum_ptr + expert_offsets, 0, mask=expert_mask)
182 tle.distributed_barrier(mesh)
184 local_counts = tle.gpu.alloc(
185 [BLOCK_EXPERT],
186 dtype=tl.int32,
187 layout=None,
188 scope=tle.gpu.smem,
189 nv_mma_shared_layout=False,
190 )
191 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,))
192 tl.store(local_counts_ptrs, 0, mask=expert_mask)
194 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS):
195 offs = base + token_offsets
196 mask = offs < numel
197 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
198 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
199 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta")
201 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0)
202 prefix_before = tl.atomic_add(
203 cumsum_ptr + expert_offsets,
204 local_counts_vals,
205 mask=expert_mask,
206 sem="acq_rel",
207 scope="gpu",
208 )
209 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask)
210 tle.distributed_barrier(mesh)
212 if pid == 0:
213 total_counts = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0)
214 aligned_counts = tl.cdiv(total_counts, block_size) * block_size
215 expert_starts = tl.cumsum(aligned_counts, axis=0) - aligned_counts
216 tl.store(cumsum_ptr + expert_offsets, expert_starts, mask=expert_mask)
217 total_tokens = tl.sum(aligned_counts, axis=0)
218 tl.store(num_tokens_post_pad_ptr, total_tokens)
219 tle.distributed_barrier(mesh)
221 expert_starts_local = tle.gpu.alloc(
222 [BLOCK_EXPERT],
223 dtype=tl.int32,
224 layout=None,
225 scope=tle.gpu.smem,
226 nv_mma_shared_layout=False,
227 )
228 expert_starts_ptrs = tle.gpu.local_ptr(expert_starts_local, (expert_offsets,))
229 expert_starts_vals = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0)
230 tl.store(expert_starts_ptrs, expert_starts_vals, mask=expert_mask)
232 total_tokens = tl.load(num_tokens_post_pad_ptr)
233 for local_expert_idx in range(EXPERTS_PER_PROG):
234 expert_id = pid + local_expert_idx * NUM_BLOCKS
235 valid_expert = expert_id < num_experts
236 start_idx = tl.load(
237 tle.gpu.local_ptr(expert_starts_local, (expert_id,)),
238 mask=valid_expert,
239 other=0,
240 )
241 next_expert = expert_id + 1
242 has_next = valid_expert & (next_expert < num_experts)
243 end_idx = tl.load(
244 tle.gpu.local_ptr(expert_starts_local, (next_expert,)),
245 mask=has_next,
246 other=total_tokens,
247 )
248 end_idx = tl.where(has_next, end_idx, total_tokens)
249 start_idx = tl.where(valid_expert, start_idx, 0)
250 end_idx = tl.where(valid_expert, end_idx, 0)
251 for i in range(start_idx, end_idx, block_size):
252 tl.store(expert_ids_ptr + i // block_size, expert_id)
254 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS):
255 offs = base + token_offsets
256 mask = offs < numel
257 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
258 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
259 rank_with_prefix = tl.atomic_add(
260 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta"
261 )
262 rank_base = tl.load(
263 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), mask=mask, other=0
264 )
265 rank_post_pad = rank_with_prefix + rank_base
266 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask)
269@libentry()
270@libtuner(
271 configs=TLE_CLUSTER_LAUNCH_CONFIGS,
272 key=["numel"],
273 strategy=["log"],
274)
275@triton.jit(do_not_specialize=["numel"])
276def moe_align_block_size_tle_cluster_fused(
277 topk_ids_ptr,
278 sorted_token_ids_ptr,
279 expert_ids_ptr,
280 num_tokens_post_pad_ptr,
281 num_experts: tl.constexpr,
282 block_size: tl.constexpr,
283 numel,
284 numel_sorted_token_ids: tl.constexpr,
285 numel_expert_ids: tl.constexpr,
286 mesh: tl.constexpr,
287 CLUSTER_SIZE: tl.constexpr,
288 BLOCK_TOKENS: tl.constexpr,
289 BLOCK_EXPERT: tl.constexpr,
290 EXPERTS_PER_SHARD: tl.constexpr,
291):
292 cluster_rank = tle.shard_id(mesh, "cluster_x")
293 is_rank0 = cluster_rank == 0
294 expert_offsets = tl.arange(0, BLOCK_EXPERT)
295 expert_mask = expert_offsets < num_experts
297 init_offsets = tl.arange(0, BLOCK_TOKENS)
298 for base in range(
299 cluster_rank * BLOCK_TOKENS, numel_sorted_token_ids, CLUSTER_SIZE * BLOCK_TOKENS
300 ):
301 offs = base + init_offsets
302 mask = offs < numel_sorted_token_ids
303 tl.store(sorted_token_ids_ptr + offs, numel, mask=mask)
304 for base in range(
305 cluster_rank * BLOCK_TOKENS, numel_expert_ids, CLUSTER_SIZE * BLOCK_TOKENS
306 ):
307 offs = base + init_offsets
308 mask = offs < numel_expert_ids
309 tl.store(expert_ids_ptr + offs, 0, mask=mask)
311 local_counts = tle.gpu.alloc(
312 [BLOCK_EXPERT],
313 dtype=tl.int32,
314 layout=None,
315 scope=tle.gpu.smem,
316 nv_mma_shared_layout=False,
317 )
318 cumsum_local = tle.gpu.alloc(
319 [BLOCK_EXPERT],
320 dtype=tl.int32,
321 layout=None,
322 scope=tle.gpu.smem,
323 nv_mma_shared_layout=False,
324 )
326 rank0_cumsum_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_offsets,))
327 if is_rank0:
328 tl.store(rank0_cumsum_ptrs, 0, mask=expert_mask)
329 tle.distributed_barrier(mesh)
331 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,))
332 tl.store(local_counts_ptrs, 0, mask=expert_mask)
334 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS):
335 offs = base + init_offsets
336 mask = offs < numel
337 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
338 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
339 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta")
341 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0)
342 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh)
343 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,))
344 prefix_before = tl.atomic_add(
345 rank0_cumsum_remote_ptrs,
346 local_counts_vals,
347 mask=expert_mask,
348 sem="relaxed",
349 scope="cta",
350 )
351 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask)
353 tle.distributed_barrier(mesh)
355 if is_rank0:
356 total_counts = tl.load(rank0_cumsum_ptrs, mask=expert_mask, other=0)
357 aligned_counts = tl.cdiv(total_counts, block_size) * block_size
358 expert_cumsum_inclusive = tl.cumsum(aligned_counts, axis=0)
359 expert_start_offsets = expert_cumsum_inclusive - aligned_counts
360 tl.store(rank0_cumsum_ptrs, expert_start_offsets, mask=expert_mask)
361 total_tokens = tl.sum(aligned_counts, axis=0)
362 tl.store(num_tokens_post_pad_ptr, total_tokens)
364 tle.distributed_barrier(mesh)
366 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh)
367 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,))
368 cumsum_vals = tl.load(rank0_cumsum_remote_ptrs, mask=expert_mask, other=0)
369 tl.store(
370 tle.gpu.local_ptr(cumsum_local, (expert_offsets,)),
371 cumsum_vals,
372 mask=expert_mask,
373 )
374 total_tokens = tl.load(num_tokens_post_pad_ptr)
376 for local_expert_idx in range(EXPERTS_PER_SHARD):
377 expert_idx = cluster_rank * EXPERTS_PER_SHARD + local_expert_idx
378 expert_id = expert_idx
379 valid_expert = expert_id < num_experts
380 start_ptr = tle.gpu.local_ptr(cumsum_local, (expert_id,))
381 start_idx = tl.load(start_ptr, mask=valid_expert, other=0)
382 next_expert_id = expert_id + 1
383 has_next = valid_expert & (next_expert_id < num_experts)
384 next_ptr = tle.gpu.local_ptr(cumsum_local, (next_expert_id,))
385 end_from_next = tl.load(next_ptr, mask=has_next, other=0)
386 end_idx = tl.where(has_next, end_from_next, total_tokens)
387 start_idx = tl.where(valid_expert, start_idx, 0)
388 end_idx = tl.where(valid_expert, end_idx, 0)
389 for i in range(start_idx, end_idx, block_size):
390 tl.store(expert_ids_ptr + i // block_size, expert_idx)
392 tle.distributed_barrier(mesh)
394 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS):
395 offs = base + init_offsets
396 mask = offs < numel
397 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
398 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
399 rank_with_prefix = tl.atomic_add(
400 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta"
401 )
402 base_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_id,))
403 rank_base = tl.load(base_ptrs, mask=mask, other=0)
404 rank_post_pad = rank_with_prefix + rank_base
405 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask)
408@triton.jit(do_not_specialize=["numel"])
409def moe_align_block_size_stage1(
410 topk_ids_ptr,
411 tokens_cnts_ptr,
412 num_experts: tl.constexpr,
413 numel,
414 tokens_per_thread: tl.constexpr,
415 sorted_token_ids_ptr,
416 expert_ids_ptr,
417 numel_sorted_token_ids: tl.constexpr,
418 numel_expert_ids: tl.constexpr,
419 block_size_sorted: tl.constexpr,
420 block_size_expert: tl.constexpr,
421):
422 pid = tl.program_id(0)
424 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted)
425 mask_sorted = offsets_sorted < numel_sorted_token_ids
426 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted)
428 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert)
429 mask_expert = offsets_expert < numel_expert_ids
430 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert)
432 start_idx = pid * tokens_per_thread
434 off_c = (pid + 1) * num_experts
436 offsets = start_idx + tl.arange(0, tokens_per_thread)
437 mask = offsets < numel
438 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0)
439 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1, mask=mask)
442@triton.jit
443def moe_align_block_size_stage2_vec(
444 tokens_cnts_ptr,
445 num_experts: tl.constexpr,
446):
447 pid = tl.program_id(0)
449 offset = tl.arange(0, num_experts) + 1
450 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid)
451 cnt = tl.cumsum(token_cnt, axis=0)
452 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt)
455@triton.jit
456def moe_align_block_size_stage2(
457 tokens_cnts_ptr,
458 num_experts: tl.constexpr,
459):
460 pid = tl.program_id(0)
462 last_cnt = 0
463 for i in range(1, num_experts + 1):
464 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
465 last_cnt = last_cnt + token_cnt
466 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
469@triton.jit
470def moe_align_block_size_stage3(
471 total_tokens_post_pad_ptr,
472 tokens_cnts_ptr,
473 cumsum_ptr,
474 num_experts: tl.constexpr,
475 num_experts_next_power_of_2: tl.constexpr,
476 block_size: tl.constexpr,
477):
478 off_cnt = num_experts * num_experts
480 expert_offsets = tl.arange(0, num_experts_next_power_of_2)
481 mask = expert_offsets < num_experts
482 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask)
483 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size
485 cumsum_values = tl.cumsum(aligned_cnts, axis=0)
486 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask)
488 total_tokens = tl.sum(aligned_cnts, axis=0)
489 tl.store(total_tokens_post_pad_ptr, total_tokens)
492@triton.jit(do_not_specialize=["numel"])
493def moe_align_block_size_stage4(
494 topk_ids_ptr,
495 sorted_token_ids_ptr,
496 expert_ids_ptr,
497 tokens_cnts_ptr,
498 cumsum_ptr,
499 num_experts: tl.constexpr,
500 block_size: tl.constexpr,
501 numel,
502 tokens_per_thread: tl.constexpr,
503):
504 pid = tl.program_id(0)
505 start_idx = tl.load(cumsum_ptr + pid)
506 end_idx = tl.load(cumsum_ptr + pid + 1)
508 for i in range(start_idx, end_idx, block_size):
509 tl.store(expert_ids_ptr + i // block_size, pid)
511 start_idx = pid * tokens_per_thread
512 off_t = pid * num_experts
514 offset = tl.arange(0, tokens_per_thread) + start_idx
515 mask = offset < numel
516 expert_id = tl.load(topk_ids_ptr + offset, mask=mask)
517 token_idx_in_expert = tl.atomic_add(
518 tokens_cnts_ptr + off_t + expert_id, 1, mask=mask
519 )
520 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id, mask=mask)
521 tl.store(sorted_token_ids_ptr + rank_post_pad, offset, mask=mask)
524def moe_align_block_size_triton(
525 topk_ids: torch.Tensor,
526 num_experts: int,
527 block_size: int,
528 sorted_token_ids: torch.Tensor,
529 expert_ids: torch.Tensor,
530 num_tokens_post_pad: torch.Tensor,
531) -> None:
532 logger.debug("GEMS MOE ALIGN BLOCK SIZE")
533 numel = topk_ids.numel()
534 numel_sorted_token_ids = sorted_token_ids.numel()
535 numel_expert_ids = expert_ids.numel()
536 grid = (num_experts,)
537 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts))
538 block_size_sorted = triton.next_power_of_2(
539 ceil_div(numel_sorted_token_ids, num_experts)
540 )
541 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts))
542 block_expert_tle = triton.next_power_of_2(num_experts)
544 if HAS_TLE and topk_ids.is_cuda and block_expert_tle <= 1024:
545 block_tokens_taf, _ = _pick_tle_atomic_fused_launch_params(numel, num_experts)
546 experts_per_shard = ceil_div(num_experts, TLE_CLUSTER_SIZE)
547 num_tokens = topk_ids.shape[0] if topk_ids.ndim > 1 else numel
549 def _run_tle_atomic_fused() -> bool:
550 cumsum_tle = torch.zeros(
551 (num_experts,), dtype=torch.int32, device=topk_ids.device
552 )
553 num_blocks = _pick_tle_atomic_fused_num_blocks(
554 numel,
555 num_experts,
556 block_tokens_taf,
557 topk_ids.device,
558 )
559 experts_per_prog = ceil_div(num_experts, num_blocks)
560 while True:
561 try:
562 moe_align_block_size_tle_atomic_fused_coop[(num_blocks,)](
563 topk_ids,
564 sorted_token_ids,
565 expert_ids,
566 num_tokens_post_pad,
567 cumsum_tle,
568 _block_mesh(num_blocks),
569 num_experts,
570 block_size,
571 numel,
572 numel_sorted_token_ids,
573 numel_expert_ids,
574 NUM_BLOCKS=num_blocks,
575 BLOCK_TOKENS=block_tokens_taf,
576 BLOCK_EXPERT=block_expert_tle,
577 EXPERTS_PER_PROG=experts_per_prog,
578 launch_cooperative_grid=True,
579 )
580 return True
581 except Exception as ex:
582 msg = str(ex).lower()
583 if "no allocator was set" in msg:
584 _install_triton_default_allocator(topk_ids.device)
585 continue
586 if num_blocks <= 1 or "cooperative" not in msg:
587 logger.debug(
588 "TLE atomic fused launch failed, fallback to triton: %s",
589 ex,
590 )
591 return False
592 num_blocks = max(1, num_blocks // 2)
593 experts_per_prog = ceil_div(num_experts, num_blocks)
595 if (
596 num_tokens < TLE_BIG_TOKEN_THRESHOLD_TOKENS
597 and _supports_tle_cluster_remote()
598 ):
599 try:
600 moe_align_block_size_tle_cluster_fused[(1,)](
601 topk_ids,
602 sorted_token_ids,
603 expert_ids,
604 num_tokens_post_pad,
605 num_experts,
606 block_size,
607 numel,
608 numel_sorted_token_ids,
609 numel_expert_ids,
610 mesh=_block_cluster_mesh_8(),
611 CLUSTER_SIZE=TLE_CLUSTER_SIZE,
612 BLOCK_EXPERT=block_expert_tle,
613 EXPERTS_PER_SHARD=experts_per_shard,
614 )
615 return
616 except Exception as ex:
617 logger.debug(
618 "TLE cluster fused launch failed, fallback to atomic/triton: %s",
619 ex,
620 )
622 if _run_tle_atomic_fused():
623 return
625 # The tensor needs to be padded before calculating IDs,
626 # to prevent out-of-bounds address access.
627 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
628 tokens_cnts = torch.zeros(
629 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
630 )
631 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts)
633 moe_align_block_size_stage1[grid](
634 topk_ids,
635 tokens_cnts,
636 num_experts,
637 numel,
638 tokens_per_thread,
639 sorted_token_ids,
640 expert_ids,
641 numel_sorted_token_ids,
642 numel_expert_ids,
643 block_size_sorted,
644 block_size_expert,
645 )
646 if num_experts == triton.next_power_of_2(num_experts):
647 moe_align_block_size_stage2_vec[grid](tokens_cnts, num_experts)
648 else:
649 moe_align_block_size_stage2[grid](tokens_cnts, num_experts)
650 moe_align_block_size_stage3[(1,)](
651 num_tokens_post_pad,
652 tokens_cnts,
653 cumsum,
654 num_experts,
655 num_experts_next_power_of_2,
656 block_size,
657 )
658 moe_align_block_size_stage4[grid](
659 topk_ids,
660 sorted_token_ids,
661 expert_ids,
662 tokens_cnts,
663 cumsum,
664 num_experts,
665 block_size,
666 numel,
667 tokens_per_thread,
668 )
671def moe_align_block_size(
672 topk_ids: torch.Tensor,
673 block_size: int,
674 num_experts: int,
675 expert_map: Optional[torch.Tensor] = None,
676 pad_sorted_ids: bool = False,
677) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]":
678 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
679 if pad_sorted_ids:
680 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
681 sorted_ids = torch.empty(
682 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
683 )
684 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
685 expert_ids = torch.empty(
686 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
687 )
688 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
690 moe_align_block_size_triton(
691 topk_ids,
692 num_experts,
693 block_size,
694 sorted_ids,
695 expert_ids,
696 num_tokens_post_pad,
697 )
699 if expert_map is not None:
700 expert_ids = expert_map[expert_ids]
702 return sorted_ids, expert_ids, num_tokens_post_pad