Coverage for src/flag_gems/fused/moe_align_block_size.py: 24%

328 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2from functools import lru_cache 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.utils import libentry, libtuner 

10 

11try: 

12 import triton.experimental.tle.language as tle 

13 

14 HAS_TLE = True 

15except ImportError: 

16 tle = None 

17 HAS_TLE = False 

18 

19 

20logger = logging.getLogger(__name__) 

21 

22TLE_CLUSTER_SIZE = 8 

23TLE_BIG_TOKEN_THRESHOLD_TOKENS = 4096 

24_TRITON_ALLOCATOR_INSTALLED = False 

25TLE_ATOMIC_WARPS_CONFIGS = [ 

26 triton.Config(kwargs={}, num_warps=4), 

27 triton.Config(kwargs={}, num_warps=8), 

28] 

29TLE_CLUSTER_LAUNCH_CONFIGS = [ 

30 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=4), 

31 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=8), 

32 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=4), 

33 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=8), 

34 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=4), 

35 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=8), 

36 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=4), 

37 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=8), 

38] 

39 

40 

41def ceil_div(a, b): 

42 return (a + b - 1) // b 

43 

44 

45def round_up(x: int, y: int) -> int: 

46 return ((x + y - 1) // y) * y 

47 

48 

49@lru_cache(maxsize=64) 

50def _block_mesh(num_blocks: int): 

51 return tle.device_mesh({"block": [("block_x", int(num_blocks))]}) 

52 

53 

54@lru_cache(maxsize=1) 

55def _block_cluster_mesh_8(): 

56 return tle.device_mesh({"block_cluster": [("cluster_x", TLE_CLUSTER_SIZE)]}) 

57 

58 

59def _supports_tle_cluster_remote() -> bool: 

60 if not torch.cuda.is_available(): 

61 return False 

62 major, _minor = torch.cuda.get_device_capability() 

63 return major >= 9 

64 

65 

66def _install_triton_default_allocator(device: torch.device) -> None: 

67 global _TRITON_ALLOCATOR_INSTALLED 

68 if _TRITON_ALLOCATOR_INSTALLED: 

69 return 

70 

71 def _alloc(size: int, _alignment: int, _stream: Optional[int]): 

72 return torch.empty((size,), dtype=torch.uint8, device=device) 

73 

74 triton.set_allocator(_alloc) 

75 _TRITON_ALLOCATOR_INSTALLED = True 

76 

77 

78def _pick_tle_fused_launch_params(numel: int, num_experts: int) -> "tuple[int, int]": 

79 if num_experts >= 256: 

80 if numel >= 32768: 

81 return 4096, 4 

82 if numel >= 1024: 

83 return 1024, 4 

84 return 256, 8 

85 

86 if numel <= 512: 

87 return 128, 8 

88 if num_experts <= 64 and numel <= 2048: 

89 return 128, 8 

90 return 256, 8 

91 

92 

93def _pick_tle_atomic_fused_launch_params( 

94 numel: int, num_experts: int 

95) -> "tuple[int, int]": 

96 if num_experts >= 256: 

97 if numel <= 16384: 

98 return 256, 8 

99 if numel <= 32768: 

100 return 512, 4 

101 return 1024, 4 

102 return _pick_tle_fused_launch_params(numel, num_experts) 

103 

104 

105def _pick_tle_atomic_fused_num_blocks( 

106 numel: int, num_experts: int, block_tokens: int, device: torch.device 

107) -> int: 

108 if device.type != "cuda" or not torch.cuda.is_available(): 

109 return 1 

110 props = torch.cuda.get_device_properties(device) 

111 sm_count = int(getattr(props, "multi_processor_count", 1)) 

112 token_programs = triton.cdiv(numel, block_tokens) 

113 cap_mult = 4 if num_experts < 256 else 16 

114 block_cap = sm_count * cap_mult 

115 return max(1, min(token_programs, block_cap)) 

116 

117 

118@libentry() 

119@libtuner( 

120 configs=TLE_ATOMIC_WARPS_CONFIGS, 

121 key=["numel"], 

122 strategy=["log"], 

123) 

124@triton.jit(do_not_specialize=["numel"]) 

125def moe_align_block_size_tle_atomic_fused_coop( 

126 topk_ids_ptr, 

127 sorted_token_ids_ptr, 

128 expert_ids_ptr, 

129 num_tokens_post_pad_ptr, 

130 cumsum_ptr, 

131 mesh: tl.constexpr, 

132 num_experts: tl.constexpr, 

133 block_size: tl.constexpr, 

134 numel, 

135 numel_sorted_token_ids: tl.constexpr, 

136 numel_expert_ids: tl.constexpr, 

137 NUM_BLOCKS: tl.constexpr, 

138 BLOCK_TOKENS: tl.constexpr, 

139 BLOCK_EXPERT: tl.constexpr, 

140 EXPERTS_PER_PROG: tl.constexpr, 

141): 

142 pid = tl.program_id(0) 

143 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

144 expert_mask = expert_offsets < num_experts 

145 token_offsets = tl.arange(0, BLOCK_TOKENS) 

146 

147 for base in range( 

148 pid * BLOCK_TOKENS, numel_sorted_token_ids, NUM_BLOCKS * BLOCK_TOKENS 

149 ): 

150 offs = base + token_offsets 

151 tl.store(sorted_token_ids_ptr + offs, numel, mask=offs < numel_sorted_token_ids) 

152 for base in range(pid * BLOCK_TOKENS, numel_expert_ids, NUM_BLOCKS * BLOCK_TOKENS): 

153 offs = base + token_offsets 

154 tl.store(expert_ids_ptr + offs, 0, mask=offs < numel_expert_ids) 

155 if pid == 0: 

156 tl.store(cumsum_ptr + expert_offsets, 0, mask=expert_mask) 

157 tle.distributed_barrier(mesh) 

158 

159 local_counts = tle.gpu.alloc( 

160 [BLOCK_EXPERT], 

161 dtype=tl.int32, 

162 layout=None, 

163 scope=tle.gpu.smem, 

164 nv_mma_shared_layout=False, 

165 ) 

166 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,)) 

