Coverage for src/flag_gems/fused/moe_align_block_size.py: 23%

334 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2from functools import lru_cache 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.utils import has_triton_tle, libentry, libtuner 

10 

11if has_triton_tle(3, 6, 0): 

12 try: 

13 import triton.experimental.tle.language as tle 

14 import triton.experimental.tle.language.gpu as tleg 

15 

16 HAS_TLE = True 

17 except ImportError: 

18 tle = None 

19 tleg = None 

20 HAS_TLE = False 

21else: 

22 tle = None 

23 tleg = None 

24 HAS_TLE = False 

25 

26 

27logger = logging.getLogger(__name__) 

28 

29TLE_CLUSTER_SIZE = 8 

30TLE_BIG_TOKEN_THRESHOLD_TOKENS = 4096 

31_TRITON_ALLOCATOR_INSTALLED = False 

32TLE_ATOMIC_WARPS_CONFIGS = [ 

33 triton.Config(kwargs={}, num_warps=4), 

34 triton.Config(kwargs={}, num_warps=8), 

35] 

36TLE_CLUSTER_LAUNCH_CONFIGS = [ 

37 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=4), 

38 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=8), 

39 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=4), 

40 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=8), 

41 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=4), 

42 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=8), 

43 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=4), 

44 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=8), 

45] 

46 

47 

48def ceil_div(a, b): 

49 return (a + b - 1) // b 

50 

51 

52def round_up(x: int, y: int) -> int: 

53 return ((x + y - 1) // y) * y 

54 

55 

56@lru_cache(maxsize=64) 

57def _block_mesh(num_blocks: int): 

58 return tle.device_mesh({"block": [("block_x", int(num_blocks))]}) 

59 

60 

61@lru_cache(maxsize=1) 

62def _block_cluster_mesh_8(): 

63 return tle.device_mesh({"block_cluster": [("cluster_x", TLE_CLUSTER_SIZE)]}) 

64 

65 

66def _supports_tle_cluster_remote() -> bool: 

67 if not torch.cuda.is_available(): 

68 return False 

69 major, _minor = torch.cuda.get_device_capability() 

70 return major >= 9 

71 

72 

73def _install_triton_default_allocator(device: torch.device) -> None: 

74 global _TRITON_ALLOCATOR_INSTALLED 

75 if _TRITON_ALLOCATOR_INSTALLED: 

76 return 

77 

78 def _alloc(size: int, _alignment: int, _stream: Optional[int]): 

79 return torch.empty((size,), dtype=torch.uint8, device=device) 

80 

81 triton.set_allocator(_alloc) 

82 _TRITON_ALLOCATOR_INSTALLED = True 

83 

84 

85def _pick_tle_fused_launch_params(numel: int, num_experts: int) -> "tuple[int, int]": 

86 if num_experts >= 256: 

87 if numel >= 32768: 

88 return 4096, 4 

89 if numel >= 1024: 

90 return 1024, 4 

91 return 256, 8 

92 

93 if numel <= 512: 

94 return 128, 8 

95 if num_experts <= 64 and numel <= 2048: 

96 return 128, 8 

97 return 256, 8 

98 

99 

100def _pick_tle_atomic_fused_launch_params( 

101 numel: int, num_experts: int 

102) -> "tuple[int, int]": 

103 if num_experts >= 256: 

104 if numel <= 16384: 

105 return 256, 8 

106 if numel <= 32768: 

107 return 512, 4 

108 return 1024, 4 

109 return _pick_tle_fused_launch_params(numel, num_experts) 

110 

111 

112def _pick_tle_atomic_fused_num_blocks( 

113 numel: int, num_experts: int, block_tokens: int, device: torch.device 

114) -> int: 

115 if device.type != "cuda" or not torch.cuda.is_available(): 

116 return 1 

117 props = torch.cuda.get_device_properties(device) 

118 sm_count = int(getattr(props, "multi_processor_count", 1)) 

119 token_programs = triton.cdiv(numel, block_tokens) 

120 cap_mult = 4 if num_experts < 256 else 16 

121 block_cap = sm_count * cap_mult 

122 return max(1, min(token_programs, block_cap)) 

123 

124 

125@libentry() 

126@libtuner( 

127 configs=TLE_ATOMIC_WARPS_CONFIGS, 

128 key=["numel"], 

129 strategy=["log"], 

130) 

131@triton.jit(do_not_specialize=["numel"]) 

132def moe_align_block_size_tle_atomic_fused_coop( 

133 topk_ids_ptr, 

134 sorted_token_ids_ptr, 

135 expert_ids_ptr, 

136 num_tokens_post_pad_ptr, 

137 cumsum_ptr, 

138 mesh: tl.constexpr, 

139 num_experts: tl.constexpr, 

140 block_size: tl.constexpr, 

141 numel, 

142 numel_sorted_token_ids: tl.constexpr, 

143 numel_expert_ids: tl.constexpr, 

144 NUM_BLOCKS: tl.constexpr, 

145 BLOCK_TOKENS: tl.constexpr, 

146 BLOCK_EXPERT: tl.constexpr, 

147 EXPERTS_PER_PROG: tl.constexpr, 

148): 

149 pid = tl.program_id(0) 

150 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

151 expert_mask = expert_offsets < num_experts 

152 token_offsets = tl.arange(0, BLOCK_TOKENS) 

153 

154 for base in range( 

155 pid * BLOCK_TOKENS, numel_sorted_token_ids, NUM_BLOCKS * BLOCK_TOKENS 

156 ): 

157 offs = base + token_offsets 

158 tl.store(sorted_token_ids_ptr + offs, numel, mask=offs < numel_sorted_token_ids) 

159 for base in range(pid * BLOCK_TOKENS, numel_expert_ids, NUM_BLOCKS * BLOCK_TOKENS): 

160 offs = base + token_offsets 

161 tl.store(expert_ids_ptr + offs, 0, mask=offs < numel_expert_ids) 

162 if pid == 0: 

163 tl.store(cumsum_ptr + expert_offsets, 0, mask=expert_mask) 

164 tle.distributed_barrier(mesh) 

165 

166 local_counts = tle.gpu.alloc( 

167 [BLOCK_EXPERT], 

168 dtype=tl.int32, 

169 layout=None, 

170 scope=tle.gpu.smem, 

171 nv_mma_shared_layout=False, 

172 ) 

173 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,)) 

