Coverage for src/flag_gems/fused/moe_align_block_size.py: 26%

348 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2from functools import lru_cache 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.utils import libentry, libtuner 

10 

11 

12def _triton_version_at_least(major: int, minor: int, patch: int = 0) -> bool: 

13 version = str(getattr(triton, "__version__", "0.0.0")).split("+", 1)[0] 

14 parts = version.split(".") 

15 parsed = [] 

16 for part in parts[:3]: 

17 digits = [] 

18 for ch in part: 

19 if ch.isdigit(): 

20 digits.append(ch) 

21 else: 

22 break 

23 parsed.append(int("".join(digits)) if digits else 0) 

24 while len(parsed) < 3: 

25 parsed.append(0) 

26 return tuple(parsed) >= (major, minor, patch) 

27 

28 

29if _triton_version_at_least(3, 6, 0): 

30 try: 

31 import triton.experimental.tle.language as tle 

32 import triton.experimental.tle.language.gpu as tleg 

33 

34 HAS_TLE = True 

35 except ImportError: 

36 tle = None 

37 tleg = None 

38 HAS_TLE = False 

39else: 

40 tle = None 

41 tleg = None 

42 HAS_TLE = False 

43 

44 

45logger = logging.getLogger(__name__) 

46 

47TLE_CLUSTER_SIZE = 8 

48TLE_BIG_TOKEN_THRESHOLD_TOKENS = 4096 

49_TRITON_ALLOCATOR_INSTALLED = False 

50TLE_ATOMIC_WARPS_CONFIGS = [ 

51 triton.Config(kwargs={}, num_warps=4), 

52 triton.Config(kwargs={}, num_warps=8), 

53] 

54TLE_CLUSTER_LAUNCH_CONFIGS = [ 

55 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=4), 

56 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=8), 

57 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=4), 

58 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=8), 

59 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=4), 

60 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=8), 

61 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=4), 

62 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=8), 

63] 

64 

65 

66def ceil_div(a, b): 

67 return (a + b - 1) // b 

68 

69 

70def round_up(x: int, y: int) -> int: 