167 tl.store(local_counts_ptrs, 0, mask=expert_mask) 

168 

169 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS): 

170 offs = base + token_offsets 

171 mask = offs < numel 

172 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

173 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

174 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta") 

175 

176 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0) 

177 prefix_before = tl.atomic_add( 

178 cumsum_ptr + expert_offsets, 

179 local_counts_vals, 

180 mask=expert_mask, 

181 sem="acq_rel", 

182 scope="gpu", 

183 ) 

184 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask) 

185 tle.distributed_barrier(mesh) 

186 

187 if pid == 0: 

188 total_counts = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0) 

189 aligned_counts = tl.cdiv(total_counts, block_size) * block_size 

190 expert_starts = tl.cumsum(aligned_counts, axis=0) - aligned_counts 

191 tl.store(cumsum_ptr + expert_offsets, expert_starts, mask=expert_mask) 

192 total_tokens = tl.sum(aligned_counts, axis=0) 

193 tl.store(num_tokens_post_pad_ptr, total_tokens) 

194 tle.distributed_barrier(mesh) 

195 

196 expert_starts_local = tle.gpu.alloc( 

197 [BLOCK_EXPERT], 

198 dtype=tl.int32, 

199 layout=None, 

200 scope=tle.gpu.smem, 

201 nv_mma_shared_layout=False, 

202 ) 

203 expert_starts_ptrs = tle.gpu.local_ptr(expert_starts_local, (expert_offsets,)) 

204 expert_starts_vals = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0) 

205 tl.store(expert_starts_ptrs, expert_starts_vals, mask=expert_mask) 

206 

207 total_tokens = tl.load(num_tokens_post_pad_ptr) 

208 for local_expert_idx in range(EXPERTS_PER_PROG): 

209 expert_id = pid + local_expert_idx * NUM_BLOCKS 

210 valid_expert = expert_id < num_experts 

211 start_idx = tl.load( 

212 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), 

213 mask=valid_expert, 

214 other=0, 

215 ) 

216 next_expert = expert_id + 1 

217 has_next = valid_expert & (next_expert < num_experts) 

218 end_idx = tl.load( 

219 tle.gpu.local_ptr(expert_starts_local, (next_expert,)), 

220 mask=has_next, 

221 other=total_tokens, 

222 ) 

223 end_idx = tl.where(has_next, end_idx, total_tokens) 

224 start_idx = tl.where(valid_expert, start_idx, 0) 