174 tl.store(local_counts_ptrs, 0, mask=expert_mask) 

175 

176 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS): 

177 offs = base + token_offsets 

178 mask = offs < numel 

179 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

180 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

181 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta") 

182 

183 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0) 

184 prefix_before = tl.atomic_add( 

185 cumsum_ptr + expert_offsets, 

186 local_counts_vals, 

187 mask=expert_mask, 

188 sem="acq_rel", 

189 scope="gpu", 

190 ) 

191 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask) 

192 tle.distributed_barrier(mesh) 

193 

194 if pid == 0: 

195 total_counts = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0) 

196 aligned_counts = tl.cdiv(total_counts, block_size) * block_size 

197 expert_starts = tl.cumsum(aligned_counts, axis=0) - aligned_counts 

198 tl.store(cumsum_ptr + expert_offsets, expert_starts, mask=expert_mask) 

199 total_tokens = tl.sum(aligned_counts, axis=0) 

200 tl.store(num_tokens_post_pad_ptr, total_tokens) 

201 tle.distributed_barrier(mesh) 

202 

203 expert_starts_local = tle.gpu.alloc( 

204 [BLOCK_EXPERT], 

205 dtype=tl.int32, 

206 layout=None, 

207 scope=tle.gpu.smem, 

208 nv_mma_shared_layout=False, 

209 ) 

210 expert_starts_ptrs = tle.gpu.local_ptr(expert_starts_local, (expert_offsets,)) 

211 expert_starts_vals = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0) 

212 tl.store(expert_starts_ptrs, expert_starts_vals, mask=expert_mask) 

213 

214 total_tokens = tl.load(num_tokens_post_pad_ptr) 

215 for local_expert_idx in range(EXPERTS_PER_PROG): 

216 expert_id = pid + local_expert_idx * NUM_BLOCKS 

217 valid_expert = expert_id < num_experts 

218 start_idx = tl.load( 

219 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), 

220 mask=valid_expert, 

221 other=0, 

222 ) 

223 next_expert = expert_id + 1 

224 has_next = valid_expert & (next_expert < num_experts) 

225 end_idx = tl.load( 

226 tle.gpu.local_ptr(expert_starts_local, (next_expert,)), 

227 mask=has_next, 

228 other=total_tokens, 

229 ) 

