Coverage for src/flag_gems/runtime/backend/_tsingmicro/fused/moe_align_block_size.py: 0%

136 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8try: 

9 import triton.experimental.tle.language.gpu as tle 

10 

11 HAS_TLE = True 

12except ImportError: 

13 HAS_TLE = False 

14 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19def ceil_div(a, b): 

20 return (a + b - 1) // b 

21 

22 

23def round_up(x: int, y: int) -> int: 

24 return ((x + y - 1) // y) * y 

25 

26 

27@triton.jit(do_not_specialize=["numel"]) 

28def moe_align_block_size_stage1_tle( 

29 topk_ids_ptr, 

30 tokens_cnts_ptr, 

31 num_experts: tl.constexpr, 

32 numel, 

33 tokens_per_thread: tl.constexpr, 

34 sorted_token_ids_ptr, 

35 expert_ids_ptr, 

36 numel_sorted_token_ids: tl.constexpr, 

37 numel_expert_ids: tl.constexpr, 

38 block_size_sorted: tl.constexpr, 

39 block_size_expert: tl.constexpr, 

40 BLOCK_EXPERT: tl.constexpr, 

41): 

42 pid = tl.program_id(0) 

43 

44 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted) 

45 mask_sorted = offsets_sorted < numel_sorted_token_ids 

46 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted) 

47 

48 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert) 

49 mask_expert = offsets_expert < numel_expert_ids 

50 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert) 

51 

52 start_idx = pid * tokens_per_thread 

53 off_c = (pid + 1) * num_experts 

54 

55 offsets = start_idx + tl.arange(0, tokens_per_thread) 

56 mask = offsets < numel 

57 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0).to(tl.int32) 

58 valid = mask & (expert_id < num_experts) 

59 expert_id = tl.where(valid, expert_id, 0) 

60 

61 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

62 expert_mask = expert_offsets < num_experts 

63 

64 smem_counts = tle.alloc( 

65 [BLOCK_EXPERT], 

66 dtype=tl.int32, 

67 layout=None, 

68 scope=tle.smem, 

69 nv_mma_shared_layout=False, 

70 ) 

71 smem_ptrs = tle.local_ptr(smem_counts, (expert_offsets,)) 

72 tl.store(smem_ptrs, 0) 

73 tl.debug_barrier() 

74 

75 count_ptrs = tle.local_ptr(smem_counts, (expert_id,)) 

76 tl.atomic_add(count_ptrs, 1, mask=valid, sem="relaxed", scope="cta") 

77 tl.debug_barrier() 

78 

79 counts = tl.load(smem_ptrs, mask=expert_mask, other=0) 

80 tl.store(tokens_cnts_ptr + off_c + expert_offsets, counts, mask=expert_mask) 

81 

82 

83@triton.jit(do_not_specialize=["numel"]) 

84def moe_align_block_size_stage1( 

85 topk_ids_ptr, 

86 tokens_cnts_ptr, 

87 num_experts: tl.constexpr, 

88 numel, 

89 tokens_per_thread: tl.constexpr, 

90 sorted_token_ids_ptr, 

91 expert_ids_ptr, 

92 numel_sorted_token_ids: tl.constexpr, 

93 numel_expert_ids: tl.constexpr, 

94 block_size_sorted: tl.constexpr, 

95 block_size_expert: tl.constexpr, 

96): 

97 pid = tl.program_id(0) 

98 

99 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted) 

100 mask_sorted = offsets_sorted < numel_sorted_token_ids 

101 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted) 

102 

103 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert) 

104 mask_expert = offsets_expert < numel_expert_ids 

105 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert) 

106 

107 start_idx = pid * tokens_per_thread 

108 

109 off_c = (pid + 1) * num_experts 

110 

111 for i in range(tokens_per_thread): 

112 if start_idx + i < numel: 

113 idx = tl.load(topk_ids_ptr + start_idx + i) 

114 token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) 

115 tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) 

116 

117 

118@triton.jit 

119def moe_align_block_size_stage2_vec( 

120 tokens_cnts_ptr, 

121 num_experts: tl.constexpr, 

122 experts_per_cta: tl.constexpr, 

123): 

124 pid = tl.program_id(0) 

125 

126 # row: 行索引 1..num_experts,shape [num_experts, 1] 

127 row = (tl.arange(0, num_experts) + 1)[:, None] 

128 # col: 本 CTA 负责的列范围,shape [1, experts_per_cta] 

129 col = (tl.arange(0, experts_per_cta) + pid * experts_per_cta)[None, :] 

130 

131 addr = row * num_experts + col # [num_experts, experts_per_cta] 

132 

133 token_cnt = tl.load(tokens_cnts_ptr + addr) 

134 cnt = tl.cumsum(token_cnt, axis=0) 

135 tl.store(tokens_cnts_ptr + addr, cnt) 

136 

137 

138@triton.jit 

139def moe_align_block_size_stage2( 

140 tokens_cnts_ptr, 

141 num_experts: tl.constexpr, 

142): 

143 pid = tl.program_id(0) 

144 

145 last_cnt = 0 

146 for i in range(1, num_experts + 1): 

147 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) 

148 last_cnt = last_cnt + token_cnt 

149 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) 

150 

151 

152@triton.jit 

153def moe_align_block_size_stage3( 

154 total_tokens_post_pad_ptr, 

155 tokens_cnts_ptr, 

156 cumsum_ptr, 

157 num_experts: tl.constexpr, 

158 num_experts_next_power_of_2: tl.constexpr, 

159 block_size: tl.constexpr, 

160): 

