Coverage for src/flag_gems/runtime/backend/_ascend/fused/moe_align_block_size.py: 0%

352 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2from functools import lru_cache 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.utils import libentry, libtuner 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14def _triton_version_at_least(major: int, minor: int, patch: int = 0) -> bool: 

15 version = str(getattr(triton, "__version__", "0.0.0")).split("+", 1)[0] 

16 parts = version.split(".") 

17 parsed = [] 

18 for part in parts[:3]: 

19 digits = [] 

20 for ch in part: 

21 if ch.isdigit(): 

22 digits.append(ch) 

23 else: 

24 break 

25 parsed.append(int("".join(digits)) if digits else 0) 

26 while len(parsed) < 3: 

27 parsed.append(0) 

28 return tuple(parsed) >= (major, minor, patch) 

29 

30 

31if _triton_version_at_least(3, 6, 0): 

32 try: 

33 import triton.experimental.tle.language as tle 

34 import triton.experimental.tle.language.gpu as tleg 

35 

36 HAS_TLE = True 

37 except ImportError: 

38 tle = None 

39 tleg = None 

40 HAS_TLE = False 

41else: 

42 tle = None 

43 tleg = None 

44 HAS_TLE = False 

45 

46 

47logger = logging.getLogger(__name__) 

48 

49TLE_CLUSTER_SIZE = 8 

50TLE_BIG_TOKEN_THRESHOLD_TOKENS = 4096 

51_TRITON_ALLOCATOR_INSTALLED = False 

52TLE_ATOMIC_WARPS_CONFIGS = [ 

53 triton.Config(kwargs={}, num_warps=4), 

54 triton.Config(kwargs={}, num_warps=8), 

55] 

56TLE_CLUSTER_LAUNCH_CONFIGS = [ 

57 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=4), 

58 triton.Config(kwargs={"BLOCK_TOKENS": 128}, num_warps=8), 

59 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=4), 

60 triton.Config(kwargs={"BLOCK_TOKENS": 256}, num_warps=8), 

61 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=4), 

62 triton.Config(kwargs={"BLOCK_TOKENS": 512}, num_warps=8), 

63 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=4), 

64 triton.Config(kwargs={"BLOCK_TOKENS": 1024}, num_warps=8), 

65] 

66 

67 

68def ceil_div(a, b): 

69 return (a + b - 1) // b 

70 

71 

72def round_up(x: int, y: int) -> int: 