230 end_idx = tl.where(has_next, end_idx, total_tokens) 

231 start_idx = tl.where(valid_expert, start_idx, 0) 

232 end_idx = tl.where(valid_expert, end_idx, 0) 

233 for i in range(start_idx, end_idx, block_size): 

234 tl.store(expert_ids_ptr + i // block_size, expert_id) 

235 

236 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS): 

237 offs = base + token_offsets 

238 mask = offs < numel 

239 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

240 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

241 rank_with_prefix = tl.atomic_add( 

242 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta" 

243 ) 

244 rank_base = tl.load( 

245 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), mask=mask, other=0 

246 ) 

247 rank_post_pad = rank_with_prefix + rank_base 

248 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask) 

249 

250 

251@libentry() 

252@libtuner( 

253 configs=TLE_CLUSTER_LAUNCH_CONFIGS, 

254 key=["numel"], 

255 strategy=["log"], 

256) 

257@triton.jit(do_not_specialize=["numel"]) 

258def moe_align_block_size_tle_cluster_fused( 

259 topk_ids_ptr, 

260 sorted_token_ids_ptr, 

261 expert_ids_ptr, 

262 num_tokens_post_pad_ptr, 

263 num_experts: tl.constexpr, 

264 block_size: tl.constexpr, 

265 numel, 

266 numel_sorted_token_ids: tl.constexpr, 

267 numel_expert_ids: tl.constexpr, 

268 mesh: tl.constexpr, 

269 CLUSTER_SIZE: tl.constexpr, 

270 BLOCK_TOKENS: tl.constexpr, 

271 BLOCK_EXPERT: tl.constexpr, 

272 EXPERTS_PER_SHARD: tl.constexpr, 

273): 

274 cluster_rank = tle.shard_id(mesh, "cluster_x") 

275 is_rank0 = cluster_rank == 0 

276 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

277 expert_mask = expert_offsets < num_experts 

278 

279 init_offsets = tl.arange(0, BLOCK_TOKENS) 

280 for base in range( 

281 cluster_rank * BLOCK_TOKENS, numel_sorted_token_ids, CLUSTER_SIZE * BLOCK_TOKENS 

282 ): 

283 offs = base + init_offsets 

284 mask = offs < numel_sorted_token_ids 

285 tl.store(sorted_token_ids_ptr + offs, numel, mask=mask) 

286 for base in range( 

287 cluster_rank * BLOCK_TOKENS, numel_expert_ids, CLUSTER_SIZE * BLOCK_TOKENS 

288 ): 

289 offs = base + init_offsets 

290 mask = offs < numel_expert_ids 

291 tl.store(expert_ids_ptr + offs, 0, mask=mask) 

292 

293 local_counts = tle.gpu.alloc( 

294 [BLOCK_EXPERT], 

295 dtype=tl.int32, 

296 layout=None, 

297 scope=tle.gpu.smem, 

298 nv_mma_shared_layout=False, 

299 ) 

300 cumsum_local = tle.gpu.alloc( 

301 [BLOCK_EXPERT], 

302 dtype=tl.int32, 

303 layout=None, 

304 scope=tle.gpu.smem, 

305 nv_mma_shared_layout=False, 

306 ) 

307 

308 rank0_cumsum_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_offsets,)) 

309 if is_rank0: 

310 tl.store(rank0_cumsum_ptrs, 0, mask=expert_mask) 

311 tle.distributed_barrier(mesh) 

312 

313 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,)) 

314 tl.store(local_counts_ptrs, 0, mask=expert_mask) 

315 

316 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS): 

317 offs = base + init_offsets 

318 mask = offs < numel 

319 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

320 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

321 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta") 

322 

323 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0) 

324 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh) 

325 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,)) 

326 prefix_before = tl.atomic_add( 

327 rank0_cumsum_remote_ptrs, 

328 local_counts_vals, 

329 mask=expert_mask, 

330 sem="relaxed", 

331 scope="cta", 

332 ) 

333 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask) 

334 

335 tle.distributed_barrier(mesh) 

336 

337 if is_rank0: 

338 total_counts = tl.load(rank0_cumsum_ptrs, mask=expert_mask, other=0) 

339 aligned_counts = tl.cdiv(total_counts, block_size) * block_size 

