Coverage for src/flag_gems/fused/fused_moe.py: 10%

163 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2from typing import Any, Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.fused.moe_align_block_size import moe_align_block_size 

9from flag_gems.fused.moe_sum import moe_sum 

10from flag_gems.fused.silu_and_mul import silu_and_mul_kernel 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.jit 

16def write_zeros_to_output( 

17 c_ptr, 

18 stride_cm, 

19 stride_cn, 

20 pid_n, 

21 N, 

22 offs_token, 

23 token_mask, 

24 BLOCK_SIZE_M, 

25 BLOCK_SIZE_N, 

26 compute_type, 

27): 

28 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) 

29 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

30 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

31 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

32 tl.store(c_ptrs, accumulator, mask=c_mask) 

33 

34 

35@triton.jit 

36def fused_moe_kernel( 

37 # Pointers to matrices 

38 a_ptr, 

39 b_ptr, 

40 c_ptr, 

41 a_scale_ptr, 

42 b_scale_ptr, 

43 topk_weights_ptr, 

44 sorted_token_ids_ptr, 

45 expert_ids_ptr, 

46 num_tokens_post_padded_ptr, 

47 # Matrix dimensions 

48 N, 

49 K, 

50 EM, 

51 num_valid_tokens, 

52 # Strides 

53 stride_am, 

54 stride_ak, 

55 stride_be, 

56 stride_bk, 

57 stride_bn, 

58 stride_cm, 

59 stride_cn, 

60 stride_asm, 

61 stride_ask, 

62 stride_bse, 

63 stride_bsk, 

64 stride_bsn, 

65 # Block size for block-wise quantization 

66 group_n: tl.constexpr, 

67 group_k: tl.constexpr, 

68 # Meta-parameters 

69 BLOCK_SIZE_M: tl.constexpr, 

70 BLOCK_SIZE_N: tl.constexpr, 

71 BLOCK_SIZE_K: tl.constexpr, 

72 GROUP_SIZE_M: tl.constexpr, 

73 MUL_ROUTED_WEIGHT: tl.constexpr, 

74 top_k: tl.constexpr, 

75 compute_type: tl.constexpr, 

76 use_fp8_w8a8: tl.constexpr, 

77 use_int8_w8a8: tl.constexpr, 

78 per_channel_quant: tl.constexpr, 

79): 

80 """ 

81 Fused MoE GEMM kernel with expert-based indirect addressing. 

82 

83 Computes: C[t, :] = A[t // topk, :] @ B[expert(t), :, :] [* topk_weight[t]] 

84 

85 Key Parameters: 

86 - A: Input activations [M, K] (or quantized) 

87 - B: Stacked expert weights [E, N, K] 

88 - C: Output [num_sorted_tokens, N] (indexed by sorted_token_ids) 

89 - sorted_token_ids: Per-expert sorted token indices (from moe_align_block_size) 

90 - expert_ids: Expert index for each M-block 

91 """ 

92 # Map program id to the block of C it should compute. 

93 # Grouped ordering promotes L2 data reuse. 

94 pid = tl.program_id(axis=0) 

95 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

96 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

97 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

98 group_id = pid // num_pid_in_group 

99 first_pid_m = group_id * GROUP_SIZE_M 

100 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

101 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

102 pid_n = (pid % num_pid_in_group) // group_size_m 

103 

104 # Load sorted token indices for this M-block 

105 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

106 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

107 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

108 return 

109 

110 offs_token_id = pid_m * BLOCK_SIZE_M + offs 

111 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) 

112 offs_token = offs_token.to(tl.int64) 

113 token_mask = offs_token < num_valid_tokens 

114 

115 # Determine which expert this block belongs to 

116 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

117 if off_experts == -1: 

118 write_zeros_to_output( 

119 c_ptr, 

120 stride_cm, 

121 stride_cn, 

122 pid_n, 

123 N, 

124 offs_token, 

125 token_mask, 

126 BLOCK_SIZE_M, 

127 BLOCK_SIZE_N, 

128 compute_type, 

129 ) 

130 return 

131 

132 # Set up A and B pointers 

133 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

134 offs_k = tl.arange(0, BLOCK_SIZE_K) 

135 a_ptrs = a_ptr + ( 

136 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

137 ) 

138 b_ptrs = ( 

139 b_ptr 

140 + off_experts * stride_be 

141 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

142 ) 

143 

144 # Load quantization scales based on mode 

145 if use_fp8_w8a8 or use_int8_w8a8: 

146 if group_k > 0 and group_n > 0: 

147 # block-wise quantization 

148 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

149 offs_bsn = offs_bn // group_n 