225 end_idx = tl.where(valid_expert, end_idx, 0) 

226 for i in range(start_idx, end_idx, block_size): 

227 tl.store(expert_ids_ptr + i // block_size, expert_id) 

228 

229 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS): 

230 offs = base + token_offsets 

231 mask = offs < numel 

232 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

233 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

234 rank_with_prefix = tl.atomic_add( 

235 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta" 

236 ) 

237 rank_base = tl.load( 

238 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), mask=mask, other=0 

239 ) 

240 rank_post_pad = rank_with_prefix + rank_base 

241 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask) 

242 

243 

244@libentry() 

245@libtuner( 

246 configs=TLE_CLUSTER_LAUNCH_CONFIGS, 

247 key=["numel"], 

248 strategy=["log"], 

249) 

250@triton.jit(do_not_specialize=["numel"]) 

251def moe_align_block_size_tle_cluster_fused( 

252 topk_ids_ptr, 

253 sorted_token_ids_ptr, 

254 expert_ids_ptr, 

255 num_tokens_post_pad_ptr, 

256 num_experts: tl.constexpr, 

257 block_size: tl.constexpr, 

258 numel, 

259 numel_sorted_token_ids: tl.constexpr, 

260 numel_expert_ids: tl.constexpr, 

261 mesh: tl.constexpr, 

262 CLUSTER_SIZE: tl.constexpr, 

263 BLOCK_TOKENS: tl.constexpr, 

264 BLOCK_EXPERT: tl.constexpr, 

265 EXPERTS_PER_SHARD: tl.constexpr, 

266): 

267 cluster_rank = tle.shard_id(mesh, "cluster_x") 

268 is_rank0 = cluster_rank == 0 

269 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

270 expert_mask = expert_offsets < num_experts 

271 

272 init_offsets = tl.arange(0, BLOCK_TOKENS) 

273 for base in range( 

274 cluster_rank * BLOCK_TOKENS, numel_sorted_token_ids, CLUSTER_SIZE * BLOCK_TOKENS 

275 ): 

276 offs = base + init_offsets 

277 mask = offs < numel_sorted_token_ids 

278 tl.store(sorted_token_ids_ptr + offs, numel, mask=mask) 

279 for base in range( 

280 cluster_rank * BLOCK_TOKENS, numel_expert_ids, CLUSTER_SIZE * BLOCK_TOKENS 

281 ): 

282 offs = base + init_offsets 

283 mask = offs < numel_expert_ids 

284 tl.store(expert_ids_ptr + offs, 0, mask=mask) 

285 

286 local_counts = tle.gpu.alloc( 

287 [BLOCK_EXPERT], 

288 dtype=tl.int32, 

289 layout=None, 

290 scope=tle.gpu.smem, 

291 nv_mma_shared_layout=False, 

292 ) 

293 cumsum_local = tle.gpu.alloc( 

294 [BLOCK_EXPERT], 

295 dtype=tl.int32, 

296 layout=None, 

297 scope=tle.gpu.smem, 

298 nv_mma_shared_layout=False, 

299 ) 

300 

301 rank0_cumsum_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_offsets,)) 

302 if is_rank0: 

303 tl.store(rank0_cumsum_ptrs, 0, mask=expert_mask) 

304 tle.distributed_barrier(mesh) 

305 

306 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,)) 

307 tl.store(local_counts_ptrs, 0, mask=expert_mask) 

308 

309 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS): 

310 offs = base + init_offsets 

311 mask = offs < numel 

312 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

313 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

314 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta") 

315 

316 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0) 

317 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh) 

318 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,)) 

319 prefix_before = tl.atomic_add( 

320 rank0_cumsum_remote_ptrs, 

321 local_counts_vals, 

322 mask=expert_mask, 

323 sem="relaxed", 

324 scope="cta", 

325 ) 

326 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask) 

327 

328 tle.distributed_barrier(mesh) 

329 

330 if is_rank0: 

331 total_counts = tl.load(rank0_cumsum_ptrs, mask=expert_mask, other=0) 

332 aligned_counts = tl.cdiv(total_counts, block_size) * block_size 

333 expert_cumsum_inclusive = tl.cumsum(aligned_counts, axis=0) 

334 expert_start_offsets = expert_cumsum_inclusive - aligned_counts 