340 expert_cumsum_inclusive = tl.cumsum(aligned_counts, axis=0) 

341 expert_start_offsets = expert_cumsum_inclusive - aligned_counts 

342 tl.store(rank0_cumsum_ptrs, expert_start_offsets, mask=expert_mask) 

343 total_tokens = tl.sum(aligned_counts, axis=0) 

344 tl.store(num_tokens_post_pad_ptr, total_tokens) 

345 

346 tle.distributed_barrier(mesh) 

347 

348 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh) 

349 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,)) 

350 cumsum_vals = tl.load(rank0_cumsum_remote_ptrs, mask=expert_mask, other=0) 

351 tl.store( 

352 tle.gpu.local_ptr(cumsum_local, (expert_offsets,)), 

353 cumsum_vals, 

354 mask=expert_mask, 

355 ) 

356 total_tokens = tl.load(num_tokens_post_pad_ptr) 

357 

358 for local_expert_idx in range(EXPERTS_PER_SHARD): 

359 expert_idx = cluster_rank * EXPERTS_PER_SHARD + local_expert_idx 

360 expert_id = expert_idx 

361 valid_expert = expert_id < num_experts 

362 start_ptr = tle.gpu.local_ptr(cumsum_local, (expert_id,)) 

363 start_idx = tl.load(start_ptr, mask=valid_expert, other=0) 

364 next_expert_id = expert_id + 1 

365 has_next = valid_expert & (next_expert_id < num_experts) 

366 next_ptr = tle.gpu.local_ptr(cumsum_local, (next_expert_id,)) 

367 end_from_next = tl.load(next_ptr, mask=has_next, other=0) 

368 end_idx = tl.where(has_next, end_from_next, total_tokens) 

369 start_idx = tl.where(valid_expert, start_idx, 0) 

370 end_idx = tl.where(valid_expert, end_idx, 0) 

371 for i in range(start_idx, end_idx, block_size): 

372 tl.store(expert_ids_ptr + i // block_size, expert_idx) 

373 

374 tle.distributed_barrier(mesh) 

375 

376 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS): 

377 offs = base + init_offsets 

378 mask = offs < numel 

379 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

380 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

381 rank_with_prefix = tl.atomic_add( 

382 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta" 

383 ) 

384 base_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_id,)) 

385 rank_base = tl.load(base_ptrs, mask=mask, other=0) 

386 rank_post_pad = rank_with_prefix + rank_base 

387 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask) 

388 

389 

390@triton.jit(do_not_specialize=["numel"]) 

391def moe_align_block_size_stage1( 

392 topk_ids_ptr, 

393 tokens_cnts_ptr, 

394 num_experts: tl.constexpr, 

395 numel, 

396 tokens_per_thread: tl.constexpr, 

397 sorted_token_ids_ptr, 

398 expert_ids_ptr, 

399 numel_sorted_token_ids: tl.constexpr, 

400 numel_expert_ids: tl.constexpr, 

401 block_size_sorted: tl.constexpr, 

402 block_size_expert: tl.constexpr, 

403): 

404 pid = tl.program_id(0) 

405 

406 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted) 

407 mask_sorted = offsets_sorted < numel_sorted_token_ids 

408 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted) 

409 

410 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert) 

411 mask_expert = offsets_expert < numel_expert_ids 

412 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert) 

413 

414 start_idx = pid * tokens_per_thread 

415 

416 off_c = (pid + 1) * num_experts 

417 

418 offsets = start_idx + tl.arange(0, tokens_per_thread) 

419 mask = offsets < numel 

420 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0) 

421 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1, mask=mask) 

422 

423 

424@triton.jit 

425def moe_align_block_size_stage2_vec( 

426 tokens_cnts_ptr, 

427 num_experts: tl.constexpr, 

428): 

429 pid = tl.program_id(0) 

430 

431 offset = tl.arange(0, num_experts) + 1 

432 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid) 

433 cnt = tl.cumsum(token_cnt, axis=0) 

434 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt) 

435 

436 

437@triton.jit 

438def moe_align_block_size_stage2( 

439 tokens_cnts_ptr, 

440 num_experts: tl.constexpr, 

441): 

442 pid = tl.program_id(0) 

443 

444 last_cnt = 0 

445 for i in range(1, num_experts + 1): 

446 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) 

447 last_cnt = last_cnt + token_cnt 