150 b_scale_ptrs = ( 

151 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn 

152 ) 

153 elif per_channel_quant: 

154 # per-channel quantization 

155 b_scale_ptrs = ( 

156 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

157 ) 

158 b_scale = tl.load(b_scale_ptrs) 

159 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

160 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] 

161 else: 

162 # per-tensor quantization 

163 a_scale = tl.load(a_scale_ptr) 

164 b_scale = tl.load(b_scale_ptr + off_experts) 

165 

166 # Main GEMM loop: accumulate in float32 

167 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

168 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

169 a = tl.load( 

170 a_ptrs, 

171 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

172 other=0.0, 

173 ) 

174 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) 

175 

176 if use_fp8_w8a8 or use_int8_w8a8: 

177 if group_k > 0 and group_n > 0: 

178 k_start = k * BLOCK_SIZE_K 

179 offs_ks = k_start // group_k 

180 a_scale = tl.load( 

181 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 

182 ) 

183 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) 

184 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] 

185 else: 

186 accumulator = tl.dot(a, b, acc=accumulator) 

187 else: 

188 # Fused dot-accumulate: on SM90 this maps to WGMMA with 

189 # in-place accumulation, avoiding a separate add instruction. 

190 accumulator = tl.dot(a, b, acc=accumulator) 

191 

192 a_ptrs += BLOCK_SIZE_K * stride_ak 

193 b_ptrs += BLOCK_SIZE_K * stride_bk 

194 

195 # Post-loop dequantization 

196 if (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0): 

197 accumulator = accumulator * a_scale * b_scale 

198 

199 # Router weight multiplication (in float32 for numerical stability) 

200 if MUL_ROUTED_WEIGHT: 

201 moe_weight = tl.load( 

202 topk_weights_ptr + offs_token, 

203 mask=token_mask, 

204 other=0, 

205 ) 

206 accumulator *= moe_weight[:, None] 

207 

208 accumulator = accumulator.to(compute_type) 

209 

210 # Write back 

211 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

212 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

213 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

214 tl.store(c_ptrs, accumulator, mask=c_mask) 

215 

216 

217def get_default_config( 

218 M: int, 

219 E: int, 

220 N: int, 

221 K: int, 

222 topk: int, 

223 dtype: str | None, 

224 block_shape: list[int] | None = None, 

225) -> dict[str, int]: 

226 """Return a reasonable default Triton config for the fused MoE kernel.""" 

227 if dtype == "fp8_w8a8" and block_shape is not None: 

228 config = { 

229 "BLOCK_SIZE_M": 16 if M <= 64 else 64, 

230 "BLOCK_SIZE_N": block_shape[0], 

231 "BLOCK_SIZE_K": block_shape[1], 

232 "GROUP_SIZE_M": 1 if M <= 16 else 32, 

233 "num_warps": 4, 

234 "num_stages": 3, 

235 } 

236 else: 

237 if M <= 32: 

238 block_m = 16 

239 elif M <= 96: 

240 block_m = 32 

241 elif M <= 512: 

242 block_m = 64 

243 else: 

244 block_m = 128 

245 

246 # --- Tile sizing optimised for H100/H800 SM90 GPUs --- 

247 # Larger N/K tiles improve compute intensity and reduce grid 

248 # launches for the common case where N is large (e.g. 14336). 

249 if N >= 4096: 

250 block_n = 128 if M <= 128 else 256 

251 elif N >= 1024: 

252 block_n = 64 if M <= 64 else 128 

253 else: 

254 block_n = 64 if M <= 64 else 128 

255 

256 # K-tile: 128 gives better arithmetic intensity. 

257 if dtype == "fp8_w8a8": 

258 block_k = 128 

259 elif K >= 4096 or M <= 64: 

260 block_k = 128 

261 else: 

262 block_k = 64 

263 

264 # Group-M: promotes L2 reuse across M-blocks. 

265 tokens_per_expert = (M * topk) // max(E, 1) 

266 if tokens_per_expert > 128: 

267 group_m = 16 

268 elif tokens_per_expert > 32: 

269 group_m = 8 

270 else: 

271 group_m = 1 

272 

273 num_warps = 4 if block_m * block_n < 8192 else 8 

274 num_stages = 3 

275 

276 # Shared-memory guard (~232 KB on H100/H800). 

277 smem_per_stage = (block_m * block_k + block_k * block_n) * 2 

278 while num_stages > 2 and smem_per_stage * num_stages > 200_000: 

279 num_stages -= 1 

280 