73 return ((x + y - 1) // y) * y 

74 

75 

76@lru_cache(maxsize=64) 

77def _block_mesh(num_blocks: int): 

78 return tle.device_mesh({"block": [("block_x", int(num_blocks))]}) 

79 

80 

81@lru_cache(maxsize=1) 

82def _block_cluster_mesh_8(): 

83 return tle.device_mesh({"block_cluster": [("cluster_x", TLE_CLUSTER_SIZE)]}) 

84 

85 

86def _supports_tle_cluster_remote() -> bool: 

87 if not torch.cuda.is_available(): 

88 return False 

89 major, _minor = torch.cuda.get_device_capability() 

90 return major >= 9 

91 

92 

93def _install_triton_default_allocator(device: torch.device) -> None: 

94 global _TRITON_ALLOCATOR_INSTALLED 

95 if _TRITON_ALLOCATOR_INSTALLED: 

96 return 

97 

98 def _alloc(size: int, _alignment: int, _stream: Optional[int]): 

99 return torch.empty((size,), dtype=torch.uint8, device=device) 

100 

101 triton.set_allocator(_alloc) 

102 _TRITON_ALLOCATOR_INSTALLED = True 

103 

104 

105def _pick_tle_fused_launch_params(numel: int, num_experts: int) -> "tuple[int, int]": 

106 if num_experts >= 256: 

107 if numel >= 32768: 

108 return 4096, 4 

109 if numel >= 1024: 

110 return 1024, 4 

111 return 256, 8 

112 

113 if numel <= 512: 

114 return 128, 8 

115 if num_experts <= 64 and numel <= 2048: 

116 return 128, 8 

117 return 256, 8 

118 

119 

120def _pick_tle_atomic_fused_launch_params( 

121 numel: int, num_experts: int 

122) -> "tuple[int, int]": 

123 if num_experts >= 256: 

124 if numel <= 16384: 

125 return 256, 8 

126 if numel <= 32768: 

127 return 512, 4 

128 return 1024, 4 

129 return _pick_tle_fused_launch_params(numel, num_experts) 

130 

131 

132def _pick_tle_atomic_fused_num_blocks( 

133 numel: int, num_experts: int, block_tokens: int, device: torch.device 

134) -> int: 

135 if device.type != "cuda" or not torch.cuda.is_available(): 

136 return 1 

137 props = torch.cuda.get_device_properties(device) 

138 sm_count = int(getattr(props, "multi_processor_count", 1)) 

139 token_programs = triton.cdiv(numel, block_tokens) 

140 cap_mult = 4 if num_experts < 256 else 16 

141 block_cap = sm_count * cap_mult 

142 return max(1, min(token_programs, block_cap)) 

143 

144 

145@libentry() 

146@libtuner( 

147 configs=TLE_ATOMIC_WARPS_CONFIGS, 

148 key=["numel"], 

149 strategy=["log"], 

150) 

151@triton.jit(do_not_specialize=["numel"]) 

152def moe_align_block_size_tle_atomic_fused_coop( 

153 topk_ids_ptr, 

154 sorted_token_ids_ptr, 

155 expert_ids_ptr, 

156 num_tokens_post_pad_ptr, 

157 cumsum_ptr, 

158 mesh: tl.constexpr, 

159 num_experts: tl.constexpr, 

160 block_size: tl.constexpr, 

161 numel, 

162 numel_sorted_token_ids: tl.constexpr, 

163 numel_expert_ids: tl.constexpr, 

164 NUM_BLOCKS: tl.constexpr, 

165 BLOCK_TOKENS: tl.constexpr, 

166 BLOCK_EXPERT: tl.constexpr, 

167 EXPERTS_PER_PROG: tl.constexpr, 

168): 

169 pid = tl.program_id(0) 

170 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

171 expert_mask = expert_offsets < num_experts 

172 token_offsets = tl.arange(0, BLOCK_TOKENS) 

173 

174 for base in range( 

175 pid * BLOCK_TOKENS, numel_sorted_token_ids, NUM_BLOCKS * BLOCK_TOKENS 

176 ): 

177 offs = base + token_offsets 

178 tl.store(sorted_token_ids_ptr + offs, numel, mask=offs < numel_sorted_token_ids) 

179 for base in range(pid * BLOCK_TOKENS, numel_expert_ids, NUM_BLOCKS * BLOCK_TOKENS): 

180 offs = base + token_offsets 

181 tl.store(expert_ids_ptr + offs, 0, mask=offs < numel_expert_ids) 

182 if pid == 0: 

183 tl.store(cumsum_ptr + expert_offsets, 0, mask=expert_mask) 

184 tle.distributed_barrier(mesh) 

185 

186 local_counts = tle.gpu.alloc( 

187 [BLOCK_EXPERT], 

188 dtype=tl.int32, 

189 layout=None, 

190 scope=tle.gpu.smem, 

191 nv_mma_shared_layout=False, 

192 ) 

193 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,)) 

194 tl.store(local_counts_ptrs, 0, mask=expert_mask) 

195 

196 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS): 

197 offs = base + token_offsets 

198 mask = offs < numel 

199 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

200 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

201 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta") 

202 

203 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0) 

204 prefix_before = tl.atomic_add( 

205 cumsum_ptr + expert_offsets, 

206 local_counts_vals, 

207 mask=expert_mask, 

208 sem="acq_rel", 

209 scope="gpu", 

210 ) 

211 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask) 

212 tle.distributed_barrier(mesh) 

213 

214 if pid == 0: 

215 total_counts = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0) 

216 aligned_counts = tl.cdiv(total_counts, block_size) * block_size 

217 expert_starts = tl.cumsum(aligned_counts, axis=0) - aligned_counts 