71 return ((x + y - 1) // y) * y 

72 

73 

74@lru_cache(maxsize=64) 

75def _block_mesh(num_blocks: int): 

76 return tle.device_mesh({"block": [("block_x", int(num_blocks))]}) 

77 

78 

79@lru_cache(maxsize=1) 

80def _block_cluster_mesh_8(): 

81 return tle.device_mesh({"block_cluster": [("cluster_x", TLE_CLUSTER_SIZE)]}) 

82 

83 

84def _supports_tle_cluster_remote() -> bool: 

85 if not torch.cuda.is_available(): 

86 return False 

87 major, _minor = torch.cuda.get_device_capability() 

88 return major >= 9 

89 

90 

91def _install_triton_default_allocator(device: torch.device) -> None: 

92 global _TRITON_ALLOCATOR_INSTALLED 

93 if _TRITON_ALLOCATOR_INSTALLED: 

94 return 

95 

96 def _alloc(size: int, _alignment: int, _stream: Optional[int]): 

97 return torch.empty((size,), dtype=torch.uint8, device=device) 

98 

99 triton.set_allocator(_alloc) 

100 _TRITON_ALLOCATOR_INSTALLED = True 

101 

102 

103def _pick_tle_fused_launch_params(numel: int, num_experts: int) -> "tuple[int, int]": 

104 if num_experts >= 256: 

105 if numel >= 32768: 

106 return 4096, 4 

107 if numel >= 1024: 

108 return 1024, 4 

109 return 256, 8 

110 

111 if numel <= 512: 

112 return 128, 8 

113 if num_experts <= 64 and numel <= 2048: 

114 return 128, 8 

115 return 256, 8 

116 

117 

118def _pick_tle_atomic_fused_launch_params( 

119 numel: int, num_experts: int 

120) -> "tuple[int, int]": 

121 if num_experts >= 256: 

122 if numel <= 16384: 

123 return 256, 8 

124 if numel <= 32768: 

125 return 512, 4 

126 return 1024, 4 

127 return _pick_tle_fused_launch_params(numel, num_experts) 

128 

129 

130def _pick_tle_atomic_fused_num_blocks( 

131 numel: int, num_experts: int, block_tokens: int, device: torch.device 

132) -> int: 

133 if device.type != "cuda" or not torch.cuda.is_available(): 

134 return 1 

135 props = torch.cuda.get_device_properties(device) 

136 sm_count = int(getattr(props, "multi_processor_count", 1)) 

137 token_programs = triton.cdiv(numel, block_tokens) 

138 cap_mult = 4 if num_experts < 256 else 16 

139 block_cap = sm_count * cap_mult 

140 return max(1, min(token_programs, block_cap)) 

141 

142 

143@libentry() 

144@libtuner( 

145 configs=TLE_ATOMIC_WARPS_CONFIGS, 

146 key=["numel"], 

147 strategy=["log"], 

148) 

149@triton.jit(do_not_specialize=["numel"]) 

150def moe_align_block_size_tle_atomic_fused_coop( 

151 topk_ids_ptr, 

152 sorted_token_ids_ptr, 

153 expert_ids_ptr, 

154 num_tokens_post_pad_ptr, 

155 cumsum_ptr, 

156 mesh: tl.constexpr, 

157 num_experts: tl.constexpr, 

158 block_size: tl.constexpr, 

159 numel, 

160 numel_sorted_token_ids: tl.constexpr, 

161 numel_expert_ids: tl.constexpr, 

162 NUM_BLOCKS: tl.constexpr, 

163 BLOCK_TOKENS: tl.constexpr, 

164 BLOCK_EXPERT: tl.constexpr, 

165 EXPERTS_PER_PROG: tl.constexpr, 

166): 

167 pid = tl.program_id(0) 

168 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

169 expert_mask = expert_offsets < num_experts 

170 token_offsets = tl.arange(0, BLOCK_TOKENS) 

171 

172 for base in range( 

173 pid * BLOCK_TOKENS, numel_sorted_token_ids, NUM_BLOCKS * BLOCK_TOKENS 

174 ): 

175 offs = base + token_offsets 

176 tl.store(sorted_token_ids_ptr + offs, numel, mask=offs < numel_sorted_token_ids) 

177 for base in range(pid * BLOCK_TOKENS, numel_expert_ids, NUM_BLOCKS * BLOCK_TOKENS): 

178 offs = base + token_offsets 

179 tl.store(expert_ids_ptr + offs, 0, mask=offs < numel_expert_ids) 

180 if pid == 0: 

181 tl.store(cumsum_ptr + expert_offsets, 0, mask=expert_mask) 

182 tle.distributed_barrier(mesh) 

183 

184 local_counts = tle.gpu.alloc( 

185 [BLOCK_EXPERT], 

186 dtype=tl.int32, 

187 layout=None, 

188 scope=tle.gpu.smem, 

189 nv_mma_shared_layout=False, 

190 ) 

191 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,)) 

192 tl.store(local_counts_ptrs, 0, mask=expert_mask) 

193 

194 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS): 

195 offs = base + token_offsets 

196 mask = offs < numel 

197 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

198 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

199 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta") 

200 

201 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0) 

202 prefix_before = tl.atomic_add( 

203 cumsum_ptr + expert_offsets, 

204 local_counts_vals, 

205 mask=expert_mask, 

206 sem="acq_rel", 

207 scope="gpu", 

208 ) 

209 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask) 

210 tle.distributed_barrier(mesh) 

211 

212 if pid == 0: 

213 total_counts = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0) 

214 aligned_counts = tl.cdiv(total_counts, block_size) * block_size 

215 expert_starts = tl.cumsum(aligned_counts, axis=0) - aligned_counts 

216 tl.store(cumsum_ptr + expert_offsets, expert_starts, mask=expert_mask) 

217 total_tokens = tl.sum(aligned_counts, axis=0) 

218 tl.store(num_tokens_post_pad_ptr, total_tokens) 

219 tle.distributed_barrier(mesh) 

220 

221 expert_starts_local = tle.gpu.alloc( 

222 [BLOCK_EXPERT], 

223 dtype=tl.int32, 

224 layout=None, 

225 scope=tle.gpu.smem, 

226 nv_mma_shared_layout=False, 

227 ) 

228 expert_starts_ptrs = tle.gpu.local_ptr(expert_starts_local, (expert_offsets,)) 

229 expert_starts_vals = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0) 

230 tl.store(expert_starts_ptrs, expert_starts_vals, mask=expert_mask) 

231 

232 total_tokens = tl.load(num_tokens_post_pad_ptr) 

233 for local_expert_idx in range(EXPERTS_PER_PROG): 

234 expert_id = pid + local_expert_idx * NUM_BLOCKS 