448 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) 

449 

450 

451@triton.jit 

452def moe_align_block_size_stage3( 

453 total_tokens_post_pad_ptr, 

454 tokens_cnts_ptr, 

455 cumsum_ptr, 

456 num_experts: tl.constexpr, 

457 num_experts_next_power_of_2: tl.constexpr, 

458 block_size: tl.constexpr, 

459): 

460 off_cnt = num_experts * num_experts 

461 

462 expert_offsets = tl.arange(0, num_experts_next_power_of_2) 

463 mask = expert_offsets < num_experts 

464 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask) 

465 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size 

466 

467 cumsum_values = tl.cumsum(aligned_cnts, axis=0) 

468 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask) 

469 

470 total_tokens = tl.sum(aligned_cnts, axis=0) 

471 tl.store(total_tokens_post_pad_ptr, total_tokens) 

472 

473 

474@triton.jit(do_not_specialize=["numel"]) 

475def moe_align_block_size_stage4( 

476 topk_ids_ptr, 

477 sorted_token_ids_ptr, 

478 expert_ids_ptr, 

479 tokens_cnts_ptr, 

480 cumsum_ptr, 

481 num_experts: tl.constexpr, 

482 block_size: tl.constexpr, 

483 numel, 

484 tokens_per_thread: tl.constexpr, 

485): 

486 pid = tl.program_id(0) 

487 start_idx = tl.load(cumsum_ptr + pid) 

488 end_idx = tl.load(cumsum_ptr + pid + 1) 

489 

490 for i in range(start_idx, end_idx, block_size): 

491 tl.store(expert_ids_ptr + i // block_size, pid) 

492 

493 start_idx = pid * tokens_per_thread 

494 off_t = pid * num_experts 

495 

496 offset = tl.arange(0, tokens_per_thread) + start_idx 

497 mask = offset < numel 

498 expert_id = tl.load(topk_ids_ptr + offset, mask=mask) 

499 token_idx_in_expert = tl.atomic_add( 

500 tokens_cnts_ptr + off_t + expert_id, 1, mask=mask 

501 ) 

502 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id, mask=mask) 

503 tl.store(sorted_token_ids_ptr + rank_post_pad, offset, mask=mask) 

504 

505 

506def moe_align_block_size_triton( 

507 topk_ids: torch.Tensor, 

508 num_experts: int, 

509 block_size: int, 

510 sorted_token_ids: torch.Tensor, 

511 expert_ids: torch.Tensor, 

512 num_tokens_post_pad: torch.Tensor, 

513) -> None: 

514 logger.debug("GEMS MOE ALIGN BLOCK SIZE") 

515 numel = topk_ids.numel() 

516 numel_sorted_token_ids = sorted_token_ids.numel() 

517 numel_expert_ids = expert_ids.numel() 

518 grid = (num_experts,) 

519 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts)) 

520 block_size_sorted = triton.next_power_of_2( 

521 ceil_div(numel_sorted_token_ids, num_experts) 

522 ) 

523 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts)) 

524 block_expert_tle = triton.next_power_of_2(num_experts) 

525 

526 if HAS_TLE and topk_ids.is_cuda and block_expert_tle <= 1024: 

527 block_tokens_taf, _ = _pick_tle_atomic_fused_launch_params(numel, num_experts) 

528 experts_per_shard = ceil_div(num_experts, TLE_CLUSTER_SIZE) 

529 num_tokens = topk_ids.shape[0] if topk_ids.ndim > 1 else numel 

530 

531 def _run_tle_atomic_fused() -> bool: 

532 cumsum_tle = torch.zeros( 

533 (num_experts,), dtype=torch.int32, device=topk_ids.device 

534 ) 

535 num_blocks = _pick_tle_atomic_fused_num_blocks( 

536 numel, 

537 num_experts, 

538 block_tokens_taf, 

539 topk_ids.device, 

540 ) 

541 experts_per_prog = ceil_div(num_experts, num_blocks) 

542 while True: 

543 try: 