218 tl.store(cumsum_ptr + expert_offsets, expert_starts, mask=expert_mask) 

219 total_tokens = tl.sum(aligned_counts, axis=0) 

220 tl.store(num_tokens_post_pad_ptr, total_tokens) 

221 tle.distributed_barrier(mesh) 

222 

223 expert_starts_local = tle.gpu.alloc( 

224 [BLOCK_EXPERT], 

225 dtype=tl.int32, 

226 layout=None, 

227 scope=tle.gpu.smem, 

228 nv_mma_shared_layout=False, 

229 ) 

230 expert_starts_ptrs = tle.gpu.local_ptr(expert_starts_local, (expert_offsets,)) 

231 expert_starts_vals = tl.load(cumsum_ptr + expert_offsets, mask=expert_mask, other=0) 

232 tl.store(expert_starts_ptrs, expert_starts_vals, mask=expert_mask) 

233 

234 total_tokens = tl.load(num_tokens_post_pad_ptr) 

235 for local_expert_idx in range(EXPERTS_PER_PROG): 

236 expert_id = pid + local_expert_idx * NUM_BLOCKS 

237 valid_expert = expert_id < num_experts 

238 start_idx = tl.load( 

239 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), 

240 mask=valid_expert, 

241 other=0, 

242 ) 

243 next_expert = expert_id + 1 

244 has_next = valid_expert & (next_expert < num_experts) 

245 end_idx = tl.load( 

246 tle.gpu.local_ptr(expert_starts_local, (next_expert,)), 

247 mask=has_next, 

248 other=total_tokens, 

249 ) 

250 end_idx = tl.where(has_next, end_idx, total_tokens) 

251 start_idx = tl.where(valid_expert, start_idx, 0) 

252 end_idx = tl.where(valid_expert, end_idx, 0) 

253 for i in range(start_idx, end_idx, block_size): 

254 tl.store(expert_ids_ptr + i // block_size, expert_id) 

255 

256 for base in range(pid * BLOCK_TOKENS, numel, NUM_BLOCKS * BLOCK_TOKENS): 

257 offs = base + token_offsets 

258 mask = offs < numel 

259 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

260 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

261 rank_with_prefix = tl.atomic_add( 

262 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta" 

263 ) 

264 rank_base = tl.load( 

265 tle.gpu.local_ptr(expert_starts_local, (expert_id,)), mask=mask, other=0 

266 ) 

267 rank_post_pad = rank_with_prefix + rank_base 

268 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask) 

269 

270 

271@libentry() 

272@libtuner( 

273 configs=TLE_CLUSTER_LAUNCH_CONFIGS, 

274 key=["numel"], 

275 strategy=["log"], 

276) 

277@triton.jit(do_not_specialize=["numel"]) 

278def moe_align_block_size_tle_cluster_fused( 

279 topk_ids_ptr, 

280 sorted_token_ids_ptr, 

281 expert_ids_ptr, 

282 num_tokens_post_pad_ptr, 

283 num_experts: tl.constexpr, 

284 block_size: tl.constexpr, 

285 numel, 

286 numel_sorted_token_ids: tl.constexpr, 

287 numel_expert_ids: tl.constexpr, 

288 mesh: tl.constexpr, 

289 CLUSTER_SIZE: tl.constexpr, 

290 BLOCK_TOKENS: tl.constexpr, 

291 BLOCK_EXPERT: tl.constexpr, 

292 EXPERTS_PER_SHARD: tl.constexpr, 

293): 

294 cluster_rank = tle.shard_id(mesh, "cluster_x") 

295 is_rank0 = cluster_rank == 0 

296 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

297 expert_mask = expert_offsets < num_experts 

298 

299 init_offsets = tl.arange(0, BLOCK_TOKENS) 

300 for base in range( 

301 cluster_rank * BLOCK_TOKENS, numel_sorted_token_ids, CLUSTER_SIZE * BLOCK_TOKENS 

302 ): 

303 offs = base + init_offsets 

304 mask = offs < numel_sorted_token_ids 

305 tl.store(sorted_token_ids_ptr + offs, numel, mask=mask) 

