Coverage for src/flag_gems/fused/moe_align_block_size.py: 34%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8try: 

9 import triton.experimental.tle.language.gpu as tle 

10 

11 HAS_TLE = True 

12except ImportError: 

13 HAS_TLE = False 

14 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19def ceil_div(a, b): 

20 return (a + b - 1) // b 

21 

22 

23def round_up(x: int, y: int) -> int: 

24 return ((x + y - 1) // y) * y 

25 

26 

27@triton.jit(do_not_specialize=["numel"]) 

28def moe_align_block_size_stage1_tle( 

29 topk_ids_ptr, 

30 tokens_cnts_ptr, 

31 num_experts: tl.constexpr, 

32 numel, 

33 tokens_per_thread: tl.constexpr, 

34 sorted_token_ids_ptr, 

35 expert_ids_ptr, 

36 numel_sorted_token_ids: tl.constexpr, 

37 numel_expert_ids: tl.constexpr, 

38 block_size_sorted: tl.constexpr, 

39 block_size_expert: tl.constexpr, 

40 BLOCK_EXPERT: tl.constexpr, 

41): 

42 pid = tl.program_id(0) 

43 

44 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted) 

45 mask_sorted = offsets_sorted < numel_sorted_token_ids 

46 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted) 

47 

48 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert) 

49 mask_expert = offsets_expert < numel_expert_ids 

50 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert) 

51 

52 start_idx = pid * tokens_per_thread 

53 off_c = (pid + 1) * num_experts 

54 

55 offsets = start_idx + tl.arange(0, tokens_per_thread) 

56 mask = offsets < numel 

57 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0).to(tl.int32) 

58 valid = mask & (expert_id < num_experts) 

59 expert_id = tl.where(valid, expert_id, 0) 

60 

61 expert_offsets = tl.arange(0, BLOCK_EXPERT) 

62 expert_mask = expert_offsets < num_experts 

63 

64 smem_counts = tle.alloc( 

65 [BLOCK_EXPERT], 

66 dtype=tl.int32, 

67 layout=None, 

68 scope=tle.smem, 

69 nv_mma_shared_layout=False, 

70 ) 

71 smem_ptrs = tle.local_ptr(smem_counts, (expert_offsets,)) 

72 tl.store(smem_ptrs, 0) 

73 tl.debug_barrier() 

74 

75 count_ptrs = tle.local_ptr(smem_counts, (expert_id,)) 

76 tl.atomic_add(count_ptrs, 1, mask=valid, sem="relaxed", scope="cta") 

77 tl.debug_barrier() 

78 

79 counts = tl.load(smem_ptrs, mask=expert_mask, other=0) 

80 tl.store(tokens_cnts_ptr + off_c + expert_offsets, counts, mask=expert_mask) 

81 

82 

83@triton.jit(do_not_specialize=["numel"]) 

84def moe_align_block_size_stage1( 

85 topk_ids_ptr, 

86 tokens_cnts_ptr, 

87 num_experts: tl.constexpr, 

88 numel, 

89 tokens_per_thread: tl.constexpr, 

90 sorted_token_ids_ptr, 

91 expert_ids_ptr, 

92 numel_sorted_token_ids: tl.constexpr, 

93 numel_expert_ids: tl.constexpr, 

94 block_size_sorted: tl.constexpr, 

95 block_size_expert: tl.constexpr, 

96): 

97 pid = tl.program_id(0) 

98 

99 offsets_sorted = pid * block_size_sorted + tl.arange(0, block_size_sorted) 

100 mask_sorted = offsets_sorted < numel_sorted_token_ids 

101 tl.store(sorted_token_ids_ptr + offsets_sorted, numel, mask=mask_sorted) 

102 

103 offsets_expert = pid * block_size_expert + tl.arange(0, block_size_expert) 

104 mask_expert = offsets_expert < numel_expert_ids 

105 tl.store(expert_ids_ptr + offsets_expert, 0, mask=mask_expert) 