281 config = { 

282 "BLOCK_SIZE_M": block_m, 

283 "BLOCK_SIZE_N": block_n, 

284 "BLOCK_SIZE_K": block_k, 

285 "GROUP_SIZE_M": group_m, 

286 "num_warps": num_warps, 

287 "num_stages": num_stages, 

288 } 

289 return config 

290 

291 

292def invoke_fused_moe_triton_kernel( 

293 A: torch.Tensor, 

294 B: torch.Tensor, 

295 C: torch.Tensor, 

296 A_scale: Optional[torch.Tensor], 

297 B_scale: Optional[torch.Tensor], 

298 topk_weights: Optional[torch.Tensor], 

299 sorted_token_ids: torch.Tensor, 

300 expert_ids: torch.Tensor, 

301 num_tokens_post_padded: torch.Tensor, 

302 mul_routed_weight: bool, 

303 top_k: int, 

304 config: dict[str, Any], 

305 compute_type: tl.dtype, 

306 use_fp8_w8a8: bool = False, 

307 use_int8_w8a8: bool = False, 

308 per_channel_quant: bool = False, 

309 block_shape: Optional[list[int]] = None, 

310) -> None: 

311 """ 

312 Launch the fused_moe_kernel Triton kernel. 

313 

314 Args: 

315 A: Input activations [M, K] 

316 B: Expert weight matrices [E, N, K] 

317 C: Output buffer [M, topk, N] 

318 A_scale: Activation quantization scale (or None) 

319 B_scale: Weight quantization scale (or None) 

320 topk_weights: Router weights [M, topk] (or None) 

321 sorted_token_ids: From moe_align_block_size 

322 expert_ids: From moe_align_block_size 

323 num_tokens_post_padded: From moe_align_block_size 

324 mul_routed_weight: Whether to multiply router weights in-kernel 

325 top_k: Number of top experts per token 

326 config: Triton config dict with BLOCK_SIZE_M/N/K, GROUP_SIZE_M, etc. 

327 compute_type: Triton dtype for compute (tl.bfloat16, tl.float16, etc.) 

328 use_fp8_w8a8: FP8 weight+activation quantization 

329 use_int8_w8a8: INT8 weight+activation quantization 

330 per_channel_quant: Per-channel quantization mode 

331 block_shape: [block_n, block_k] for block-wise quantization 

332 """ 

333 assert topk_weights is not None or not mul_routed_weight 

334 assert topk_weights is None or topk_weights.stride(1) == 1 

335 assert sorted_token_ids.stride(0) == 1 

336 

337 if use_fp8_w8a8 or use_int8_w8a8: 

338 assert B_scale is not None 

339 else: 

340 assert A_scale is None 

341 assert B_scale is None 

342 

343 M = A.size(0) 

344 num_tokens = M * top_k 

345 EM = sorted_token_ids.size(0) 

346 if A.size(0) < config["BLOCK_SIZE_M"]: 

347 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) 

348 

349 grid = lambda META: ( 

350 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

351 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

352 ) 

353 

354 config = config.copy() 

355 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") 

356 if block_shape is not None: 

357 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) 

358 

359 fused_moe_kernel[grid]( 

360 A, 

361 B, 

362 C, 

363 A_scale, 

364 B_scale, 

365 topk_weights, 

366 sorted_token_ids, 

367 expert_ids, 

368 num_tokens_post_padded, 

369 B.size(1), # N 

370 B.size(2), # K 

371 EM, 

372 num_tokens, 

373 A.stride(0), 

374 A.stride(1), 

375 B.stride(0), 

376 B.stride(2), 

377 B.stride(1), 

378 C.stride(1), 

379 C.stride(2), 

380 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, 

381 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, 

382 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, 

383 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, 

384 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, 

385 0 if block_shape is None else block_shape[0], 

386 0 if block_shape is None else block_shape[1], 

387 MUL_ROUTED_WEIGHT=mul_routed_weight, 

388 top_k=top_k, 

389 compute_type=compute_type, 

390 use_fp8_w8a8=use_fp8_w8a8, 

391 use_int8_w8a8=use_int8_w8a8, 

392 per_channel_quant=per_channel_quant, 

393 BLOCK_SIZE_K=BLOCK_SIZE_K, 

394 **config, 

395 ) 

396 

397 

398def _apply_silu_and_mul(out: torch.Tensor, inp: torch.Tensor) -> None: 

399 """Apply SiLU-and-Mul activation: out = SiLU(inp[:, :N]) * inp[:, N:].""" 

400 N = inp.shape[-1] // 2 

401 x, y = inp[:, :N], inp[:, N:] 

402 silu_and_mul_kernel(x, y, out0=out) 

403 

404 