306 for base in range( 

307 cluster_rank * BLOCK_TOKENS, numel_expert_ids, CLUSTER_SIZE * BLOCK_TOKENS 

308 ): 

309 offs = base + init_offsets 

310 mask = offs < numel_expert_ids 

311 tl.store(expert_ids_ptr + offs, 0, mask=mask) 

312 

313 local_counts = tle.gpu.alloc( 

314 [BLOCK_EXPERT], 

315 dtype=tl.int32, 

316 layout=None, 

317 scope=tle.gpu.smem, 

318 nv_mma_shared_layout=False, 

319 ) 

320 cumsum_local = tle.gpu.alloc( 

321 [BLOCK_EXPERT], 

322 dtype=tl.int32, 

323 layout=None, 

324 scope=tle.gpu.smem, 

325 nv_mma_shared_layout=False, 

326 ) 

327 

328 rank0_cumsum_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_offsets,)) 

329 if is_rank0: 

330 tl.store(rank0_cumsum_ptrs, 0, mask=expert_mask) 

331 tle.distributed_barrier(mesh) 

332 

333 local_counts_ptrs = tle.gpu.local_ptr(local_counts, (expert_offsets,)) 

334 tl.store(local_counts_ptrs, 0, mask=expert_mask) 

335 

336 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS): 

337 offs = base + init_offsets 

338 mask = offs < numel 

339 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

340 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

341 tl.atomic_add(count_ptrs, 1, mask=mask, sem="relaxed", scope="cta") 

342 

343 local_counts_vals = tl.load(local_counts_ptrs, mask=expert_mask, other=0) 

344 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh) 

345 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,)) 

346 prefix_before = tl.atomic_add( 

347 rank0_cumsum_remote_ptrs, 

348 local_counts_vals, 

349 mask=expert_mask, 

350 sem="relaxed", 

351 scope="cta", 

352 ) 

353 tl.store(local_counts_ptrs, prefix_before, mask=expert_mask) 

354 

355 tle.distributed_barrier(mesh) 

356 

357 if is_rank0: 

358 total_counts = tl.load(rank0_cumsum_ptrs, mask=expert_mask, other=0) 

359 aligned_counts = tl.cdiv(total_counts, block_size) * block_size 

360 expert_cumsum_inclusive = tl.cumsum(aligned_counts, axis=0) 

361 expert_start_offsets = expert_cumsum_inclusive - aligned_counts 

362 tl.store(rank0_cumsum_ptrs, expert_start_offsets, mask=expert_mask) 

363 total_tokens = tl.sum(aligned_counts, axis=0) 

364 tl.store(num_tokens_post_pad_ptr, total_tokens) 

365 

366 tle.distributed_barrier(mesh) 

367 

368 rank0_cumsum_remote = tle.remote(cumsum_local, 0, scope=mesh) 

369 rank0_cumsum_remote_ptrs = tle.gpu.local_ptr(rank0_cumsum_remote, (expert_offsets,)) 

370 cumsum_vals = tl.load(rank0_cumsum_remote_ptrs, mask=expert_mask, other=0) 

371 tl.store( 

372 tle.gpu.local_ptr(cumsum_local, (expert_offsets,)), 

373 cumsum_vals, 

374 mask=expert_mask, 

375 ) 

376 total_tokens = tl.load(num_tokens_post_pad_ptr) 

377 

378 for local_expert_idx in range(EXPERTS_PER_SHARD): 

379 expert_idx = cluster_rank * EXPERTS_PER_SHARD + local_expert_idx 

380 expert_id = expert_idx 

381 valid_expert = expert_id < num_experts 

382 start_ptr = tle.gpu.local_ptr(cumsum_local, (expert_id,)) 

383 start_idx = tl.load(start_ptr, mask=valid_expert, other=0) 

384 next_expert_id = expert_id + 1 

385 has_next = valid_expert & (next_expert_id < num_experts) 

386 next_ptr = tle.gpu.local_ptr(cumsum_local, (next_expert_id,)) 

387 end_from_next = tl.load(next_ptr, mask=has_next, other=0) 