161 off_cnt = num_experts * num_experts 

162 

163 expert_offsets = tl.arange(0, num_experts_next_power_of_2) 

164 mask = expert_offsets < num_experts 

165 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask) 

166 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size 

167 

168 cumsum_values = tl.cumsum(aligned_cnts, axis=0) 

169 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask) 

170 

171 total_tokens = tl.sum(aligned_cnts, axis=0) 

172 tl.store(total_tokens_post_pad_ptr, total_tokens) 

173 

174 

175@triton.jit(do_not_specialize=["numel"]) 

176def moe_align_block_size_stage4( 

177 topk_ids_ptr, 

178 sorted_token_ids_ptr, 

179 expert_ids_ptr, 

180 tokens_cnts_ptr, 

181 cumsum_ptr, 

182 num_experts: tl.constexpr, 

183 block_size: tl.constexpr, 

184 numel, 

185 tokens_per_thread: tl.constexpr, 

186): 

187 pid = tl.program_id(0) 

188 start_idx = tl.load(cumsum_ptr + pid) 

189 end_idx = tl.load(cumsum_ptr + pid + 1) 

190 

191 for i in range(start_idx, end_idx, block_size): 

192 tl.store(expert_ids_ptr + i // block_size, pid) 

193 

194 start_idx = pid * tokens_per_thread 

195 off_t = pid * num_experts 

196 

197 for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): 

198 expert_id = tl.load(topk_ids_ptr + i) 

199 token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) 

200 rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) 

201 tl.store(sorted_token_ids_ptr + rank_post_pad, i) 

202 tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) 

203 

204 

205def moe_align_block_size_triton( 

206 topk_ids: torch.Tensor, 

207 num_experts: int, 

208 block_size: int, 

209 sorted_token_ids: torch.Tensor, 

210 expert_ids: torch.Tensor, 

211 num_tokens_post_pad: torch.Tensor, 

212) -> None: 

213 numel = topk_ids.numel() 

214 numel_sorted_token_ids = sorted_token_ids.numel() 

215 numel_expert_ids = expert_ids.numel() 

216 # The tensor needs to be padded before calculating IDs, 

217 # to prevent out-of-bounds address access. 

218 

219 grid = (num_experts,) 

220 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) 

221 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts)) 

222 

223 block_size_sorted = triton.next_power_of_2( 

224 ceil_div(numel_sorted_token_ids, num_experts) 

225 ) 

226 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts)) 

227 

228 tokens_cnts = torch.zeros( 

229 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device 

230 ) 

231 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts) 

232 block_expert = triton.cdiv(num_experts, 32) * 32 

233 

234 if HAS_TLE: 

235 moe_align_block_size_stage1_tle[grid]( 

236 topk_ids, 

237 tokens_cnts, 

238 num_experts, 

239 numel, 

240 tokens_per_thread, 

241 sorted_token_ids, 

242 expert_ids, 

243 numel_sorted_token_ids, 

244 numel_expert_ids, 

245 block_size_sorted, 

246 block_size_expert, 

247 BLOCK_EXPERT=block_expert, 

248 ) 

249 else: 

250 moe_align_block_size_stage1[grid]( 

251 topk_ids, 

252 tokens_cnts, 

253 num_experts, 

254 numel, 

255 tokens_per_thread, 

256 sorted_token_ids, 

257 expert_ids, 

258 numel_sorted_token_ids, 

259 numel_expert_ids, 

260 block_size_sorted, 

261 block_size_expert, 

262 ) 

263 if num_experts == triton.next_power_of_2(num_experts): 

264 experts_per_cta = num_experts // 16 

265 grid2 = (num_experts // experts_per_cta,) 

266 moe_align_block_size_stage2_vec[grid2]( 

267 tokens_cnts, 

268 num_experts, 

269 experts_per_cta, 

270 ) 

271 else: 

272 moe_align_block_size_stage2[grid]( 

273 tokens_cnts, 

274 num_experts, 

275 ) 

276 moe_align_block_size_stage3[(1,)]( 

277 num_tokens_post_pad, 

278 tokens_cnts, 

279 cumsum, 

280 num_experts, 

281 num_experts_next_power_of_2, 

282 block_size, 

283 ) 

284 moe_align_block_size_stage4[grid]( 

285 topk_ids, 

286 sorted_token_ids, 

287 expert_ids, 

288 tokens_cnts, 

289 cumsum, 

290 num_experts, 

291 block_size, 

292 numel, 

293 tokens_per_thread, 

294 ) 

295 

296 

297def moe_align_block_size( 

298 topk_ids: torch.Tensor, 

299 block_size: int, 

300 num_experts: int, 

301 expert_map: Optional[torch.Tensor] = None, 

302 pad_sorted_ids: bool = False, 

303) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]": 

304 logger.debug("GEMS MOE ALIGN BLOCK SIZE") 

305 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) 

306 if pad_sorted_ids: 

307 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) 

308 sorted_ids = torch.empty( 

309 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device 

310 ) 

311 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) 

312 expert_ids = torch.empty( 

313 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device 

314 ) 

315 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) 

316 

317 moe_align_block_size_triton( 

318 topk_ids, 

319 num_experts, 

320 block_size, 

321 sorted_ids, 

322 expert_ids, 

323 num_tokens_post_pad, 

324 ) 

325 

326 if expert_map is not None: 

327 expert_ids = expert_map[expert_ids] 

328 

329 return sorted_ids, expert_ids, num_tokens_post_pad