335 tl.store(rank0_cumsum_ptrs, expert_start_offsets, mask=expert_mask) 

336 total_tokens = tl.sum(aligned_counts, axis=0) 

337 tl.store(num_tokens_post_pad_ptr, total_tokens) 

338 

339 tle.distributed_barrier(mesh) 

340 

341 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh) 

342 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,)) 

343 cumsum_vals = tl.load(rank0_cumsum_remote_ptrs, mask=expert_mask, other=0) 

344 tl.store( 

345 tle.gpu.local_ptr(cumsum_local, (expert_offsets,)), 

346 cumsum_vals, 

347 mask=expert_mask, 

348 ) 

349 total_tokens = tl.load(num_tokens_post_pad_ptr) 

350 

351 for local_expert_idx in range(EXPERTS_PER_SHARD): 

352 expert_idx = cluster_rank * EXPERTS_PER_SHARD + local_expert_idx 

353 expert_id = expert_idx 

354 valid_expert = expert_id < num_experts 

355 start_ptr = tle.gpu.local_ptr(cumsum_local, (expert_id,)) 

356 start_idx = tl.load(start_ptr, mask=valid_expert, other=0) 

357 next_expert_id = expert_id + 1 

358 has_next = valid_expert & (next_expert_id < num_experts) 

359 next_ptr = tle.gpu.local_ptr(cumsum_local, (next_expert_id,)) 

360 end_from_next = tl.load(next_ptr, mask=has_next, other=0) 

361 end_idx = tl.where(has_next, end_from_next, total_tokens) 

362 start_idx = tl.where(valid_expert, start_idx, 0) 

363 end_idx = tl.where(valid_expert, end_idx, 0) 

364 for i in range(start_idx, end_idx, block_size): 

365 tl.store(expert_ids_ptr + i // block_size, expert_idx) 

366 

367 tle.distributed_barrier(mesh) 

368 

369 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS): 

370 offs = base + init_offsets 

371 mask = offs < numel 

372 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

373 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

374 rank_with_prefix = tl.atomic_add( 

375 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta" 

376 ) 

377 base_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_id,)) 

378 rank_base = tl.load(base_ptrs, mask=mask, other=0) 

379 rank_post_pad = rank_with_prefix + rank_base 

380 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask) 

381 

382 

383@triton.jit(do_not_specialize=["numel"]) 

384def moe_align_block_size_stage1( 

385 topk_ids_ptr, 

386 tokens_cnts_ptr, 

387 num_experts: tl.constexpr, 

388 numel, 

389 tokens_per_thread: tl.constexpr, 

390 sorted_token_ids_ptr, 

391 expert_ids_ptr, 

392 numel_sorted_token_ids: tl.constexpr, 

393 numel_expert_ids: tl.constexpr, 

394 block_size_sorted: tl.constexpr, 

395 block_size_expert: tl.constexpr, 

396): 

397 pid = tl.program_id(0) 

398 

399 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted) 

400 mask_sorted = offsets_sorted < numel_sorted_token_ids 

401 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted) 

402 

403 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert) 

404 mask_expert = offsets_expert < numel_expert_ids 

405 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert) 

406 

407 start_idx = pid * tokens_per_thread 

408 

409 off_c = (pid + 1) * num_experts 

410 

411 offsets = start_idx + tl.arange(0, tokens_per_thread) 

412 mask = offsets < numel 

413 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0) 

414 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1, mask=mask) 

415 

416 

417@triton.jit 

418def moe_align_block_size_stage2_vec( 

419 tokens_cnts_ptr, 

420 num_experts: tl.constexpr, 

421): 

422 pid = tl.program_id(0) 

423 

424 offset = tl.arange(0, num_experts) + 1 

425 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid) 

426 cnt = tl.cumsum(token_cnt, axis=0) 

427 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt) 

428 

429 

430@triton.jit 

431def moe_align_block_size_stage2( 

432 tokens_cnts_ptr, 

433 num_experts: tl.constexpr, 

434): 

435 pid = tl.program_id(0) 

436 

437 last_cnt = 0 

438 for i in range(1, num_experts + 1): 

439 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) 

440 last_cnt = last_cnt + token_cnt 

441 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) 

442 

443 

444@triton.jit 