388 end_idx = tl.where(has_next, end_from_next, total_tokens) 

389 start_idx = tl.where(valid_expert, start_idx, 0) 

390 end_idx = tl.where(valid_expert, end_idx, 0) 

391 for i in range(start_idx, end_idx, block_size): 

392 tl.store(expert_ids_ptr + i // block_size, expert_idx) 

393 

394 tle.distributed_barrier(mesh) 

395 

396 for base in range(cluster_rank * BLOCK_TOKENS, numel, CLUSTER_SIZE * BLOCK_TOKENS): 

397 offs = base + init_offsets 

398 mask = offs < numel 

399 expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int32) 

400 count_ptrs = tle.gpu.local_ptr(local_counts, (expert_id,)) 

401 rank_with_prefix = tl.atomic_add( 

402 count_ptrs, 1, mask=mask, sem="relaxed", scope="cta" 

403 ) 

404 base_ptrs = tle.gpu.local_ptr(cumsum_local, (expert_id,)) 

405 rank_base = tl.load(base_ptrs, mask=mask, other=0) 

406 rank_post_pad = rank_with_prefix + rank_base 

407 tl.store(sorted_token_ids_ptr + rank_post_pad, offs, mask=mask) 

408 

409 

410@triton.jit(do_not_specialize=["numel"]) 

411def moe_align_block_size_stage1( 

412 topk_ids_ptr, 

413 tokens_cnts_ptr, 

414 num_experts: tl.constexpr, 

415 numel, 

416 tokens_per_thread: tl.constexpr, 

417 sorted_token_ids_ptr, 

418 expert_ids_ptr, 

419 numel_sorted_token_ids: tl.constexpr, 

420 numel_expert_ids: tl.constexpr, 

421 block_size_sorted: tl.constexpr, 

422 block_size_expert: tl.constexpr, 

423): 

424 pid = tl.program_id(0) 

425 

426 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted) 

427 mask_sorted = offsets_sorted < numel_sorted_token_ids 

428 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted) 

429 

430 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert) 

431 mask_expert = offsets_expert < numel_expert_ids 

432 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert) 

433 

434 start_idx = pid * tokens_per_thread 

435 

436 off_c = (pid + 1) * num_experts 

437 

438 # Process tokens one at a time with scalar atomic_add to avoid 

439 # a triton-ascend bug where vectorized atomic_add to the same 

440 # address only executes once. 

441 for i in range(tokens_per_thread): 

442 offset = start_idx + i 

443 if offset < numel: 

444 expert_id = tl.load(topk_ids_ptr + offset) 

445 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1) 

446 

447 

448@triton.jit 

449def moe_align_block_size_stage2_vec( 

450 tokens_cnts_ptr, 

451 num_experts: tl.constexpr, 

452): 

453 pid = tl.program_id(0) 

454 

455 offset = tl.arange(0, num_experts) + 1 

456 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid) 

457 cnt = tl.cumsum(token_cnt, axis=0) 

458 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt) 

459 

460 

461@triton.jit 

462def moe_align_block_size_stage2( 

463 tokens_cnts_ptr, 

464 num_experts: tl.constexpr, 

465): 

466 pid = tl.program_id(0) 

467 

468 last_cnt = 0 

469 for i in range(1, num_experts + 1): 

470 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) 

471 last_cnt = last_cnt + token_cnt 

472 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) 

473 

474 

475@triton.jit 

476def moe_align_block_size_stage3( 

477 total_tokens_post_pad_ptr, 

478 tokens_cnts_ptr, 

479 cumsum_ptr, 

480 num_experts: tl.constexpr, 

481 num_experts_next_power_of_2: tl.constexpr, 

482 block_size: tl.constexpr, 

483): 

484 off_cnt = num_experts * num_experts 

485 

486 expert_offsets = tl.arange(0, num_experts_next_power_of_2) 

487 mask = expert_offsets < num_experts 

488 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask) 

489 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size 

490 

491 cumsum_values = tl.cumsum(aligned_cnts, axis=0) 

