Coverage for src/flag_gems/runtime/backend/_ascend/fused/moe_align_block_size.py: 0%
352 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2from functools import lru_cache
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems.utils import libentry, libtuner
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14def _triton_version_at_least(major: int, minor: int, patch: int = 0) -> bool:
15 version = str(getattr(triton, "__version__", "0.0.0")).split("+", 1)[0]
16 parts = version.split(".")
17 parsed = []
18 for part in parts[:3]:
19 digits = []
20 for ch in part:
21 if ch.isdigit():
22 digits.append(ch)
23 else:
24 break
25 parsed.append(int("".join(digits)) if digits else 0)
26 while len(parsed) < 3:
27 parsed.append(0)
28 return tuple(parsed) >= (major, minor, patch)
31if _triton_version_at_least(3, 6, 0):
32 try:
33 import triton.experimental.tle.language as tle
34 import triton.experimental.tle.language.gpu as tleg
36 HAS_TLE = True
37 except ImportError:
38 tle = None
39 tleg = None
40 HAS_TLE = False
41else:
42 tle = None
43 tleg = None
44 HAS_TLE = False
47logger = logging.getLogger(__name__)
49TLE_CLUSTER_SIZE = 8
50TLE_BIG_TOKEN_THRESHOLD_TOKENS = 4096
51_TRITON_ALLOCATOR_INSTALLED = False
52TLE_ATOMIC_WARPS_CONFIGS = [
53 triton.Config(kwargs={}, num_warps=4),
54 triton.Config(kwargs={}, num_warps=8),
55]
56TLE_CLUSTER_LAUNCH_CONFIGS = [
57 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=4),
58 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=8),
59 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=4),
60 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=8),
61 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=4),
62 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=8),
63 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=4),
64 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=8),
65]
68def ceil_div(a, b):
69 return (a + b - 1) // b
72def round_up(x: int, y: int) -> int:
73 return ((x + y - 1) // y) * y
76@lru_cache(maxsize=64)
77def _block_mesh(num_blocks: int):
78 return tle.device_mesh({"block": [("block_x", int(num_blocks))]})
81@lru_cache(maxsize=1)
82def _block_cluster_mesh_8():
83 return tle.device_mesh({"block_cluster": [("cluster_x", TLE_CLUSTER_SIZE)]})
86def _supports_tle_cluster_remote() -> bool:
87 if not torch.cuda.is_available():
88 return False
89 major, _minor = torch.cuda.get_device_capability()
90 return major >= 9
93def _install_triton_default_allocator(device: torch.device) -> None:
94 global _TRITON_ALLOCATOR_INSTALLED
95 if _TRITON_ALLOCATOR_INSTALLED:
96 return
98 def _alloc(size: int, _alignment: int, _stream: Optional[int]):
99 return torch.empty((size,), dtype=torch.uint8, device=device)
101 triton.set_allocator(_alloc)
102 _TRITON_ALLOCATOR_INSTALLED = True
105def _pick_tle_fused_launch_params(numel: int, num_experts: int) -> "tuple[int, int]":
106 if num_experts >= 256:
107 if numel >= 32768:
108 return 4096, 4
109 if numel >= 1024:
110 return 1024, 4
111 return 256, 8
113 if numel <= 512:
114 return 128, 8
115 if num_experts <= 64 and numel <= 2048:
116 return 128, 8
117 return 256, 8
120def _pick_tle_atomic_fused_launch_params(
121 numel: int, num_experts: int
122) -> "tuple[int, int]":
123 if num_experts >= 256:
124 if numel <= 16384:
125 return 256, 8
126 if numel <= 32768:
127 return 512, 4
128 return 1024, 4
129 return _pick_tle_fused_launch_params(numel, num_experts)
132def _pick_tle_atomic_fused_num_blocks(
133 numel: int, num_experts: int, block_tokens: int, device: torch.device
134) -> int:
135 if device.type != "cuda" or not torch.cuda.is_available():
136 return 1
137 props = torch.cuda.get_device_properties(device)
138 sm_count = int(getattr(props, "multi_processor_count", 1))
139 token_programs = triton.cdiv(numel, block_tokens)
140 cap_mult = 4 if num_experts < 256 else 16
141 block_cap = sm_count * cap_mult
142 return max(1, min(token_programs, block_cap))
145@libentry()
146@libtuner(
147 configs=TLE_ATOMIC_WARPS_CONFIGS,
148 key=["numel"],
149 strategy=["log"],
150)
151@triton.jit(do_not_specialize=["numel"])
152def moe_align_block_size_tle_atomic_fused_coop(
153 topk_ids_ptr,
154 sorted_token_ids_ptr,
155 expert_ids_ptr,
156 num_tokens_post_pad_ptr,
157 cumsum_ptr,
158 mesh: tl.constexpr,
159 num_experts: tl.constexpr,
160 block_size: tl.constexpr,
161 numel,
162 numel_sorted_token_ids: tl.constexpr,
163 numel_expert_ids: tl.constexpr,
164 NUM_BLOCKS: tl.constexpr,
165 BLOCK_TOKENS: tl.constexpr,
166 BLOCK_EXPERT: tl.constexpr,
167 EXPERTS_PER_PROG: tl.constexpr,
168):
169 pid = tl.program_id(0)
170 expert_offsets = tl.arange(0, BLOCK_EXPERT)
171 expert_mask = expert_offsets < num_experts
172 token_offsets = tl.arange(0, BLOCK_TOKENS)
174 for base in range(
175 pid * BLOCK_TOKENS, numel_sorted_token_ids, NUM_BLOCKS * BLOCK_TOKENS
176 ):
177 offs = base + token_offsets
178 tl.store(sorted_token_ids_ptr + offs, numel, mask=offs < numel_sorted_token_ids)
179 for base in range(pid * BLOCK_TOKENS, numel_expert_ids, NUM_BLOCKS * BLOCK_TOKENS):
180 offs = base + token_offsets
181 tl.store(expert_ids_ptr + offs, 0, mask=offs < numel_expert_ids)
182 if pid == 0:
183 tl.store(cumsum_ptr + expert_offsets, 0, mask=expert_mask)
184 tle.distributed_barrier(mesh)
186 local_counts = tle.gpu.alloc(
187 [BLOCK_EXPERT],
188 dtype=tl.int32,
189 layout=None,
190 scope=tle.gpu.smem,
191 nv_mma_shared_layout=False,
192 )
193 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,))
194 tl.store(local_counts_ptrs, 0, mask=expert_mask)
196 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS):
197 offs = base + token_offsets
198 mask = offs < numel
199 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
200 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
201 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta")
203 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0)
204 prefix_before = tl.atomic_add(
205 cumsum_ptr + expert_offsets,
206 local_counts_vals,
207 mask=expert_mask,
208 sem="acq_rel",
209 scope="gpu",
210 )
211 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask)
212 tle.distributed_barrier(mesh)
214 if pid == 0:
215 total_counts = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0)
216 aligned_counts = tl.cdiv(total_counts, block_size) * block_size
217 expert_starts = tl.cumsum(aligned_counts, axis=0) - aligned_counts
218 tl.store(cumsum_ptr + expert_offsets, expert_starts, mask=expert_mask)
219 total_tokens = tl.sum(aligned_counts, axis=0)
220 tl.store(num_tokens_post_pad_ptr, total_tokens)
221 tle.distributed_barrier(mesh)
223 expert_starts_local = tle.gpu.alloc(
224 [BLOCK_EXPERT],
225 dtype=tl.int32,
226 layout=None,
227 scope=tle.gpu.smem,
228 nv_mma_shared_layout=False,
229 )
230 expert_starts_ptrs = tle.gpu.local_ptr(expert_starts_local, (expert_offsets,))
231 expert_starts_vals = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0)
232 tl.store(expert_starts_ptrs, expert_starts_vals, mask=expert_mask)
234 total_tokens = tl.load(num_tokens_post_pad_ptr)
235 for local_expert_idx in range(EXPERTS_PER_PROG):
236 expert_id = pid + local_expert_idx * NUM_BLOCKS
237 valid_expert = expert_id < num_experts
238 start_idx = tl.load(
239 tle.gpu.local_ptr(expert_starts_local, (expert_id,)),
240 mask=valid_expert,
241 other=0,
242 )
243 next_expert = expert_id + 1
244 has_next = valid_expert & (next_expert < num_experts)
245 end_idx = tl.load(
246 tle.gpu.local_ptr(expert_starts_local, (next_expert,)),
247 mask=has_next,
248 other=total_tokens,
249 )
250 end_idx = tl.where(has_next, end_idx, total_tokens)
251 start_idx = tl.where(valid_expert, start_idx, 0)
252 end_idx = tl.where(valid_expert, end_idx, 0)
253 for i in range(start_idx, end_idx, block_size):
254 tl.store(expert_ids_ptr + i // block_size, expert_id)
256 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS):
257 offs = base + token_offsets
258 mask = offs < numel
259 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
260 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
261 rank_with_prefix = tl.atomic_add(
262 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta"
263 )
264 rank_base = tl.load(
265 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), mask=mask, other=0
266 )
267 rank_post_pad = rank_with_prefix + rank_base
268 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask)
271@libentry()
272@libtuner(
273 configs=TLE_CLUSTER_LAUNCH_CONFIGS,
274 key=["numel"],
275 strategy=["log"],
276)
277@triton.jit(do_not_specialize=["numel"])
278def moe_align_block_size_tle_cluster_fused(
279 topk_ids_ptr,
280 sorted_token_ids_ptr,
281 expert_ids_ptr,
282 num_tokens_post_pad_ptr,
283 num_experts: tl.constexpr,
284 block_size: tl.constexpr,
285 numel,
286 numel_sorted_token_ids: tl.constexpr,
287 numel_expert_ids: tl.constexpr,
288 mesh: tl.constexpr,
289 CLUSTER_SIZE: tl.constexpr,
290 BLOCK_TOKENS: tl.constexpr,
291 BLOCK_EXPERT: tl.constexpr,
292 EXPERTS_PER_SHARD: tl.constexpr,
293):
294 cluster_rank = tle.shard_id(mesh, "cluster_x")
295 is_rank0 = cluster_rank == 0
296 expert_offsets = tl.arange(0, BLOCK_EXPERT)
297 expert_mask = expert_offsets < num_experts
299 init_offsets = tl.arange(0, BLOCK_TOKENS)
300 for base in range(
301 cluster_rank * BLOCK_TOKENS, numel_sorted_token_ids, CLUSTER_SIZE * BLOCK_TOKENS
302 ):
303 offs = base + init_offsets
304 mask = offs < numel_sorted_token_ids
305 tl.store(sorted_token_ids_ptr + offs, numel, mask=mask)
306 for base in range(
307 cluster_rank * BLOCK_TOKENS, numel_expert_ids, CLUSTER_SIZE * BLOCK_TOKENS
308 ):
309 offs = base + init_offsets
310 mask = offs < numel_expert_ids
311 tl.store(expert_ids_ptr + offs, 0, mask=mask)
313 local_counts = tle.gpu.alloc(
314 [BLOCK_EXPERT],
315 dtype=tl.int32,
316 layout=None,
317 scope=tle.gpu.smem,
318 nv_mma_shared_layout=False,
319 )
320 cumsum_local = tle.gpu.alloc(
321 [BLOCK_EXPERT],
322 dtype=tl.int32,
323 layout=None,
324 scope=tle.gpu.smem,
325 nv_mma_shared_layout=False,
326 )
328 rank0_cumsum_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_offsets,))
329 if is_rank0:
330 tl.store(rank0_cumsum_ptrs, 0, mask=expert_mask)
331 tle.distributed_barrier(mesh)
333 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,))
334 tl.store(local_counts_ptrs, 0, mask=expert_mask)
336 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS):
337 offs = base + init_offsets
338 mask = offs < numel
339 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
340 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
341 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta")
343 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0)
344 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh)
345 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,))
346 prefix_before = tl.atomic_add(
347 rank0_cumsum_remote_ptrs,
348 local_counts_vals,
349 mask=expert_mask,
350 sem="relaxed",
351 scope="cta",
352 )
353 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask)
355 tle.distributed_barrier(mesh)
357 if is_rank0:
358 total_counts = tl.load(rank0_cumsum_ptrs, mask=expert_mask, other=0)
359 aligned_counts = tl.cdiv(total_counts, block_size) * block_size
360 expert_cumsum_inclusive = tl.cumsum(aligned_counts, axis=0)
361 expert_start_offsets = expert_cumsum_inclusive - aligned_counts
362 tl.store(rank0_cumsum_ptrs, expert_start_offsets, mask=expert_mask)
363 total_tokens = tl.sum(aligned_counts, axis=0)
364 tl.store(num_tokens_post_pad_ptr, total_tokens)
366 tle.distributed_barrier(mesh)
368 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh)
369 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,))
370 cumsum_vals = tl.load(rank0_cumsum_remote_ptrs, mask=expert_mask, other=0)
371 tl.store(
372 tle.gpu.local_ptr(cumsum_local, (expert_offsets,)),
373 cumsum_vals,
374 mask=expert_mask,
375 )
376 total_tokens = tl.load(num_tokens_post_pad_ptr)
378 for local_expert_idx in range(EXPERTS_PER_SHARD):
379 expert_idx = cluster_rank * EXPERTS_PER_SHARD + local_expert_idx
380 expert_id = expert_idx
381 valid_expert = expert_id < num_experts
382 start_ptr = tle.gpu.local_ptr(cumsum_local, (expert_id,))
383 start_idx = tl.load(start_ptr, mask=valid_expert, other=0)
384 next_expert_id = expert_id + 1
385 has_next = valid_expert & (next_expert_id < num_experts)
386 next_ptr = tle.gpu.local_ptr(cumsum_local, (next_expert_id,))
387 end_from_next = tl.load(next_ptr, mask=has_next, other=0)
388 end_idx = tl.where(has_next, end_from_next, total_tokens)
389 start_idx = tl.where(valid_expert, start_idx, 0)
390 end_idx = tl.where(valid_expert, end_idx, 0)
391 for i in range(start_idx, end_idx, block_size):
392 tl.store(expert_ids_ptr + i // block_size, expert_idx)
394 tle.distributed_barrier(mesh)
396 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS):
397 offs = base + init_offsets
398 mask = offs < numel
399 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32)
400 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,))
401 rank_with_prefix = tl.atomic_add(
402 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta"
403 )
404 base_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_id,))
405 rank_base = tl.load(base_ptrs, mask=mask, other=0)
406 rank_post_pad = rank_with_prefix + rank_base
407 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask)
410@triton.jit(do_not_specialize=["numel"])
411def moe_align_block_size_stage1(
412 topk_ids_ptr,
413 tokens_cnts_ptr,
414 num_experts: tl.constexpr,
415 numel,
416 tokens_per_thread: tl.constexpr,
417 sorted_token_ids_ptr,
418 expert_ids_ptr,
419 numel_sorted_token_ids: tl.constexpr,
420 numel_expert_ids: tl.constexpr,
421 block_size_sorted: tl.constexpr,
422 block_size_expert: tl.constexpr,
423):
424 pid = tl.program_id(0)
426 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted)
427 mask_sorted = offsets_sorted < numel_sorted_token_ids
428 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted)
430 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert)
431 mask_expert = offsets_expert < numel_expert_ids
432 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert)
434 start_idx = pid * tokens_per_thread
436 off_c = (pid + 1) * num_experts
438 # Process tokens one at a time with scalar atomic_add to avoid
439 # a triton-ascend bug where vectorized atomic_add to the same
440 # address only executes once.
441 for i in range(tokens_per_thread):
442 offset = start_idx + i
443 if offset < numel:
444 expert_id = tl.load(topk_ids_ptr + offset)
445 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1)
448@triton.jit
449def moe_align_block_size_stage2_vec(
450 tokens_cnts_ptr,
451 num_experts: tl.constexpr,
452):
453 pid = tl.program_id(0)
455 offset = tl.arange(0, num_experts) + 1
456 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid)
457 cnt = tl.cumsum(token_cnt, axis=0)
458 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt)
461@triton.jit
462def moe_align_block_size_stage2(
463 tokens_cnts_ptr,
464 num_experts: tl.constexpr,
465):
466 pid = tl.program_id(0)
468 last_cnt = 0
469 for i in range(1, num_experts + 1):
470 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
471 last_cnt = last_cnt + token_cnt
472 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
475@triton.jit
476def moe_align_block_size_stage3(
477 total_tokens_post_pad_ptr,
478 tokens_cnts_ptr,
479 cumsum_ptr,
480 num_experts: tl.constexpr,
481 num_experts_next_power_of_2: tl.constexpr,
482 block_size: tl.constexpr,
483):
484 off_cnt = num_experts * num_experts
486 expert_offsets = tl.arange(0, num_experts_next_power_of_2)
487 mask = expert_offsets < num_experts
488 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask)
489 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size
491 cumsum_values = tl.cumsum(aligned_cnts, axis=0)
492 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask)
494 total_tokens = tl.sum(aligned_cnts, axis=0)
495 tl.store(total_tokens_post_pad_ptr, total_tokens)
498@triton.jit(do_not_specialize=["numel"])
499def moe_align_block_size_stage4(
500 topk_ids_ptr,
501 sorted_token_ids_ptr,
502 expert_ids_ptr,
503 tokens_cnts_ptr,
504 cumsum_ptr,
505 num_experts: tl.constexpr,
506 block_size: tl.constexpr,
507 numel,
508 tokens_per_thread: tl.constexpr,
509):
510 pid = tl.program_id(0)
511 start_idx = tl.load(cumsum_ptr + pid)
512 end_idx = tl.load(cumsum_ptr + pid + 1)
514 for i in range(start_idx, end_idx, block_size):
515 tl.store(expert_ids_ptr + i // block_size, pid)
517 start_idx = pid * tokens_per_thread
518 off_t = pid * num_experts
520 # Process tokens one at a time with scalar atomic_add to avoid
521 # a triton-ascend bug where vectorized atomic_add to the same
522 # address only executes once (drops all but one increment).
523 for i in range(tokens_per_thread):
524 offset = start_idx + i
525 if offset < numel:
526 expert_id = tl.load(topk_ids_ptr + offset)
527 tl.atomic_add(tokens_cnts_ptr + off_t + expert_id, 1)
528 token_idx_in_expert = tl.load(tokens_cnts_ptr + off_t + expert_id)
529 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id) - 1
530 tl.store(sorted_token_ids_ptr + rank_post_pad, offset)
533def moe_align_block_size_triton(
534 topk_ids: torch.Tensor,
535 num_experts: int,
536 block_size: int,
537 sorted_token_ids: torch.Tensor,
538 expert_ids: torch.Tensor,
539 num_tokens_post_pad: torch.Tensor,
540) -> None:
541 logger.debug("GEMS_ASCEND MOE ALIGN BLOCK SIZE")
542 numel = topk_ids.numel()
543 numel_sorted_token_ids = sorted_token_ids.numel()
544 numel_expert_ids = expert_ids.numel()
545 grid = (num_experts,)
546 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts))
547 block_size_sorted = triton.next_power_of_2(
548 ceil_div(numel_sorted_token_ids, num_experts)
549 )
550 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts))
551 block_expert_tle = triton.next_power_of_2(num_experts)
553 if HAS_TLE and topk_ids.is_cuda and block_expert_tle <= 1024:
554 block_tokens_taf, _ = _pick_tle_atomic_fused_launch_params(numel, num_experts)
555 experts_per_shard = ceil_div(num_experts, TLE_CLUSTER_SIZE)
556 num_tokens = topk_ids.shape[0] if topk_ids.ndim > 1 else numel
558 def _run_tle_atomic_fused() -> bool:
559 cumsum_tle = torch.zeros(
560 (num_experts,), dtype=torch.int32, device=topk_ids.device
561 )
562 num_blocks = _pick_tle_atomic_fused_num_blocks(
563 numel,
564 num_experts,
565 block_tokens_taf,
566 topk_ids.device,
567 )
568 experts_per_prog = ceil_div(num_experts, num_blocks)
569 while True:
570 try:
571 moe_align_block_size_tle_atomic_fused_coop[(num_blocks,)](
572 topk_ids,
573 sorted_token_ids,
574 expert_ids,
575 num_tokens_post_pad,
576 cumsum_tle,
577 _block_mesh(num_blocks),
578 num_experts,
579 block_size,
580 numel,
581 numel_sorted_token_ids,
582 numel_expert_ids,
583 NUM_BLOCKS=num_blocks,
584 BLOCK_TOKENS=block_tokens_taf,
585 BLOCK_EXPERT=block_expert_tle,
586 EXPERTS_PER_PROG=experts_per_prog,
587 launch_cooperative_grid=True,
588 )
589 return True
590 except Exception as ex:
591 msg = str(ex).lower()
592 if "no allocator was set" in msg:
593 _install_triton_default_allocator(topk_ids.device)
594 continue
595 if num_blocks <= 1 or "cooperative" not in msg:
596 logger.debug(
597 "TLE atomic fused launch failed, fallback to triton: %s",
598 ex,
599 )
600 return False
601 num_blocks = max(1, num_blocks // 2)
602 experts_per_prog = ceil_div(num_experts, num_blocks)
604 if (
605 num_tokens < TLE_BIG_TOKEN_THRESHOLD_TOKENS
606 and _supports_tle_cluster_remote()
607 ):
608 try:
609 moe_align_block_size_tle_cluster_fused[(1,)](
610 topk_ids,
611 sorted_token_ids,
612 expert_ids,
613 num_tokens_post_pad,
614 num_experts,
615 block_size,
616 numel,
617 numel_sorted_token_ids,
618 numel_expert_ids,
619 mesh=_block_cluster_mesh_8(),
620 CLUSTER_SIZE=TLE_CLUSTER_SIZE,
621 BLOCK_EXPERT=block_expert_tle,
622 EXPERTS_PER_SHARD=experts_per_shard,
623 )
624 return
625 except Exception as ex:
626 logger.debug(
627 "TLE cluster fused launch failed, fallback to atomic/triton: %s",
628 ex,
629 )
631 if _run_tle_atomic_fused():
632 return
634 # The tensor needs to be padded before calculating IDs,
635 # to prevent out-of-bounds address access.
636 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
637 tokens_cnts = torch.zeros(
638 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
639 )
640 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts)
642 moe_align_block_size_stage1[grid](
643 topk_ids,
644 tokens_cnts,
645 num_experts,
646 numel,
647 tokens_per_thread,
648 sorted_token_ids,
649 expert_ids,
650 numel_sorted_token_ids,
651 numel_expert_ids,
652 block_size_sorted,
653 block_size_expert,
654 )
655 if num_experts == triton.next_power_of_2(num_experts):
656 moe_align_block_size_stage2_vec[grid](tokens_cnts, num_experts)
657 else:
658 moe_align_block_size_stage2[grid](tokens_cnts, num_experts)
659 moe_align_block_size_stage3[(1,)](
660 num_tokens_post_pad,
661 tokens_cnts,
662 cumsum,
663 num_experts,
664 num_experts_next_power_of_2,
665 block_size,
666 )
667 moe_align_block_size_stage4[grid](
668 topk_ids,
669 sorted_token_ids,
670 expert_ids,
671 tokens_cnts,
672 cumsum,
673 num_experts,
674 block_size,
675 numel,
676 tokens_per_thread,
677 )
680def moe_align_block_size(
681 topk_ids: torch.Tensor,
682 block_size: int,
683 num_experts: int,
684 expert_map: Optional[torch.Tensor] = None,
685 pad_sorted_ids: bool = False,
686) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]":
687 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
688 if pad_sorted_ids:
689 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
690 sorted_ids = torch.empty(
691 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
692 )
693 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
694 expert_ids = torch.empty(
695 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
696 )
697 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
699 moe_align_block_size_triton(
700 topk_ids,
701 num_experts,
702 block_size,
703 sorted_ids,
704 expert_ids,
705 num_tokens_post_pad,
706 )
708 if expert_map is not None:
709 expert_ids = expert_map[expert_ids]
711 return sorted_ids, expert_ids, num_tokens_post_pad