445def moe_align_block_size_stage3( 

446 total_tokens_post_pad_ptr, 

447 tokens_cnts_ptr, 

448 cumsum_ptr, 

449 num_experts: tl.constexpr, 

450 num_experts_next_power_of_2: tl.constexpr, 

451 block_size: tl.constexpr, 

452): 

453 off_cnt = num_experts * num_experts 

454 

455 expert_offsets = tl.arange(0, num_experts_next_power_of_2) 

456 mask = expert_offsets < num_experts 

457 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask) 

458 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size 

459 

460 cumsum_values = tl.cumsum(aligned_cnts, axis=0) 

461 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask) 

462 

463 total_tokens = tl.sum(aligned_cnts, axis=0) 

464 tl.store(total_tokens_post_pad_ptr, total_tokens) 

465 

466 

467@triton.jit(do_not_specialize=["numel"]) 

468def moe_align_block_size_stage4( 

469 topk_ids_ptr, 

470 sorted_token_ids_ptr, 

471 expert_ids_ptr, 

472 tokens_cnts_ptr, 

473 cumsum_ptr, 

474 num_experts: tl.constexpr, 

475 block_size: tl.constexpr, 

476 numel, 

477 tokens_per_thread: tl.constexpr, 

478): 

479 pid = tl.program_id(0) 

480 start_idx = tl.load(cumsum_ptr + pid) 

481 end_idx = tl.load(cumsum_ptr + pid + 1) 

482 

483 for i in range(start_idx, end_idx, block_size): 

484 tl.store(expert_ids_ptr + i // block_size, pid) 

485 

486 start_idx = pid * tokens_per_thread 

487 off_t = pid * num_experts 

488 

489 offset = tl.arange(0, tokens_per_thread) + start_idx 

490 mask = offset < numel 

491 expert_id = tl.load(topk_ids_ptr + offset, mask=mask) 

492 token_idx_in_expert = tl.atomic_add( 

493 tokens_cnts_ptr + off_t + expert_id, 1, mask=mask 

494 ) 

495 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id, mask=mask) 

496 tl.store(sorted_token_ids_ptr + rank_post_pad, offset, mask=mask) 

497 

498 

499def moe_align_block_size_triton( 

500 topk_ids: torch.Tensor, 

501 num_experts: int, 

502 block_size: int, 

503 sorted_token_ids: torch.Tensor, 

504 expert_ids: torch.Tensor, 

505 num_tokens_post_pad: torch.Tensor, 

506) -> None: 

507 numel = topk_ids.numel() 

508 numel_sorted_token_ids = sorted_token_ids.numel() 

509 numel_expert_ids = expert_ids.numel() 

510 grid = (num_experts,) 

511 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts)) 

512 block_size_sorted = triton.next_power_of_2( 

513 ceil_div(numel_sorted_token_ids, num_experts) 

514 ) 

515 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts)) 

516 block_expert_tle = triton.next_power_of_2(num_experts) 

517 

518 if HAS_TLE and topk_ids.is_cuda and block_expert_tle <= 1024: 

519 block_tokens_taf, _ = _pick_tle_atomic_fused_launch_params(numel, num_experts) 

520 experts_per_shard = ceil_div(num_experts, TLE_CLUSTER_SIZE) 

521 num_tokens = topk_ids.shape[0] if topk_ids.ndim > 1 else numel 

522 

523 def _run_tle_atomic_fused() -> bool: 

524 cumsum_tle = torch.zeros( 

525 (num_experts,), dtype=torch.int32, device=topk_ids.device 

526 ) 

527 num_blocks = _pick_tle_atomic_fused_num_blocks( 

528 numel, 

529 num_experts, 

530 block_tokens_taf, 

531 topk_ids.device, 

532 ) 

533 experts_per_prog = ceil_div(num_experts, num_blocks) 

534 while True: 

535 try: 

536 moe_align_block_size_tle_atomic_fused_coop[(num_blocks,)]( 

537 topk_ids, 

538 sorted_token_ids, 

539 expert_ids, 

540 num_tokens_post_pad, 

541 cumsum_tle, 

542 _block_mesh(num_blocks), 

543 num_experts, 

544 block_size, 

545 numel, 

546 numel_sorted_token_ids, 

547 numel_expert_ids, 

548 NUM_BLOCKS=num_blocks, 

549 BLOCK_TOKENS=block_tokens_taf, 

550 BLOCK_EXPERT=block_expert_tle, 

551 EXPERTS_PER_PROG=experts_per_prog, 

552 launch_cooperative_grid=True, 

553 ) 