492 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask) 

493 

494 total_tokens = tl.sum(aligned_cnts, axis=0) 

495 tl.store(total_tokens_post_pad_ptr, total_tokens) 

496 

497 

498@triton.jit(do_not_specialize=["numel"]) 

499def moe_align_block_size_stage4( 

500 topk_ids_ptr, 

501 sorted_token_ids_ptr, 

502 expert_ids_ptr, 

503 tokens_cnts_ptr, 

504 cumsum_ptr, 

505 num_experts: tl.constexpr, 

506 block_size: tl.constexpr, 

507 numel, 

508 tokens_per_thread: tl.constexpr, 

509): 

510 pid = tl.program_id(0) 

511 start_idx = tl.load(cumsum_ptr + pid) 

512 end_idx = tl.load(cumsum_ptr + pid + 1) 

513 

514 for i in range(start_idx, end_idx, block_size): 

515 tl.store(expert_ids_ptr + i // block_size, pid) 

516 

517 start_idx = pid * tokens_per_thread 

518 off_t = pid * num_experts 

519 

520 # Process tokens one at a time with scalar atomic_add to avoid 

521 # a triton-ascend bug where vectorized atomic_add to the same 

522 # address only executes once (drops all but one increment). 

523 for i in range(tokens_per_thread): 

524 offset = start_idx + i 

525 if offset < numel: 

526 expert_id = tl.load(topk_ids_ptr + offset) 

527 tl.atomic_add(tokens_cnts_ptr + off_t + expert_id, 1) 

528 token_idx_in_expert = tl.load(tokens_cnts_ptr + off_t + expert_id) 

529 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id) - 1 

530 tl.store(sorted_token_ids_ptr + rank_post_pad, offset) 

531 

532 

533def moe_align_block_size_triton( 

534 topk_ids: torch.Tensor, 

535 num_experts: int, 

536 block_size: int, 

537 sorted_token_ids: torch.Tensor, 

538 expert_ids: torch.Tensor, 

539 num_tokens_post_pad: torch.Tensor, 

540) -> None: 

541 logger.debug("GEMS_ASCEND MOE ALIGN BLOCK SIZE") 

542 numel = topk_ids.numel() 

543 numel_sorted_token_ids = sorted_token_ids.numel() 

544 numel_expert_ids = expert_ids.numel() 

545 grid = (num_experts,) 

546 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts)) 

547 block_size_sorted = triton.next_power_of_2( 

548 ceil_div(numel_sorted_token_ids, num_experts) 

549 ) 

550 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts)) 

551 block_expert_tle = triton.next_power_of_2(num_experts) 

552 

553 if HAS_TLE and topk_ids.is_cuda and block_expert_tle <= 1024: 

554 block_tokens_taf, _ = _pick_tle_atomic_fused_launch_params(numel, num_experts) 

555 experts_per_shard = ceil_div(num_experts, TLE_CLUSTER_SIZE) 

556 num_tokens = topk_ids.shape[0] if topk_ids.ndim > 1 else numel 

557 

558 def _run_tle_atomic_fused() -> bool: 

559 cumsum_tle = torch.zeros( 

560 (num_experts,), dtype=torch.int32, device=topk_ids.device 

561 ) 

562 num_blocks = _pick_tle_atomic_fused_num_blocks( 

563 numel, 

564 num_experts, 

565 block_tokens_taf, 

566 topk_ids.device, 

567 ) 

568 experts_per_prog = ceil_div(num_experts, num_blocks) 

569 while True: 

570 try: 

571 moe_align_block_size_tle_atomic_fused_coop[(num_blocks,)]( 

572 topk_ids, 

573 sorted_token_ids, 

574 expert_ids, 

575 num_tokens_post_pad, 

576 cumsum_tle, 

577 _block_mesh(num_blocks), 

578 num_experts, 

579 block_size, 

580 numel, 

581 numel_sorted_token_ids, 

582 numel_expert_ids, 

583 NUM_BLOCKS=num_blocks, 

584 BLOCK_TOKENS=block_tokens_taf, 

585 BLOCK_EXPERT=block_expert_tle, 

586 EXPERTS_PER_PROG=experts_per_prog, 

587 launch_cooperative_grid=True, 

588 ) 