235 valid_expert = expert_id < num_experts 

236 start_idx = tl.load( 

237 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), 

238 mask=valid_expert, 

239 other=0, 

240 ) 

241 next_expert = expert_id + 1 

242 has_next = valid_expert & (next_expert < num_experts) 

243 end_idx = tl.load( 

244 tle.gpu.local_ptr(expert_starts_local, (next_expert,)), 

245 mask=has_next, 

246 other=total_tokens, 

247 ) 

248 end_idx = tl.where(has_next, end_idx, total_tokens) 

249 start_idx = tl.where(valid_expert, start_idx, 0) 

250 end_idx = tl.where(valid_expert, end_idx, 0) 

251 for i in range(start_idx, end_idx, block_size): 

252 tl.store(expert_ids_ptr + i // block_size, expert_id) 

253 

254 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS): 

255 offs = base + token_offsets 

256 mask = offs < numel 

257 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

258 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

259 rank_with_prefix = tl.atomic_add( 

260 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta" 

261 ) 

262 rank_base = tl.load( 

263 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), mask=mask, other=0 

264 ) 

265 rank_post_pad = rank_with_prefix + rank_base 

266 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask) 

267 

268 

269@libentry() 

270@libtuner( 

271 configs=TLE_CLUSTER_LAUNCH_CONFIGS, 

272 key=["numel"], 

273 strategy=["log"], 

274) 

275@triton.jit(do_not_specialize=["numel"]) 

276def moe_align_block_size_tle_cluster_fused( 

277 topk_ids_ptr, 

278 sorted_token_ids_ptr, 

279 expert_ids_ptr, 

280 num_tokens_post_pad_ptr, 

281 num_experts: tl.constexpr, 

282 block_size: tl.constexpr, 

283 numel, 

284 numel_sorted_token_ids: tl.constexpr, 

285 numel_expert_ids: tl.constexpr, 

286 mesh: tl.constexpr, 

287 CLUSTER_SIZE: tl.constexpr, 

288 BLOCK_TOKENS: tl.constexpr, 

289 BLOCK_EXPERT: tl.constexpr, 

290 EXPERTS_PER_SHARD: tl.constexpr, 

291): 

292 cluster_rank = tle.shard_id(mesh, "cluster_x") 

293 is_rank0 = cluster_rank == 0 

294 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

295 expert_mask = expert_offsets < num_experts 

296 

297 init_offsets = tl.arange(0, BLOCK_TOKENS) 

298 for base in range( 

299 cluster_rank * BLOCK_TOKENS, numel_sorted_token_ids, CLUSTER_SIZE * BLOCK_TOKENS 

300 ): 

301 offs = base + init_offsets 

302 mask = offs < numel_sorted_token_ids 

303 tl.store(sorted_token_ids_ptr + offs, numel, mask=mask) 

304 for base in range( 

305 cluster_rank * BLOCK_TOKENS, numel_expert_ids, CLUSTER_SIZE * BLOCK_TOKENS 

306 ): 

307 offs = base + init_offsets 

308 mask = offs < numel_expert_ids 

309 tl.store(expert_ids_ptr + offs, 0, mask=mask) 

310 

311 local_counts = tle.gpu.alloc( 

312 [BLOCK_EXPERT], 

313 dtype=tl.int32, 

314 layout=None, 

315 scope=tle.gpu.smem, 

316 nv_mma_shared_layout=False, 

317 ) 

318 cumsum_local = tle.gpu.alloc( 

319 [BLOCK_EXPERT], 

320 dtype=tl.int32, 

321 layout=None, 

322 scope=tle.gpu.smem, 

323 nv_mma_shared_layout=False, 

324 ) 

325 

326 rank0_cumsum_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_offsets,)) 

327 if is_rank0: 

328 tl.store(rank0_cumsum_ptrs, 0, mask=expert_mask) 

329 tle.distributed_barrier(mesh) 

330 

331 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,)) 

332 tl.store(local_counts_ptrs, 0, mask=expert_mask) 

333 

334 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS): 

335 offs = base + init_offsets 

336 mask = offs < numel 

337 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

338 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

339 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta") 

340 

341 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0) 

342 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh) 

343 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,)) 

344 prefix_before = tl.atomic_add( 

345 rank0_cumsum_remote_ptrs, 

346 local_counts_vals, 

347 mask=expert_mask, 

348 sem="relaxed", 

349 scope="cta", 

350 ) 