544 moe_align_block_size_tle_atomic_fused_coop[(num_blocks,)]( 

545 topk_ids, 

546 sorted_token_ids, 

547 expert_ids, 

548 num_tokens_post_pad, 

549 cumsum_tle, 

550 _block_mesh(num_blocks), 

551 num_experts, 

552 block_size, 

553 numel, 

554 numel_sorted_token_ids, 

555 numel_expert_ids, 

556 NUM_BLOCKS=num_blocks, 

557 BLOCK_TOKENS=block_tokens_taf, 

558 BLOCK_EXPERT=block_expert_tle, 

559 EXPERTS_PER_PROG=experts_per_prog, 

560 launch_cooperative_grid=True, 

561 ) 

562 return True 

563 except Exception as ex: 

564 msg = str(ex).lower() 

565 if "no allocator was set" in msg: 

566 _install_triton_default_allocator(topk_ids.device) 

567 continue 

568 if num_blocks <= 1 or "cooperative" not in msg: 

569 logger.debug( 

570 "TLE atomic fused launch failed, fallback to triton: %s", 

571 ex, 

572 ) 

573 return False 

574 num_blocks = max(1, num_blocks // 2) 

575 experts_per_prog = ceil_div(num_experts, num_blocks) 

576 

577 if ( 

578 num_tokens < TLE_BIG_TOKEN_THRESHOLD_TOKENS 

579 and _supports_tle_cluster_remote() 

580 ): 

581 try: 

582 moe_align_block_size_tle_cluster_fused[(1,)]( 

583 topk_ids, 

584 sorted_token_ids, 

585 expert_ids, 

586 num_tokens_post_pad, 

587 num_experts, 

588 block_size, 

589 numel, 

590 numel_sorted_token_ids, 

591 numel_expert_ids, 

592 mesh=_block_cluster_mesh_8(), 

593 CLUSTER_SIZE=TLE_CLUSTER_SIZE, 

594 BLOCK_EXPERT=block_expert_tle, 

595 EXPERTS_PER_SHARD=experts_per_shard, 

596 ) 

597 return 

598 except Exception as ex: 

599 logger.debug( 

600 "TLE cluster fused launch failed, fallback to atomic/triton: %s", 

601 ex, 

602 ) 

603 

604 if _run_tle_atomic_fused(): 

605 return 

606 

607 # The tensor needs to be padded before calculating IDs, 

608 # to prevent out-of-bounds address access. 

609 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) 

610 tokens_cnts = torch.zeros( 

611 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device 

612 ) 

613 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts) 

614 

615 moe_align_block_size_stage1[grid]( 

616 topk_ids, 

617 tokens_cnts, 

618 num_experts, 

619 numel, 

620 tokens_per_thread, 

621 sorted_token_ids, 

622 expert_ids, 

623 numel_sorted_token_ids, 

624 numel_expert_ids, 

625 block_size_sorted, 

626 block_size_expert, 

627 ) 

628 if num_experts == triton.next_power_of_2(num_experts): 

629 moe_align_block_size_stage2_vec[grid](tokens_cnts, num_experts) 

630 else: 

631 moe_align_block_size_stage2[grid](tokens_cnts, num_experts) 

632 moe_align_block_size_stage3[(1,)]( 

633 num_tokens_post_pad, 

634 tokens_cnts, 

635 cumsum, 

636 num_experts, 

637 num_experts_next_power_of_2, 

638 block_size, 

639 ) 

640 moe_align_block_size_stage4[grid]( 

641 topk_ids, 

642 sorted_token_ids, 

643 expert_ids, 

644 tokens_cnts, 

645 cumsum, 

646 num_experts, 

647 block_size, 

648 numel, 

649 tokens_per_thread, 

650 ) 

651 

652 

653def moe_align_block_size( 

654 topk_ids: torch.Tensor, 

655 block_size: int, 

656 num_experts: int, 

657 expert_map: Optional[torch.Tensor] = None, 

658 pad_sorted_ids: bool = False, 

659) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]": 

660 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) 

661 if pad_sorted_ids: 

662 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) 

663 sorted_ids = torch.empty( 

664 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device 

665 ) 

666 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) 

667 expert_ids = torch.empty( 

668 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device 

669 ) 

670 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) 

671 

672 moe_align_block_size_triton( 

673 topk_ids, 

674 num_experts, 

675 block_size, 

676 sorted_ids, 

677 expert_ids, 

678 num_tokens_post_pad, 

679 ) 

680 

681 if expert_map is not None: 

682 expert_ids = expert_map[expert_ids] 

683 

684 return sorted_ids, expert_ids, num_tokens_post_pad