405def fused_experts_impl( 

406 hidden_states: torch.Tensor, 

407 w1: torch.Tensor, 

408 w2: torch.Tensor, 

409 topk_weights: torch.Tensor, 

410 topk_ids: torch.Tensor, 

411 num_experts: int = -1, 

412 activation: str = "silu", 

413) -> torch.Tensor: 

414 """ 

415 Complete fused MoE forward pass (bf16/fp16, no quantization). 

416 

417 Pipeline: 

418 moe_align_block_size → GEMM1(up+gate) → SiLU+Mul → GEMM2(down) → moe_sum 

419 

420 Args: 

421 hidden_states: [num_tokens, hidden_size] 

422 w1: [E, intermediate_size * 2, hidden_size] (gate + up projection) 

423 w2: [E, hidden_size, intermediate_size] (down projection) 

424 topk_weights: [num_tokens, topk] 

425 topk_ids: [num_tokens, topk] 

426 num_experts: Total number of experts (default: inferred from w1) 

427 activation: Activation function name ("silu") 

428 

429 Returns: 

430 output: [num_tokens, hidden_size] 

431 """ 

432 logger.debug("GEMS FUSED MOE") 

433 assert ( 

434 activation == "silu" 

435 ), f"Only 'silu' activation is supported, got {activation}" 

436 

437 M, K = hidden_states.shape 

438 E = w1.shape[0] 

439 N = w1.shape[1] # intermediate_size * 2 

440 top_k = topk_ids.shape[1] 

441 

442 if num_experts <= 0: 

443 num_experts = E 

444 

445 # Determine compute type 

446 if hidden_states.dtype == torch.bfloat16: 

447 compute_type = tl.bfloat16 

448 elif hidden_states.dtype == torch.float16: 

449 compute_type = tl.float16 

450 elif hidden_states.dtype == torch.float32: 

451 compute_type = tl.float32 

452 else: 

453 raise ValueError(f"Unsupported dtype: {hidden_states.dtype}") 

454 

455 # Get kernel config 

456 config = get_default_config(M, E, w2.shape[1], K, top_k, None) 

457 

458 # Step 1: Align tokens to experts 

459 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

460 topk_ids, config["BLOCK_SIZE_M"], num_experts 

461 ) 

462 

463 # Allocate intermediate buffers 

464 # GEMM1 output: [M, topk, N] 

465 intermediate_cache1 = torch.empty( 

466 (M, top_k, N), dtype=hidden_states.dtype, device=hidden_states.device 

467 ) 

468 # After activation (SiLU+Mul): [M * topk, N // 2] 

469 intermediate_cache2 = torch.empty( 

470 (M * top_k, N // 2), dtype=hidden_states.dtype, device=hidden_states.device 

471 ) 

472 # GEMM2 output: [M, topk, K] 

473 intermediate_cache3 = torch.empty( 

474 (M, top_k, K), dtype=hidden_states.dtype, device=hidden_states.device 

475 ) 

476 # Final output: [M, K] 

477 output = torch.zeros((M, K), dtype=hidden_states.dtype, device=hidden_states.device) 

478 

479 # Step 2: GEMM1 — hidden_states @ W1 → intermediate_cache1 

480 invoke_fused_moe_triton_kernel( 

481 A=hidden_states, 

482 B=w1, 

483 C=intermediate_cache1, 

484 A_scale=None, 

485 B_scale=None, 

486 topk_weights=None, 

487 sorted_token_ids=sorted_token_ids, 

488 expert_ids=expert_ids, 

489 num_tokens_post_padded=num_tokens_post_padded, 

490 mul_routed_weight=False, 

491 top_k=top_k, 

492 config=config, 

493 compute_type=compute_type, 

494 ) 

495 

496 # Step 3: Activation — SiLU(gate) * up 

497 _apply_silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) 

498 

499 # Step 4: GEMM2 — intermediate @ W2 → intermediate_cache3 

500 # Multiply router weights here 

501 invoke_fused_moe_triton_kernel( 

502 A=intermediate_cache2, 

503 B=w2, 

504 C=intermediate_cache3, 

505 A_scale=None, 

506 B_scale=None, 

507 topk_weights=topk_weights, 

508 sorted_token_ids=sorted_token_ids, 

509 expert_ids=expert_ids, 

510 num_tokens_post_padded=num_tokens_post_padded, 

511 mul_routed_weight=True, 

512 top_k=1, # After activation, each token-expert pair is independent 

513 config=config, 

514 compute_type=compute_type, 

515 ) 

516 

517 # Step 5: Reduce — sum over topK experts 

518 moe_sum(intermediate_cache3, output) 

519 

520 return output