351 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask) 

352 

353 tle.distributed_barrier(mesh) 

354 

355 if is_rank0: 

356 total_counts = tl.load(rank0_cumsum_ptrs, mask=expert_mask, other=0) 

357 aligned_counts = tl.cdiv(total_counts, block_size) * block_size 

358 expert_cumsum_inclusive = tl.cumsum(aligned_counts, axis=0) 

359 expert_start_offsets = expert_cumsum_inclusive - aligned_counts 

360 tl.store(rank0_cumsum_ptrs, expert_start_offsets, mask=expert_mask) 

361 total_tokens = tl.sum(aligned_counts, axis=0) 

362 tl.store(num_tokens_post_pad_ptr, total_tokens) 

363 

364 tle.distributed_barrier(mesh) 

365 

366 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh) 

367 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,)) 

368 cumsum_vals = tl.load(rank0_cumsum_remote_ptrs, mask=expert_mask, other=0) 

369 tl.store( 

370 tle.gpu.local_ptr(cumsum_local, (expert_offsets,)), 

371 cumsum_vals, 

372 mask=expert_mask, 

373 ) 

374 total_tokens = tl.load(num_tokens_post_pad_ptr) 

375 

376 for local_expert_idx in range(EXPERTS_PER_SHARD): 

377 expert_idx = cluster_rank * EXPERTS_PER_SHARD + local_expert_idx 

378 expert_id = expert_idx 

379 valid_expert = expert_id < num_experts 

380 start_ptr = tle.gpu.local_ptr(cumsum_local, (expert_id,)) 

381 start_idx = tl.load(start_ptr, mask=valid_expert, other=0) 

382 next_expert_id = expert_id + 1 

383 has_next = valid_expert & (next_expert_id < num_experts) 

384 next_ptr = tle.gpu.local_ptr(cumsum_local, (next_expert_id,)) 

385 end_from_next = tl.load(next_ptr, mask=has_next, other=0) 

386 end_idx = tl.where(has_next, end_from_next, total_tokens) 

387 start_idx = tl.where(valid_expert, start_idx, 0) 

388 end_idx = tl.where(valid_expert, end_idx, 0) 

389 for i in range(start_idx, end_idx, block_size): 

390 tl.store(expert_ids_ptr + i // block_size, expert_idx) 

391 

392 tle.distributed_barrier(mesh) 

393 

394 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS): 

395 offs = base + init_offsets 

396 mask = offs < numel 

397 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

398 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

399 rank_with_prefix = tl.atomic_add( 

400 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta" 

401 ) 

402 base_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_id,)) 

403 rank_base = tl.load(base_ptrs, mask=mask, other=0) 

404 rank_post_pad = rank_with_prefix + rank_base 

405 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask) 

406 

407 

408@triton.jit(do_not_specialize=["numel"]) 

409def moe_align_block_size_stage1( 

410 topk_ids_ptr, 

411 tokens_cnts_ptr, 

412 num_experts: tl.constexpr, 

413 numel, 

414 tokens_per_thread: tl.constexpr, 

415 sorted_token_ids_ptr, 

416 expert_ids_ptr, 

417 numel_sorted_token_ids: tl.constexpr, 

418 numel_expert_ids: tl.constexpr, 

419 block_size_sorted: tl.constexpr, 

420 block_size_expert: tl.constexpr, 

421): 

422 pid = tl.program_id(0) 

423 

424 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted) 

425 mask_sorted = offsets_sorted < numel_sorted_token_ids 

426 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted) 

427 

428 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert) 

429 mask_expert = offsets_expert < numel_expert_ids 

430 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert) 

431 

432 start_idx = pid * tokens_per_thread 

433 

434 off_c = (pid + 1) * num_experts 

435 

436 offsets = start_idx + tl.arange(0, tokens_per_thread) 

437 mask = offsets < numel 

438 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0) 

439 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1, mask=mask) 

440 

441 

442@triton.jit 

443def moe_align_block_size_stage2_vec( 

444 tokens_cnts_ptr, 

445 num_experts: tl.constexpr, 

446): 

447 pid = tl.program_id(0) 

448 

449 offset = tl.arange(0, num_experts) + 1 

450 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid) 

451 cnt = tl.cumsum(token_cnt, axis=0) 

452 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt) 

453 

454 

455@triton.jit 

456def moe_align_block_size_stage2( 

457 tokens_cnts_ptr, 

458 num_experts: tl.constexpr, 

459): 