589 return True 

590 except Exception as ex: 

591 msg = str(ex).lower() 

592 if "no allocator was set" in msg: 

593 _install_triton_default_allocator(topk_ids.device) 

594 continue 

595 if num_blocks <= 1 or "cooperative" not in msg: 

596 logger.debug( 

597 "TLE atomic fused launch failed, fallback to triton: %s", 

598 ex, 

599 ) 

600 return False 

601 num_blocks = max(1, num_blocks // 2) 

602 experts_per_prog = ceil_div(num_experts, num_blocks) 

603 

604 if ( 

605 num_tokens < TLE_BIG_TOKEN_THRESHOLD_TOKENS 

606 and _supports_tle_cluster_remote() 

607 ): 

608 try: 

609 moe_align_block_size_tle_cluster_fused[(1,)]( 

610 topk_ids, 

611 sorted_token_ids, 

612 expert_ids, 

613 num_tokens_post_pad, 

614 num_experts, 

615 block_size, 

616 numel, 

617 numel_sorted_token_ids, 

618 numel_expert_ids, 

619 mesh=_block_cluster_mesh_8(), 

620 CLUSTER_SIZE=TLE_CLUSTER_SIZE, 

621 BLOCK_EXPERT=block_expert_tle, 

622 EXPERTS_PER_SHARD=experts_per_shard, 

623 ) 

624 return 

625 except Exception as ex: 

626 logger.debug( 

627 "TLE cluster fused launch failed, fallback to atomic/triton: %s", 

628 ex, 

629 ) 

630 

631 if _run_tle_atomic_fused(): 

632 return 

633 

634 # The tensor needs to be padded before calculating IDs, 

635 # to prevent out-of-bounds address access. 

636 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) 

637 tokens_cnts = torch.zeros( 

638 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device 

639 ) 

640 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts) 

641 

642 moe_align_block_size_stage1[grid]( 

643 topk_ids, 

644 tokens_cnts, 

645 num_experts, 

646 numel, 

647 tokens_per_thread, 

648 sorted_token_ids, 

649 expert_ids, 

650 numel_sorted_token_ids, 

651 numel_expert_ids, 

652 block_size_sorted, 

653 block_size_expert, 

654 ) 

655 if num_experts == triton.next_power_of_2(num_experts): 

656 moe_align_block_size_stage2_vec[grid](tokens_cnts, num_experts) 

657 else: 

658 moe_align_block_size_stage2[grid](tokens_cnts, num_experts) 

659 moe_align_block_size_stage3[(1,)]( 

660 num_tokens_post_pad, 

661 tokens_cnts, 

662 cumsum, 

663 num_experts, 

664 num_experts_next_power_of_2, 

665 block_size, 

666 ) 

667 moe_align_block_size_stage4[grid]( 

668 topk_ids, 

669 sorted_token_ids, 

670 expert_ids, 

671 tokens_cnts, 

672 cumsum, 

673 num_experts, 

674 block_size, 

675 numel, 

676 tokens_per_thread, 

677 ) 

678 

679 

680def moe_align_block_size( 

681 topk_ids: torch.Tensor, 

682 block_size: int, 

683 num_experts: int, 

684 expert_map: Optional[torch.Tensor] = None, 

685 pad_sorted_ids: bool = False, 

686) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]": 

687 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) 

688 if pad_sorted_ids: 

689 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) 

690 sorted_ids = torch.empty( 

691 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device 

692 ) 

693 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) 

694 expert_ids = torch.empty( 

695 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device 

696 ) 

697 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) 

698 

699 moe_align_block_size_triton( 

700 topk_ids, 

701 num_experts, 

702 block_size, 

703 sorted_ids, 

704 expert_ids, 

705 num_tokens_post_pad, 

706 ) 

707 

708 if expert_map is not None: 

709 expert_ids = expert_map[expert_ids] 

710 

711 return sorted_ids, expert_ids, num_tokens_post_pad