106 

107 start_idx = pid * tokens_per_thread 

108 

109 off_c = (pid + 1) * num_experts 

110 

111 offsets = start_idx + tl.arange(0, tokens_per_thread) 

112 mask = offsets < numel 

113 expert_id = tl.load(topk_ids_ptr + offsets, mask=mask, other=0) 

114 tl.atomic_add(tokens_cnts_ptr + off_c + expert_id, 1, mask=mask) 

115 

116 

117@triton.jit 

118def moe_align_block_size_stage2_vec( 

119 tokens_cnts_ptr, 

120 num_experts: tl.constexpr, 

121): 

122 pid = tl.program_id(0) 

123 

124 offset = tl.arange(0, num_experts) + 1 

125 token_cnt = tl.load(tokens_cnts_ptr + offset * num_experts + pid) 

126 cnt = tl.cumsum(token_cnt, axis=0) 

127 tl.store(tokens_cnts_ptr + offset * num_experts + pid, cnt) 

128 

129 

130@triton.jit 

131def moe_align_block_size_stage2( 

132 tokens_cnts_ptr, 

133 num_experts: tl.constexpr, 

134): 

135 pid = tl.program_id(0) 

136 

137 last_cnt = 0 

138 for i in range(1, num_experts + 1): 

139 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) 

140 last_cnt = last_cnt + token_cnt 

141 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) 

142 

143 

144@triton.jit 

145def moe_align_block_size_stage3( 

146 total_tokens_post_pad_ptr, 

147 tokens_cnts_ptr, 

148 cumsum_ptr, 

149 num_experts: tl.constexpr, 

150 num_experts_next_power_of_2: tl.constexpr, 

151 block_size: tl.constexpr, 

152): 

153 off_cnt = num_experts * num_experts 

154 

155 expert_offsets = tl.arange(0, num_experts_next_power_of_2) 

156 mask = expert_offsets < num_experts 

157 token_cnts = tl.load(tokens_cnts_ptr + off_cnt + expert_offsets, mask=mask) 

158 aligned_cnts = tl.cdiv(token_cnts, block_size) * block_size 

159 

160 cumsum_values = tl.cumsum(aligned_cnts, axis=0) 

161 tl.store(cumsum_ptr + 1 + expert_offsets, cumsum_values, mask=mask) 

162 

163 total_tokens = tl.sum(aligned_cnts, axis=0) 

164 tl.store(total_tokens_post_pad_ptr, total_tokens) 

165 

166 

167@triton.jit(do_not_specialize=["numel"]) 

168def moe_align_block_size_stage4( 

169 topk_ids_ptr, 

170 sorted_token_ids_ptr, 

171 expert_ids_ptr, 

172 tokens_cnts_ptr, 

173 cumsum_ptr, 

174 num_experts: tl.constexpr, 

175 block_size: tl.constexpr, 

176 numel, 

177 tokens_per_thread: tl.constexpr, 

178): 

179 pid = tl.program_id(0) 

180 start_idx = tl.load(cumsum_ptr + pid) 

181 end_idx = tl.load(cumsum_ptr + pid + 1) 

182 

183 for i in range(start_idx, end_idx, block_size): 

184 tl.store(expert_ids_ptr + i // block_size, pid) 

185 

186 start_idx = pid * tokens_per_thread 

187 off_t = pid * num_experts 

188 

189 offset = tl.arange(0, tokens_per_thread) + start_idx 

190 mask = offset < numel 

191 expert_id = tl.load(topk_ids_ptr + offset, mask=mask) 

192 token_idx_in_expert = tl.atomic_add( 

193 tokens_cnts_ptr + off_t + expert_id, 1, mask=mask 

194 ) 

195 rank_post_pad = token_idx_in_expert + tl.load(cumsum_ptr + expert_id, mask=mask) 

196 tl.store(sorted_token_ids_ptr + rank_post_pad, offset, mask=mask) 

197 

198 

199def moe_align_block_size_triton( 

200 topk_ids: torch.Tensor, 

201 num_experts: int, 

202 block_size: int, 

203 sorted_token_ids: torch.Tensor, 

204 expert_ids: torch.Tensor, 

205 num_tokens_post_pad: torch.Tensor, 

206) -> None: 

207 numel = topk_ids.numel() 

208 numel_sorted_token_ids = sorted_token_ids.numel() 

209 numel_expert_ids = expert_ids.numel() 

210 # The tensor needs to be padded before calculating IDs, 

211 # to prevent out-of-bounds address access. 

212 

213 grid = (num_experts,) 

214 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) 