460 pid = tl.program_id(0) 

461 

462 last_cnt = 0 

463 for i in range(1, num_experts + 1): 

464 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) 

465 last_cnt = last_cnt + token_cnt 

466 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) 

467 

468 

469@triton.jit 

470def moe_align_block_size_stage3( 

471 total_tokens_post_pad_ptr, 

472 tokens_cnts_ptr, 

473 cumsum_ptr, 

474 num_experts: tl.constexpr, 

475 num_experts_next_power_of_2: tl.constexpr, 

476 block_size: tl.constexpr, 

477): 

478 off_cnt = num_experts * num_experts 

479 

480 expert_offsets = tl.arange(0, num_experts_next_power_of_2) 

481 mask = expert_offsets < num_experts 

482 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask) 

483 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size 

484 

485 cumsum_values = tl.cumsum(aligned_cnts, axis=0) 

486 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask) 

487 

488 total_tokens = tl.sum(aligned_cnts, axis=0) 

489 tl.store(total_tokens_post_pad_ptr, total_tokens) 

490 

491 

492@triton.jit(do_not_specialize=["numel"]) 

493def moe_align_block_size_stage4( 

494 topk_ids_ptr, 

495 sorted_token_ids_ptr, 

496 expert_ids_ptr, 

497 tokens_cnts_ptr, 

498 cumsum_ptr, 

499 num_experts: tl.constexpr, 

500 block_size: tl.constexpr, 

501 numel, 

502 tokens_per_thread: tl.constexpr, 

503): 

504 pid = tl.program_id(0) 

505 start_idx = tl.load(cumsum_ptr + pid) 

506 end_idx = tl.load(cumsum_ptr + pid + 1) 

507 

508 for i in range(start_idx, end_idx, block_size): 

509 tl.store(expert_ids_ptr + i // block_size, pid) 

510 

511 start_idx = pid * tokens_per_thread 

512 off_t = pid * num_experts 

513 

514 offset = tl.arange(0, tokens_per_thread) + start_idx 

515 mask = offset < numel 

516 expert_id = tl.load(topk_ids_ptr + offset, mask=mask) 

517 token_idx_in_expert = tl.atomic_add( 

518 tokens_cnts_ptr + off_t + expert_id, 1, mask=mask 

519 ) 

520 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id, mask=mask) 

521 tl.store(sorted_token_ids_ptr + rank_post_pad, offset, mask=mask) 

522 

523 

524def moe_align_block_size_triton( 

525 topk_ids: torch.Tensor, 

526 num_experts: int, 

527 block_size: int, 

528 sorted_token_ids: torch.Tensor, 

529 expert_ids: torch.Tensor, 

530 num_tokens_post_pad: torch.Tensor, 

531) -> None: 

532 logger.debug("GEMS MOE ALIGN BLOCK SIZE") 

533 numel = topk_ids.numel() 

534 numel_sorted_token_ids = sorted_token_ids.numel() 

535 numel_expert_ids = expert_ids.numel() 

536 grid = (num_experts,) 

537 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts)) 

538 block_size_sorted = triton.next_power_of_2( 

539 ceil_div(numel_sorted_token_ids, num_experts) 

540 ) 

541 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts)) 

542 block_expert_tle = triton.next_power_of_2(num_experts) 

543 

544 if HAS_TLE and topk_ids.is_cuda and block_expert_tle <= 1024: 

545 block_tokens_taf, _ = _pick_tle_atomic_fused_launch_params(numel, num_experts) 

546 experts_per_shard = ceil_div(num_experts, TLE_CLUSTER_SIZE) 

547 num_tokens = topk_ids.shape[0] if topk_ids.ndim > 1 else numel 

548 

549 def _run_tle_atomic_fused() -> bool: 

550 cumsum_tle = torch.zeros( 

551 (num_experts,), dtype=torch.int32, device=topk_ids.device 

552 ) 

553 num_blocks = _pick_tle_atomic_fused_num_blocks( 

554 numel, 

555 num_experts, 

556 block_tokens_taf, 

557 topk_ids.device, 

558 ) 

559 experts_per_prog = ceil_div(num_experts, num_blocks) 

560 while True: 

561 try: 