554 return True 

555 except Exception as ex: 

556 msg = str(ex).lower() 

557 if "no allocator was set" in msg: 

558 _install_triton_default_allocator(topk_ids.device) 

559 continue 

560 if num_blocks <= 1 or "cooperative" not in msg: 

561 logger.debug( 

562 "TLE atomic fused launch failed, fallback to triton: %s", 

563 ex, 

564 ) 

565 return False 

566 num_blocks = max(1, num_blocks // 2) 

567 experts_per_prog = ceil_div(num_experts, num_blocks) 

568 

569 if ( 

570 num_tokens < TLE_BIG_TOKEN_THRESHOLD_TOKENS 

571 and _supports_tle_cluster_remote() 

572 ): 

573 try: 

574 moe_align_block_size_tle_cluster_fused[(1,)]( 

575 topk_ids, 

576 sorted_token_ids, 

577 expert_ids, 

578 num_tokens_post_pad, 

579 num_experts, 

580 block_size, 

581 numel, 

582 numel_sorted_token_ids, 

583 numel_expert_ids, 

584 mesh=_block_cluster_mesh_8(), 

585 CLUSTER_SIZE=TLE_CLUSTER_SIZE, 

586 BLOCK_EXPERT=block_expert_tle, 

587 EXPERTS_PER_SHARD=experts_per_shard, 

588 ) 

589 return 

590 except Exception as ex: 

591 logger.debug( 

592 "TLE cluster fused launch failed, fallback to atomic/triton: %s", 

593 ex, 

594 ) 

595 

596 if _run_tle_atomic_fused(): 

597 return 

598 

599 # The tensor needs to be padded before calculating IDs, 

600 # to prevent out-of-bounds address access. 

601 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) 

602 tokens_cnts = torch.zeros( 

603 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device 

604 ) 

605 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts) 

606 

607 moe_align_block_size_stage1[grid]( 

608 topk_ids, 

609 tokens_cnts, 

610 num_experts, 

611 numel, 

612 tokens_per_thread, 

613 sorted_token_ids, 

614 expert_ids, 

615 numel_sorted_token_ids, 

616 numel_expert_ids, 

617 block_size_sorted, 

618 block_size_expert, 

619 ) 

620 if num_experts == triton.next_power_of_2(num_experts): 

621 moe_align_block_size_stage2_vec[grid](tokens_cnts, num_experts) 

622 else: 

623 moe_align_block_size_stage2[grid](tokens_cnts, num_experts) 

624 moe_align_block_size_stage3[(1,)]( 

625 num_tokens_post_pad, 

626 tokens_cnts, 

627 cumsum, 

628 num_experts, 

629 num_experts_next_power_of_2, 

630 block_size, 

631 ) 

632 moe_align_block_size_stage4[grid]( 

633 topk_ids, 

634 sorted_token_ids, 

635 expert_ids, 

636 tokens_cnts, 

637 cumsum, 

638 num_experts, 

639 block_size, 

640 numel, 

641 tokens_per_thread, 

642 ) 

643 

644 

645def moe_align_block_size( 

646 topk_ids: torch.Tensor, 

647 block_size: int, 

648 num_experts: int, 

649 expert_map: Optional[torch.Tensor] = None, 

650 pad_sorted_ids: bool = False, 

651) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]": 

652 logger.debug("GEMS MOE ALIGN BLOCK SIZE") 

653 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) 

654 if pad_sorted_ids: 

655 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) 

656 sorted_ids = torch.empty( 

657 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device 

658 ) 

659 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) 

660 expert_ids = torch.empty( 

661 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device 

662 ) 

663 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) 

664 

665 moe_align_block_size_triton( 

666 topk_ids, 

667 num_experts, 

668 block_size, 

669 sorted_ids, 

670 expert_ids, 

671 num_tokens_post_pad, 

672 ) 

673 

674 if expert_map is not None: 

675 expert_ids = expert_map[expert_ids] 

676 

677 return sorted_ids, expert_ids, num_tokens_post_pad