215 tokens_per_thread = triton.next_power_of_2(ceil_div(numel, num_experts)) 

216 

217 block_size_sorted = triton.next_power_of_2( 

218 ceil_div(numel_sorted_token_ids, num_experts) 

219 ) 

220 block_size_expert = triton.next_power_of_2(ceil_div(numel_expert_ids, num_experts)) 

221 

222 tokens_cnts = torch.zeros( 

223 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device 

224 ) 

225 num_experts_next_power_of_2 = triton.next_power_of_2(num_experts) 

226 block_expert = triton.cdiv(num_experts, 32) * 32 

227 

228 if HAS_TLE: 

229 moe_align_block_size_stage1_tle[grid]( 

230 topk_ids, 

231 tokens_cnts, 

232 num_experts, 

233 numel, 

234 tokens_per_thread, 

235 sorted_token_ids, 

236 expert_ids, 

237 numel_sorted_token_ids, 

238 numel_expert_ids, 

239 block_size_sorted, 

240 block_size_expert, 

241 BLOCK_EXPERT=block_expert, 

242 ) 

243 else: 

244 moe_align_block_size_stage1[grid]( 

245 topk_ids, 

246 tokens_cnts, 

247 num_experts, 

248 numel, 

249 tokens_per_thread, 

250 sorted_token_ids, 

251 expert_ids, 

252 numel_sorted_token_ids, 

253 numel_expert_ids, 

254 block_size_sorted, 

255 block_size_expert, 

256 ) 

257 if num_experts == triton.next_power_of_2(num_experts): 

258 moe_align_block_size_stage2_vec[grid]( 

259 tokens_cnts, 

260 num_experts, 

261 ) 

262 else: 

263 moe_align_block_size_stage2[grid]( 

264 tokens_cnts, 

265 num_experts, 

266 ) 

267 moe_align_block_size_stage3[(1,)]( 

268 num_tokens_post_pad, 

269 tokens_cnts, 

270 cumsum, 

271 num_experts, 

272 num_experts_next_power_of_2, 

273 block_size, 

274 ) 

275 moe_align_block_size_stage4[grid]( 

276 topk_ids, 

277 sorted_token_ids, 

278 expert_ids, 

279 tokens_cnts, 

280 cumsum, 

281 num_experts, 

282 block_size, 

283 numel, 

284 tokens_per_thread, 

285 ) 

286 

287 

288def moe_align_block_size( 

289 topk_ids: torch.Tensor, 

290 block_size: int, 

291 num_experts: int, 

292 expert_map: Optional[torch.Tensor] = None, 

293 pad_sorted_ids: bool = False, 

294) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]": 

295 logger.debug("GEMS MOE ALIGN BLOCK SIZE") 

296 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) 

297 if pad_sorted_ids: 

298 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) 

299 sorted_ids = torch.empty( 

300 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device 

301 ) 

302 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) 

303 expert_ids = torch.empty( 

304 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device 

305 ) 

306 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) 

307 

308 moe_align_block_size_triton( 

309 topk_ids, 

310 num_experts, 

311 block_size, 

312 sorted_ids, 

313 expert_ids, 

314 num_tokens_post_pad, 

315 ) 

316 

317 if expert_map is not None: 

318 expert_ids = expert_map[expert_ids] 

319 

320 return sorted_ids, expert_ids, num_tokens_post_pad