562 moe_align_block_size_tle_atomic_fused_coop[(num_blocks,)]( 

563 topk_ids, 

564 sorted_token_ids, 

565 expert_ids, 

566 num_tokens_post_pad, 

567 cumsum_tle, 

568 _block_mesh(num_blocks), 

569 num_experts, 

570 block_size, 

571 numel, 

572 numel_sorted_token_ids, 

573 numel_expert_ids, 

574 NUM_BLOCKS=num_blocks, 

575 BLOCK_TOKENS=block_tokens_taf, 

576 BLOCK_EXPERT=block_expert_tle, 

577 EXPERTS_PER_PROG=experts_per_prog, 

578 launch_cooperative_grid=True, 

579 ) 

580 return True 

581 except Exception as ex: 

582 msg = str(ex).lower() 

583 if "no allocator was set" in msg: 

584 _install_triton_default_allocator(topk_ids.device) 

585 continue 

586 if num_blocks <= 1 or "cooperative" not in msg: 

587 logger.debug( 

588 "TLE atomic fused launch failed, fallback to triton: %s", 

589 ex, 

590 ) 

591 return False 

592 num_blocks = max(1, num_blocks // 2) 

593 experts_per_prog = ceil_div(num_experts, num_blocks) 

594 

595 if ( 

596 num_tokens < TLE_BIG_TOKEN_THRESHOLD_TOKENS 

597 and _supports_tle_cluster_remote() 

598 ): 

599 try: 

600 moe_align_block_size_tle_cluster_fused[(1,)]( 

601 topk_ids, 

602 sorted_token_ids, 

603 expert_ids, 

604 num_tokens_post_pad, 

605 num_experts, 

606 block_size, 

607 numel, 

608 numel_sorted_token_ids, 

609 numel_expert_ids, 

610 mesh=_block_cluster_mesh_8(), 

611 CLUSTER_SIZE=TLE_CLUSTER_SIZE, 

612 BLOCK_EXPERT=block_expert_tle, 

613 EXPERTS_PER_SHARD=experts_per_shard, 

614 ) 

615 return 

616 except Exception as ex: 

617 logger.debug( 

618 "TLE cluster fused launch failed, fallback to atomic/triton: %s", 

619 ex, 

620 ) 

621 

622 if _run_tle_atomic_fused(): 

623 return 

624 

625 # The tensor needs to be padded before calculating IDs, 

626 # to prevent out-of-bounds address access. 

627 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) 

628 tokens_cnts = torch.zeros( 

629 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device 

630 ) 

631 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts) 

632 

633 moe_align_block_size_stage1[grid]( 

634 topk_ids, 

635 tokens_cnts, 

636 num_experts, 

637 numel, 

638 tokens_per_thread, 

639 sorted_token_ids, 

640 expert_ids, 

641 numel_sorted_token_ids, 

642 numel_expert_ids, 

643 block_size_sorted, 

644 block_size_expert, 

645 ) 

646 if num_experts == triton.next_power_of_2(num_experts): 

647 moe_align_block_size_stage2_vec[grid](tokens_cnts, num_experts) 

648 else: 

649 moe_align_block_size_stage2[grid](tokens_cnts, num_experts) 

650 moe_align_block_size_stage3[(1,)]( 

651 num_tokens_post_pad, 

652 tokens_cnts, 

653 cumsum, 

654 num_experts, 

655 num_experts_next_power_of_2, 

656 block_size, 

657 ) 

658 moe_align_block_size_stage4[grid]( 

659 topk_ids, 

660 sorted_token_ids, 

661 expert_ids, 

662 tokens_cnts, 

663 cumsum, 

664 num_experts, 

665 block_size, 

666 numel, 

667 tokens_per_thread, 

668 ) 

669 

670 

671def moe_align_block_size( 

672 topk_ids: torch.Tensor, 

673 block_size: int, 

674 num_experts: int, 

675 expert_map: Optional[torch.Tensor] = None, 

676 pad_sorted_ids: bool = False, 

677) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]": 

678 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) 

679 if pad_sorted_ids: 

680 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) 

681 sorted_ids = torch.empty( 

682 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device 

683 ) 

684 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) 

685 expert_ids = torch.empty( 

686 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device 

687 ) 

688 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) 

689 

690 moe_align_block_size_triton( 

691 topk_ids, 

692 num_experts, 

693 block_size, 

694 sorted_ids, 

695 expert_ids, 

696 num_tokens_post_pad, 

697 ) 

698 

699 if expert_map is not None: 

700 expert_ids = expert_map[expert_ids] 

701 

702 return sorted_ids, expert_ids, num_tokens_post_pad