Coverage for src/flag_gems/fused/fused_moe.py: 40%

614 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# 

4# Adapted from the vLLM project (https://github.com/vllm-project/vllm). 

5# Source files under vllm/model_executor/layers/: 

6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl 

7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation 

8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input 

9# fused_moe/config.py – _get_config_dtype_str 

10# quantization/utils/mxfp4_utils.py – dequant_mxfp4 

11# quantization/utils/mxfp6_utils.py – dequant_mxfp6 

12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE 

13 

14 

15import functools 

16import logging 

17from enum import Enum 

18from typing import Any, Optional 

19 

20import torch 

21import torch.nn.functional as F 

22import triton 

23import triton.language as tl 

24 

25from flag_gems.fused.moe_align_block_size import moe_align_block_size 

26from flag_gems.fused.moe_sum import moe_sum 

27from flag_gems.utils import pointwise_dynamic 

28 

29logger = logging.getLogger(__name__) 

30 

31# OCP MX quantization helpers (requires amd-quark) 

32 

33OCP_MX_BLOCK_SIZE = 32 

34 

35 

36def dequant_mxfp4( 

37 x: torch.Tensor, 

38 scale: torch.Tensor, 

39 float_dtype: torch.dtype, 

40) -> torch.Tensor: 

41 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4.""" 

42 try: 

43 from quark.torch.kernel import mx 

44 except ImportError as err: 

45 raise ImportError("amd-quark is required for MX-FP4") from err 

46 

47 return mx.dq_mxfp4(x, scale, float_dtype) 

48 

49 

50def dequant_mxfp6( 

51 x: torch.Tensor, 

52 scale: torch.Tensor, 

53 float_dtype: torch.dtype, 

54 quant_dtype: str, 

55) -> torch.Tensor: 

56 """Dequantize MXFP6 tensor via quark hw_emulation.""" 

57 try: 

58 from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( 

59 dequantize_fp4_fp6_per_group, 

60 ) 

61 from quark.torch.utils.pack import create_pack_method 

62 except ImportError as err: 

63 raise ImportError("amd-quark is required for MX-FP6") from err 

64 

65 pack_method = create_pack_method(None, dtype=quant_dtype) 

66 unpacked_x = pack_method.unpack(x, reorder=False) 

67 

68 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype) 

69 

70 return dequantize_fp4_fp6_per_group( 

71 unpacked_x, 

72 scale, 

73 axis=-1, 

74 group_size=OCP_MX_BLOCK_SIZE, 

75 quant_dtype=quant_dtype, 

76 ).to(float_dtype) 

77 

78 

79# Activation quantization helpers 

80 

81 

82def get_moe_configs( 

83 E: int, 

84 N: int, 

85 dtype: str | None, 

86 block_n: int | None = None, 

87 block_k: int | None = None, 

88) -> dict[int, Any] | None: 

89 """Return None; FlagGems uses get_default_config instead.""" 

90 return None 

91 

92 

93def try_get_optimal_moe_config( 

94 w1_shape: tuple[int, ...], 

95 w2_shape: tuple[int, ...], 

96 top_k: int, 

97 dtype: str | None, 

98 M: int, 

99 block_shape: list[int] | None = None, 

100) -> dict[str, int]: 

101 override_config: Optional[dict[str, Any]] = None 

102 if override_config: 

103 config = override_config 

104 else: 

105 # First try to load optimal config from the file 

106 E, _, N = w2_shape 

107 if dtype == "int4_w4a16": 

108 N = N * 2 

109 block_n = block_shape[0] if block_shape else 0 

110 block_k = block_shape[1] if block_shape else 0 

111 configs = get_moe_configs(E, N, dtype, block_n, block_k) 

112 

113 if configs: 

114 config = configs[min(configs.keys(), key=lambda x: abs(x - M))] 

115 else: 

116 config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) 

117 return config 

118 

119 

120def _get_config_quant_dtype( 

121 use_fp8_w8a8: bool, 

122 use_int8_w8a8: bool, 

123 ocp_mx_scheme: str | None, 

124) -> None | torch.dtype | str: 

125 """Map quantization flags to the corresponding dtype.""" 

126 if use_fp8_w8a8: 

127 return torch.float8_e4m3fn 

128 elif use_int8_w8a8: 

129 return torch.int8 

130 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": 

131 return "mxfp4" 

132 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}: 

133 return "mxfp6_e3m2" 

134 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: 

135 return "mxfp6_e2m3" 

136 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}: 

137 return torch.bfloat16 

138 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}: 

139 return torch.float8_e4m3fn 

140 

141 return None 

142 

143 

144def get_moe_wna16_block_config( 

145 config: dict[str, int], 

146 use_moe_wna16_cuda: bool, 

147 num_valid_tokens: int, 

148 size_k: int, 

149 size_n: int, 

150 num_experts: int, 

151 group_size: int, 

152 real_top_k: int, 

153 block_size_m: int, 

154): 

155 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: 

156 return {} 

157 if not use_moe_wna16_cuda: 

158 if num_valid_tokens // real_top_k == 1: 

159 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} 

160 else: 

161 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} 

162 else: 

163 block_size_n = 128 

164 block_size_k = 128 

165 if block_size_k <= group_size: 

166 block_size_k = group_size 

167 

168 num_n_blocks = size_k // block_size_k 

169 num_k_blocks = size_n // block_size_k 

170 num_m_blocks = ( 

171 num_valid_tokens + block_size_m - 1 

172 ) / block_size_m + num_experts 

173 if num_valid_tokens // real_top_k <= block_size_m: 

174 num_m_blocks = min(num_m_blocks, num_valid_tokens) 

175 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks 

176 

177 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256: 

178 block_size_k = 256 

179 num_blocks = num_blocks // (256 // block_size_k) 

180 

181 if ( 

182 num_m_blocks <= 16 

183 and size_k % (block_size_k * 2) == 0 

184 and size_k % (block_size_k * 2) == 0 

185 and block_size_k <= 512 

186 and num_blocks >= 512 

187 ): 

188 block_size_k = block_size_k * 2 

189 num_blocks = num_blocks // 2 

190 

191 if num_blocks > 1024: 

192 block_size_n = 256 

193 num_n_blocks = num_n_blocks // 2 

194 num_blocks = num_blocks // 2 

195 

196 if size_n <= 1024 and num_blocks >= 1024: 

197 block_size_n = 1024 

198 

199 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size) 

200 

201 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} 

202 

203 

204def get_default_config( 

205 M: int, 

206 E: int, 

207 N: int, 

208 K: int, 

209 topk: int, 

210 dtype: str | None, 

211 block_shape: list[int] | None = None, 

212) -> dict[str, int]: 

213 """Default Triton config for fused MoE kernel.""" 

214 if dtype == "fp8_w8a8" and block_shape is not None: 

215 config = { 

216 "BLOCK_SIZE_M": 16 if M <= 64 else 64, 

217 "BLOCK_SIZE_N": block_shape[0], 

218 "BLOCK_SIZE_K": block_shape[1], 

219 "GROUP_SIZE_M": 1 if M <= 16 else 32, 

220 "num_warps": 4, 

221 "num_stages": 3, 

222 } 

223 else: 

224 if M <= 32: 

225 block_m = 16 

226 elif M <= 96: 

227 block_m = 32 

228 elif M <= 512: 

229 block_m = 64 

230 else: 

231 block_m = 128 

232 

233 # Tile sizing for H100/H800 

234 if N >= 4096: 

235 block_n = 128 if M <= 128 else 256 

236 elif N >= 1024: 

237 block_n = 64 if M <= 64 else 128 

238 else: 

239 block_n = 64 if M <= 64 else 128 

240 

241 if dtype == "fp8_w8a8": 

242 block_k = 128 

243 elif K >= 4096 or M <= 64: 

244 block_k = 128 

245 else: 

246 block_k = 64 

247 

248 tokens_per_expert = (M * topk) // max(E, 1) 

249 if tokens_per_expert > 128: 

250 group_m = 16 

251 elif tokens_per_expert > 32: 

252 group_m = 8 

253 else: 

254 group_m = 1 

255 

256 num_warps = 4 if block_m * block_n < 8192 else 8 

257 num_stages = 3 

258 

259 smem_per_stage = (block_m * block_k + block_k * block_n) * 2 

260 while num_stages > 2 and smem_per_stage * num_stages > 200_000: 

261 num_stages -= 1 

262 

263 config = { 

264 "BLOCK_SIZE_M": block_m, 

265 "BLOCK_SIZE_N": block_n, 

266 "BLOCK_SIZE_K": block_k, 

267 "GROUP_SIZE_M": group_m, 

268 "num_warps": num_warps, 

269 "num_stages": num_stages, 

270 } 

271 return config 

272 

273 

274def _get_config_dtype_str( 

275 dtype: Optional[torch.dtype] = None, 

276 use_fp8_w8a8: bool = False, 

277 use_fp8_w8a16: bool = False, 

278 use_int8_w8a16: bool = False, 

279 use_int4_w4a16: bool = False, 

280 ocp_mx_scheme: str | None = None, 

281) -> str | None: 

282 """Return dtype string for kernel config lookup.""" 

283 if use_fp8_w8a8: 

284 return "fp8_w8a8" 

285 elif use_fp8_w8a16: 

286 return "fp8_w8a16" 

287 elif use_int8_w8a16: 

288 return "int8_w8a16" 

289 elif use_int4_w4a16: 

290 return "int4_w4a16" 

291 elif ocp_mx_scheme is not None: 

292 return None 

293 elif dtype == torch.float: 

294 return "float32" 

295 return None 

296 

297 

298# MoE activation enum 

299 

300 

301class MoEActivation(Enum): 

302 """Activation functions for MoE layers.""" 

303 

304 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d] 

305 SILU = "silu" 

306 GELU = "gelu" 

307 RELU2 = "relu2" 

308 SWIGLUOAI = "swigluoai" 

309 SWIGLUSTEP = "swiglustep" 

310 

311 # Non-gated: input [..., d] -> output [..., d] 

312 SILU_NO_MUL = "silu_no_mul" 

313 GELU_NO_MUL = "gelu_no_mul" 

314 RELU2_NO_MUL = "relu2_no_mul" 

315 

316 @property 

317 def is_gated(self) -> bool: 

318 return not self.value.endswith("_no_mul") 

319 

320 def without_mul(self) -> "MoEActivation": 

321 """Return the non-gated variant.""" 

322 _without_mul: dict[MoEActivation, MoEActivation] = { 

323 MoEActivation.SILU: MoEActivation.SILU_NO_MUL, 

324 MoEActivation.GELU: MoEActivation.GELU_NO_MUL, 

325 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL, 

326 } 

327 return _without_mul.get(self, self) 

328 

329 @classmethod 

330 def from_str(cls, s: str) -> "MoEActivation": 

331 for member in cls: 

332 if member.value == s: 

333 return member 

334 valid = [m.value for m in cls] 

335 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}") 

336 

337 @staticmethod 

338 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int: 

339 """Return N for non-gated, N // 2 for gated activations.""" 

340 return N if not activation.is_gated else N // 2 

341 

342 

343def apply_moe_activation( 

344 activation: MoEActivation, 

345 output: torch.Tensor, 

346 input: torch.Tensor, 

347) -> torch.Tensor: 

348 """Apply MoE activation (pure PyTorch / FlagGems Triton).""" 

349 assert input.dim() == 2, "Input must be 2D" 

350 assert output.dim() == 2, "Output must be 2D" 

351 if activation.is_gated: 

352 assert output.size(-1) * 2 == input.size(-1), ( 

353 f"{activation.value} expects 2x ratio: " 

354 f"{output.size(-1) * 2} vs {input.size(-1)}" 

355 ) 

356 else: 

357 assert output.size(-1) == input.size(-1), ( 

358 f"{activation.value} expects equal sizes: " 

359 f"{output.size(-1)} vs {input.size(-1)}" 

360 ) 

361 

362 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI): 

363 N = output.size(-1) 

364 x, y = input[:, :N], input[:, N:] 

365 _silu_and_mul_kernel(x, y, out0=output) 

366 elif activation == MoEActivation.GELU: 

367 N = output.size(-1) 

368 gate, up = input[:, :N], input[:, N:] 

369 output.copy_(F.gelu(gate) * up) 

370 elif activation == MoEActivation.SWIGLUSTEP: 

371 N = output.size(-1) 

372 gate, up = input[:, :N], input[:, N:] 

373 output.copy_(torch.sigmoid(gate) * up) 

374 elif activation == MoEActivation.RELU2: 

375 N = output.size(-1) 

376 gate, up = input[:, :N], input[:, N:] 

377 output.copy_(F.relu(gate).square() * up) 

378 

379 elif activation == MoEActivation.SILU_NO_MUL: 

380 output.copy_(F.silu(input)) 

381 elif activation == MoEActivation.GELU_NO_MUL: 

382 output.copy_(F.gelu(input)) 

383 elif activation == MoEActivation.RELU2_NO_MUL: 

384 F.relu(input, inplace=True) 

385 torch.square(input, out=output) 

386 else: 

387 raise ValueError(f"Unsupported FusedMoe activation: {activation}") 

388 

389 return output 

390 

391 

392def _fp8_quantize( 

393 A: torch.Tensor, 

394 A_scale: Optional[torch.Tensor], 

395 per_act_token: bool, 

396 block_shape: Optional[list[int]] = None, 

397) -> tuple[torch.Tensor, torch.Tensor]: 

398 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise.""" 

399 fp8_dtype = torch.float8_e4m3fn 

400 finfo = torch.finfo(fp8_dtype) 

401 fp8_max = finfo.max 

402 fp8_min = finfo.min 

403 eps = 1e-10 

404 

405 if block_shape is not None: 

406 assert not per_act_token 

407 assert len(block_shape) == 2 

408 block_k = block_shape[1] 

409 assert A.size(-1) % block_k == 0 

410 orig_shape = A.shape 

411 A_flat = A.reshape(-1, A.size(-1)) 

412 M, K = A_flat.shape 

413 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

414 amax = ( 

415 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

416 ) 

417 scale = amax / fp8_max 

418 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

419 A_q = A_q.reshape(orig_shape) 

420 scale = scale.reshape(M, K // block_k) 

421 return A_q, scale 

422 

423 elif per_act_token: 

424 A_flat = A.reshape(-1, A.size(-1)) 

425 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

426 scale = amax / fp8_max 

427 min_scale = torch.tensor( 

428 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device 

429 ) 

430 scale = scale.clamp(min=min_scale) 

431 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

432 A_q = A_q.reshape(A.shape) 

433 scale = scale.reshape(A.shape[:-1] + (1,)) 

434 return A_q, scale 

435 

436 else: 

437 if A_scale is not None: 

438 scale = ( 

439 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

440 ) 

441 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

442 return A_q, A_scale 

443 else: 

444 amax = A.abs().amax().clamp(min=eps).to(torch.float32) 

445 scale = amax / fp8_max 

446 iscale = 1.0 / scale 

447 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

448 return A_q, scale.view(1) 

449 

450 

451def _int8_quantize( 

452 A: torch.Tensor, 

453 A_scale: Optional[torch.Tensor], 

454 per_act_token: bool, 

455 block_shape: Optional[list[int]] = None, 

456) -> tuple[torch.Tensor, torch.Tensor]: 

457 """INT8 quantization: per-tensor, per-token, or block-wise.""" 

458 iinfo = torch.iinfo(torch.int8) 

459 int8_max = iinfo.max 

460 int8_min = iinfo.min 

461 eps = 1e-10 

462 

463 if block_shape is not None: 

464 assert not per_act_token 

465 assert len(block_shape) == 2 

466 block_k = block_shape[1] 

467 assert A.size(-1) % block_k == 0 

468 orig_shape = A.shape 

469 A_flat = A.reshape(-1, A.size(-1)) 

470 M, K = A_flat.shape 

471 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

472 amax = ( 

473 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

474 ) 

475 scale = amax / int8_max 

476 A_q = ( 

477 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

478 ) 

479 A_q = A_q.reshape(orig_shape) 

480 scale = scale.reshape(M, K // block_k) 

481 return A_q, scale 

482 

483 elif per_act_token: 

484 A_flat = A.reshape(-1, A.size(-1)) 

485 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

486 scale = amax / int8_max 

487 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

488 A_q = A_q.reshape(A.shape) 

489 scale = scale.reshape(A.shape[:-1] + (1,)) 

490 return A_q, scale 

491 

492 else: 

493 assert A_scale is not None, "int8 per-tensor requires A_scale" 

494 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

495 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

496 return A_q, A_scale 

497 

498 

499def moe_kernel_quantize_input( 

500 A: torch.Tensor, 

501 A_scale: Optional[torch.Tensor], 

502 quant_dtype: None | torch.dtype | str, 

503 per_act_token_quant: bool, 

504 block_shape: Optional[list[int]] = None, 

505 ocp_mx_scheme: str | None = None, 

506) -> tuple[torch.Tensor, Optional[torch.Tensor]]: 

507 """Quantize MoE input activations before GEMM.""" 

508 if ocp_mx_scheme is not None: 

509 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}: 

510 pass 

511 elif ocp_mx_scheme.endswith("a_fp8"): 

512 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False) 

513 A = (qA.float() * qA_scale.float()).to(A.dtype) 

514 return A, None 

515 

516 if quant_dtype is None: 

517 return A, A_scale 

518 elif quant_dtype == torch.float8_e4m3fn: 

519 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) 

520 elif quant_dtype == torch.int8: 

521 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) 

522 else: 

523 return A, A_scale 

524 

525 

526def _ensure_block_size_k_divisible( 

527 size_k: int, block_size_k: int, group_size: int 

528) -> int: 

529 """Find largest block_size_k that divides size_k and is divisible by group_size.""" 

530 if size_k % block_size_k == 0 and block_size_k % group_size == 0: 

531 return block_size_k 

532 

533 max_search = min(block_size_k, size_k) 

534 start = (max_search // group_size) * group_size 

535 for candidate in range(start, group_size - 1, -group_size): 

536 if size_k % candidate == 0: 

537 return candidate 

538 

539 if size_k % group_size == 0: 

540 return group_size 

541 

542 return size_k 

543 

544 

545@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

546@triton.jit 

547def _silu_and_mul_kernel(x, y): 

548 x_fp32 = x.to(tl.float32) 

549 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

550 return x_silu * y 

551 

552 

553@triton.jit 

554def write_zeros_to_output( 

555 c_ptr, 

556 stride_cm, 

557 stride_cn, 

558 pid_n, 

559 N, 

560 offs_token, 

561 token_mask, 

562 BLOCK_SIZE_M, 

563 BLOCK_SIZE_N, 

564 compute_type, 

565): 

566 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) 

567 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

568 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

569 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

570 tl.store(c_ptrs, accumulator, mask=c_mask) 

571 

572 

573@triton.jit 

574def fused_moe_kernel_gptq_awq( 

575 # Pointers to matrices 

576 a_ptr, 

577 b_ptr, 

578 c_ptr, 

579 b_scale_ptr, 

580 b_zp_ptr, 

581 topk_weights_ptr, 

582 sorted_token_ids_ptr, 

583 expert_ids_ptr, 

584 num_tokens_post_padded_ptr, 

585 # Matrix dimensions 

586 N: tl.constexpr, 

587 K: tl.constexpr, 

588 EM, 

589 num_valid_tokens, 

590 # The stride variables represent how much to increase the ptr by when 

591 # moving by 1 element in a particular dimension. E.g. `stride_am` is 

592 # how much to increase `a_ptr` by to get the element one row down 

593 # (A has M rows). 

594 stride_am, 

595 stride_ak, 

596 stride_be, 

597 stride_bk, 

598 stride_bn, 

599 stride_cm, 

600 stride_cn, 

601 stride_bse, 

602 stride_bsk, 

603 stride_bsn, 

604 stride_bze, 

605 stride_bzk, 

606 stride_bzn, 

607 block_k_diviable: tl.constexpr, 

608 group_size: tl.constexpr, 

609 # Meta-parameters 

610 BLOCK_SIZE_M: tl.constexpr, 

611 BLOCK_SIZE_N: tl.constexpr, 

612 BLOCK_SIZE_K: tl.constexpr, 

613 GROUP_SIZE_M: tl.constexpr, 

614 SPLIT_K: tl.constexpr, 

615 MUL_ROUTED_WEIGHT: tl.constexpr, 

616 top_k: tl.constexpr, 

617 compute_type: tl.constexpr, 

618 has_zp: tl.constexpr, 

619 use_int4_w4a16: tl.constexpr, 

620 use_int8_w8a16: tl.constexpr, 

621): 

622 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights.""" 

623 # Map pid to C block (grouped ordering for L2 reuse) 

624 pid = tl.program_id(axis=0) 

625 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

626 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

627 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

628 group_id = pid // num_pid_in_group 

629 first_pid_m = group_id * GROUP_SIZE_M 

630 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

631 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

632 pid_n = (pid % num_pid_in_group) // group_size_m 

633 

634 # Create pointers for first blocks of A and B 

635 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

636 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

637 return 

638 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

639 # Cast to int64 to prevent overflow in stride*offset products 

640 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) 

641 token_mask = offs_token < num_valid_tokens 

642 

643 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

644 if off_experts == -1: 

645 # ----------------------------------------------------------- 

646 # Write back zeros to the output when the expert is not 

647 # in the current expert parallel rank. 

648 write_zeros_to_output( 

649 c_ptr, 

650 stride_cm, 

651 stride_cn, 

652 pid_n, 

653 N, 

654 offs_token, 

655 token_mask, 

656 BLOCK_SIZE_M, 

657 BLOCK_SIZE_N, 

658 compute_type, 

659 ) 

660 return 

661 

662 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

663 offs_k = tl.arange(0, BLOCK_SIZE_K) 

664 a_ptrs = a_ptr + ( 

665 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

666 ) 

667 

668 if use_int4_w4a16: 

669 b_ptrs = ( 

670 b_ptr 

671 + off_experts * stride_be 

672 + (offs_k[:, None] // 2) * stride_bk 

673 + offs_bn[None, :] * stride_bn 

674 ) 

675 b_shifter = (offs_k[:, None] % 2) * 4 

676 elif use_int8_w8a16: 

677 b_ptrs = ( 

678 b_ptr 

679 + off_experts * stride_be 

680 + offs_k[:, None] * stride_bk 

681 + offs_bn[None, :] * stride_bn 

682 ) 

683 

684 if not has_zp and use_int4_w4a16: 

685 b_zp_num = 8 

686 if not has_zp and use_int8_w8a16: 

687 b_zp_num = 128 

688 elif has_zp and use_int4_w4a16: 

689 b_zp_shifter = (offs_bn[None, :] % 2) * 4 

690 

691 # Accumulate C block in fp32 

692 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

693 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

694 if not block_k_diviable: 

695 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K 

696 k_other = 0.0 

697 else: 

698 k_mask = None 

699 k_other = None 

700 

701 a = tl.load( 

702 a_ptrs, 

703 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

704 other=0.0, 

705 ) 

706 b = tl.load(b_ptrs) 

707 if use_int4_w4a16: 

708 b = (b >> b_shifter) & 0xF 

709 

710 b_scale_ptrs = ( 

711 b_scale_ptr 

712 + off_experts * stride_bse 

713 + offs_bn[None, :] * stride_bsn 

714 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk 

715 ) 

716 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) 

717 b_scale = b_scale.to(tl.float32) 

718 

719 if has_zp and use_int4_w4a16: 

720 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

721 b_zp_ptrs = ( 

722 b_zp_ptr 

723 + off_experts * stride_bze 

724 + (offs_bn[None, :] // 2) * stride_bzn 

725 + offs_k_true * stride_bzk 

726 ) 

727 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

728 b_zp = (b_zp >> b_zp_shifter) & 0xF 

729 b_zp = b_zp.to(tl.float32) 

730 elif has_zp and use_int8_w8a16: 

731 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

732 b_zp_ptrs = ( 

733 b_zp_ptr 

734 + off_experts * stride_bze 

735 + offs_bn[None, :] * stride_bzn 

736 + offs_k_true * stride_bzk 

737 ) 

738 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

739 b_zp = b_zp.to(tl.float32) 

740 

741 if has_zp: 

742 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) 

743 else: 

744 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) 

745 accumulator = tl.dot(a, b, acc=accumulator) 

746 

747 a_ptrs += BLOCK_SIZE_K * stride_ak 

748 if use_int4_w4a16: 

749 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk 

750 else: 

751 b_ptrs += BLOCK_SIZE_K * stride_bk 

752 

753 if MUL_ROUTED_WEIGHT: 

754 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) 

755 accumulator = accumulator * moe_weight[:, None] 

756 

757 accumulator = accumulator.to(compute_type) 

758 # Write back output 

759 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

760 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

761 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

762 tl.store(c_ptrs, accumulator, mask=c_mask) 

763 

764 

765@triton.jit 

766def fused_moe_kernel( 

767 # Pointers to matrices 

768 a_ptr, 

769 b_ptr, 

770 c_ptr, 

771 b_bias_ptr, 

772 a_scale_ptr, 

773 b_scale_ptr, 

774 topk_weights_ptr, 

775 sorted_token_ids_ptr, 

776 expert_ids_ptr, 

777 num_tokens_post_padded_ptr, 

778 # Matrix dimensions 

779 N, 

780 K, 

781 EM, 

782 num_valid_tokens, 

783 stride_am, 

784 stride_ak, 

785 stride_be, 

786 stride_bk, 

787 stride_bn, 

788 stride_cm, 

789 stride_cn, 

790 stride_asm, 

791 stride_ask, 

792 stride_bse, 

793 stride_bsk, 

794 stride_bsn, 

795 stride_bbe, # bias expert stride 

796 stride_bbn, # bias N stride 

797 # Block size for block-wise quantization 

798 group_n: tl.constexpr, 

799 group_k: tl.constexpr, 

800 naive_block_assignment: tl.constexpr, 

801 # Meta-parameters 

802 BLOCK_SIZE_M: tl.constexpr, 

803 BLOCK_SIZE_N: tl.constexpr, 

804 BLOCK_SIZE_K: tl.constexpr, 

805 GROUP_SIZE_M: tl.constexpr, 

806 SPLIT_K: tl.constexpr, 

807 MUL_ROUTED_WEIGHT: tl.constexpr, 

808 top_k: tl.constexpr, 

809 compute_type: tl.constexpr, 

810 use_fp8_w8a8: tl.constexpr, 

811 use_int8_w8a8: tl.constexpr, 

812 use_int8_w8a16: tl.constexpr, 

813 per_channel_quant: tl.constexpr, 

814 HAS_BIAS: tl.constexpr, 

815): 

816 """Fused MoE kernel: token × expert GEMM with quantization support.""" 

817 # Map pid to C block (grouped ordering for L2 reuse) 

818 pid = tl.program_id(axis=0) 

819 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

820 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

821 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

822 group_id = pid // num_pid_in_group 

823 first_pid_m = group_id * GROUP_SIZE_M 

824 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

825 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

826 pid_n = (pid % num_pid_in_group) // group_size_m 

827 

828 # Create pointers for first blocks of A and B 

829 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

830 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

831 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

832 return 

833 if not naive_block_assignment: 

834 offs_token_id = pid_m * BLOCK_SIZE_M + offs 

835 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) 

836 else: 

837 offs_token = tl.where( 

838 offs == 0, 

839 pid_m, # first element = pid_m 

840 num_valid_tokens, # remaining elements = constant 

841 ) 

842 offs_token = offs_token.to(tl.int64) # prevent int32 overflow 

843 

844 token_mask = offs_token < num_valid_tokens 

845 

846 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

847 if off_experts == -1: 

848 # Expert not in current EP rank, write zeros 

849 write_zeros_to_output( 

850 c_ptr, 

851 stride_cm, 

852 stride_cn, 

853 pid_n, 

854 N, 

855 offs_token, 

856 token_mask, 

857 BLOCK_SIZE_M, 

858 BLOCK_SIZE_N, 

859 compute_type, 

860 ) 

861 return 

862 

863 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

864 offs_k = tl.arange(0, BLOCK_SIZE_K) 

865 a_ptrs = a_ptr + ( 

866 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

867 ) 

868 

869 b_ptrs = ( 

870 b_ptr 

871 + off_experts * stride_be 

872 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

873 ) 

874 if use_int8_w8a16: 

875 b_scale_ptrs = ( 

876 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

877 ) 

878 b_scale = tl.load(b_scale_ptrs) 

879 

880 if use_fp8_w8a8 or use_int8_w8a8: 

881 if group_k > 0 and group_n > 0: # block-wise 

882 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

883 offs_bsn = offs_bn // group_n 

884 b_scale_ptrs = ( 

885 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn 

886 ) 

887 elif per_channel_quant: # channel-wise 

888 b_scale_ptrs = ( 

889 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

890 ) 

891 b_scale = tl.load(b_scale_ptrs) 

892 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

893 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] 

894 else: # tensor-wise 

895 a_scale = tl.load(a_scale_ptr) 

896 b_scale = tl.load(b_scale_ptr + off_experts) 

897 if HAS_BIAS: 

898 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn 

899 bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) 

900 # Accumulate C block in fp32 

901 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

902 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

903 a = tl.load( 

904 a_ptrs, 

905 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

906 other=0.0, 

907 ) 

908 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) 

909 if use_int8_w8a16: 

910 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) 

911 elif use_fp8_w8a8 or use_int8_w8a8: 

912 if group_k > 0 and group_n > 0: 

913 k_start = k * BLOCK_SIZE_K 

914 offs_ks = k_start // group_k 

915 a_scale = tl.load( 

916 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 

917 ) 

918 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) 

919 

920 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] 

921 else: 

922 if use_fp8_w8a8: 

923 accumulator = tl.dot(a, b, acc=accumulator) 

924 else: 

925 accumulator += tl.dot(a, b) 

926 else: 

927 accumulator += tl.dot(a, b) 

928 a_ptrs += BLOCK_SIZE_K * stride_ak 

929 b_ptrs += BLOCK_SIZE_K * stride_bk 

930 

931 # Dequantization 

932 if use_int8_w8a16: 

933 accumulator = accumulator * b_scale 

934 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0): 

935 accumulator = accumulator * a_scale * b_scale 

936 

937 if HAS_BIAS: 

938 accumulator += bias[None, :] 

939 

940 # Router weight multiplication (must be in fp32) 

941 if MUL_ROUTED_WEIGHT: 

942 moe_weight = tl.load( 

943 topk_weights_ptr + offs_token, 

944 mask=token_mask, 

945 other=0, 

946 ) 

947 accumulator *= moe_weight[:, None] 

948 

949 accumulator = accumulator.to(compute_type) 

950 

951 # Write back output 

952 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

953 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

954 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

955 tl.store(c_ptrs, accumulator, mask=c_mask) 

956 

957 

958def invoke_fused_moe_wna16_triton_kernel( 

959 A: torch.Tensor, 

960 B: torch.Tensor, 

961 C: torch.Tensor, 

962 B_scale: torch.Tensor | None, 

963 B_zp: torch.Tensor | None, 

964 topk_weights: torch.Tensor | None, 

965 sorted_token_ids: torch.Tensor, 

966 expert_ids: torch.Tensor, 

967 num_tokens_post_padded: torch.Tensor, 

968 mul_routed_weight: bool, 

969 top_k: int, 

970 config: dict[str, Any], 

971 compute_type: tl.dtype, 

972 use_int8_w8a16: bool, 

973 use_int4_w4a16: bool, 

974 block_shape: list[int] | None, 

975): 

976 assert B_scale is not None and B_scale.ndim == 3 

977 assert B_zp is None or B_zp.ndim == 3 

978 assert block_shape is not None and block_shape[0] == 0 

979 

980 M = A.size(0) 

981 num_tokens = M * top_k 

982 

983 EM = sorted_token_ids.size(0) 

984 if A.size(0) < config["BLOCK_SIZE_M"]: 

985 # optimize for small batch_size. 

986 # We assume that top_ids of each token is unique, 

987 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, 

988 # and we can skip some invalid blocks. 

989 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) 

990 grid = lambda META: ( 

991 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

992 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

993 ) 

994 config = config.copy() 

995 config.update( 

996 get_moe_wna16_block_config( 

997 config=config, 

998 use_moe_wna16_cuda=False, 

999 num_valid_tokens=num_tokens, 

1000 size_k=A.size(1), 

1001 size_n=B.size(1), 

1002 num_experts=B.size(1), 

1003 group_size=block_shape[1], 

1004 real_top_k=top_k, 

1005 block_size_m=config["BLOCK_SIZE_M"], 

1006 ) 

1007 ) 

1008 

1009 fused_moe_kernel_gptq_awq[grid]( 

1010 A, 

1011 B, 

1012 C, 

1013 B_scale, 

1014 B_zp, 

1015 topk_weights, 

1016 sorted_token_ids, 

1017 expert_ids, 

1018 num_tokens_post_padded, 

1019 B.size(1), 

1020 A.size(1), 

1021 EM, 

1022 num_tokens, 

1023 A.stride(0), 

1024 A.stride(1), 

1025 B.stride(0), 

1026 B.stride(2), 

1027 B.stride(1), 

1028 C.stride(1), 

1029 C.stride(2), 

1030 B_scale.stride(0), 

1031 B_scale.stride(2), 

1032 B_scale.stride(1), 

1033 B_zp.stride(0) if B_zp is not None else 0, 

1034 B_zp.stride(2) if B_zp is not None else 0, 

1035 B_zp.stride(1) if B_zp is not None else 0, 

1036 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, 

1037 group_size=block_shape[1], 

1038 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1039 top_k=top_k, 

1040 compute_type=compute_type, 

1041 has_zp=B_zp is not None, 

1042 use_int4_w4a16=use_int4_w4a16, 

1043 use_int8_w8a16=use_int8_w8a16, 

1044 **config, 

1045 ) 

1046 

1047 

1048def invoke_fused_moe_triton_kernel( 

1049 A: torch.Tensor, 

1050 B: torch.Tensor, 

1051 C: torch.Tensor, 

1052 A_scale: Optional[torch.Tensor], 

1053 B_scale: Optional[torch.Tensor], 

1054 topk_weights: Optional[torch.Tensor], 

1055 sorted_token_ids: torch.Tensor, 

1056 expert_ids: torch.Tensor, 

1057 num_tokens_post_padded: torch.Tensor, 

1058 mul_routed_weight: bool, 

1059 top_k: int, 

1060 config: dict[str, Any], 

1061 compute_type: tl.dtype, 

1062 use_fp8_w8a8: bool = False, 

1063 use_int8_w8a8: bool = False, 

1064 use_int8_w8a16: bool = False, 

1065 use_int4_w4a16: bool = False, 

1066 per_channel_quant: bool = False, 

1067 block_shape: Optional[list[int]] = None, 

1068 B_bias: torch.Tensor | None = None, 

1069) -> None: 

1070 """Launch the fused_moe_kernel Triton kernel.""" 

1071 assert topk_weights is not None or not mul_routed_weight 

1072 assert topk_weights is None or topk_weights.stride(1) == 1 

1073 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1074 

1075 if use_fp8_w8a8 or use_int8_w8a8: 

1076 assert B_scale is not None 

1077 assert block_shape is None or triton.cdiv( 

1078 B.size(-2), block_shape[0] 

1079 ) == B_scale.size(-2) 

1080 assert block_shape is None or triton.cdiv( 

1081 B.size(-1), block_shape[1] 

1082 ) == B_scale.size(-1) 

1083 elif use_int8_w8a16 or use_int4_w4a16: 

1084 assert B_scale is not None 

1085 assert block_shape is None or block_shape[0] == 0 

1086 else: 

1087 assert A_scale is None 

1088 assert B_scale is None 

1089 

1090 M = A.size(0) 

1091 num_tokens = M * top_k 

1092 if sorted_token_ids is not None: 

1093 EM = sorted_token_ids.size(0) 

1094 if A.size(0) < config["BLOCK_SIZE_M"]: 

1095 EM = min( 

1096 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"] 

1097 ) 

1098 else: 

1099 EM = num_tokens * config["BLOCK_SIZE_M"] 

1100 grid = lambda META: ( 

1101 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

1102 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

1103 ) 

1104 HAS_BIAS = B_bias is not None 

1105 

1106 config = config.copy() 

1107 config["SPLIT_K"] = 1 

1108 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") 

1109 if block_shape is not None: 

1110 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) 

1111 

1112 fused_moe_kernel[grid]( 

1113 A, 

1114 B, 

1115 C, 

1116 B_bias, 

1117 A_scale, 

1118 B_scale, 

1119 topk_weights, 

1120 sorted_token_ids, 

1121 expert_ids, 

1122 num_tokens_post_padded, 

1123 B.size(1), # N 

1124 B.size(2), # K 

1125 EM, 

1126 num_tokens, 

1127 A.stride(0), 

1128 A.stride(1), 

1129 B.stride(0), 

1130 B.stride(2), 

1131 B.stride(1), 

1132 C.stride(1), 

1133 C.stride(2), 

1134 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, 

1135 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, 

1136 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1137 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, 

1138 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1139 B_bias.stride(0) if B_bias is not None else 0, 

1140 B_bias.stride(1) if B_bias is not None else 0, 

1141 0 if block_shape is None else block_shape[0], 

1142 0 if block_shape is None else block_shape[1], 

1143 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1144 top_k=top_k, 

1145 compute_type=compute_type, 

1146 use_fp8_w8a8=use_fp8_w8a8, 

1147 use_int8_w8a8=use_int8_w8a8, 

1148 use_int8_w8a16=use_int8_w8a16, 

1149 per_channel_quant=per_channel_quant, 

1150 naive_block_assignment=(sorted_token_ids is None), 

1151 HAS_BIAS=HAS_BIAS, 

1152 BLOCK_SIZE_K=BLOCK_SIZE_K, 

1153 **config, 

1154 ) 

1155 

1156 

1157def dispatch_fused_moe_kernel( 

1158 A: torch.Tensor, 

1159 B: torch.Tensor, 

1160 C: torch.Tensor, 

1161 A_scale: Optional[torch.Tensor], 

1162 B_scale: Optional[torch.Tensor], 

1163 B_zp: Optional[torch.Tensor], 

1164 topk_weights: Optional[torch.Tensor], 

1165 sorted_token_ids: torch.Tensor, 

1166 expert_ids: torch.Tensor, 

1167 num_tokens_post_padded: torch.Tensor, 

1168 mul_routed_weight: bool, 

1169 top_k: int, 

1170 config: dict[str, Any], 

1171 compute_type: tl.dtype, 

1172 use_fp8_w8a8: bool, 

1173 use_int8_w8a8: bool, 

1174 use_int8_w8a16: bool, 

1175 use_int4_w4a16: bool, 

1176 per_channel_quant: bool, 

1177 block_shape: Optional[list[int]] = None, 

1178 B_bias: Optional[torch.Tensor] = None, 

1179) -> None: 

1180 """Dispatch to the appropriate fused MoE kernel based on quantization flags.""" 

1181 assert topk_weights is not None or not mul_routed_weight 

1182 assert topk_weights is None or topk_weights.stride(1) == 1 

1183 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1184 

1185 # M = A.size(0) 

1186 # num_tokens = M * top_k 

1187 

1188 if False: 

1189 # TODO: Other precision-specific implementations 

1190 # use_fp8_w8a8, 

1191 # use_int8_w8a8, 

1192 # use_int8_w8a16, 

1193 # use_int4_w4a16, 

1194 pass 

1195 if (use_int8_w8a16 or use_int4_w4a16) and ( 

1196 block_shape is not None and block_shape[1] > 0 

1197 ): 

1198 assert B_bias is None 

1199 invoke_fused_moe_wna16_triton_kernel( 

1200 A, 

1201 B, 

1202 C, 

1203 B_scale, 

1204 B_zp, 

1205 topk_weights, 

1206 sorted_token_ids, 

1207 expert_ids, 

1208 num_tokens_post_padded, 

1209 mul_routed_weight, 

1210 top_k, 

1211 config, 

1212 compute_type, 

1213 use_int8_w8a16, 

1214 use_int4_w4a16, 

1215 block_shape, 

1216 ) 

1217 else: 

1218 invoke_fused_moe_triton_kernel( 

1219 A, 

1220 B, 

1221 C, 

1222 A_scale, 

1223 B_scale, 

1224 topk_weights, 

1225 sorted_token_ids, 

1226 expert_ids, 

1227 num_tokens_post_padded, 

1228 mul_routed_weight, 

1229 top_k, 

1230 config, 

1231 compute_type, 

1232 use_fp8_w8a8, 

1233 use_int8_w8a8, 

1234 use_int8_w8a16, 

1235 use_int4_w4a16, 

1236 per_channel_quant, 

1237 block_shape, 

1238 B_bias, 

1239 ) 

1240 

1241 

1242def fused_experts_impl( 

1243 hidden_states: torch.Tensor, 

1244 w1: torch.Tensor, 

1245 w2: torch.Tensor, 

1246 topk_weights: torch.Tensor, 

1247 topk_ids: torch.Tensor, 

1248 inplace: bool = False, 

1249 activation: str = "silu", 

1250 apply_router_weight_on_input: bool = False, 

1251 use_fp8_w8a8: bool = False, 

1252 use_int8_w8a8: bool = False, 

1253 use_int8_w8a16: bool = False, 

1254 use_int4_w4a16: bool = False, 

1255 ocp_mx_scheme: str | None = None, 

1256 per_channel_quant: bool = False, 

1257 global_num_experts: int = -1, 

1258 expert_map: torch.Tensor | None = None, 

1259 w1_scale: Optional[torch.Tensor] = None, 

1260 w2_scale: Optional[torch.Tensor] = None, 

1261 w1_zp: torch.Tensor | None = None, 

1262 w2_zp: torch.Tensor | None = None, 

1263 a1_scale: Optional[torch.Tensor] = None, 

1264 a2_scale: Optional[torch.Tensor] = None, 

1265 block_shape: Optional[list[int]] = None, 

1266 w1_bias: Optional[torch.Tensor] = None, 

1267 w2_bias: Optional[torch.Tensor] = None, 

1268) -> torch.Tensor: 

1269 logger.debug("GEMS FUSED MOE") 

1270 assert ( 

1271 activation == "silu" 

1272 ), f"Only 'silu' activation is supported, got {activation}" 

1273 

1274 activation_enum = MoEActivation.from_str(activation) 

1275 

1276 # Check constraints 

1277 if use_int4_w4a16: 

1278 # INT4 stored unpacked in INT8 containers (full K dim) 

1279 assert hidden_states.size(1) == w1.size( 

1280 2 

1281 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1282 elif ocp_mx_scheme is not None: 

1283 if ocp_mx_scheme.startswith("w_mxfp4"): 

1284 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" 

1285 elif ocp_mx_scheme.startswith("w_mxfp6"): 

1286 assert ( 

1287 hidden_states.size(1) == (w1.size(2) * 4) // 3 

1288 ), "hidden size mismatch" 

1289 else: 

1290 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1291 else: 

1292 assert hidden_states.size(1) == w1.size( 

1293 2 

1294 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1295 

1296 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" 

1297 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" 

1298 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" 

1299 assert w2.stride(-1) == 1, "Stride of last dimension must be 1" 

1300 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] 

1301 

1302 num_tokens = hidden_states.size(0) 

1303 E, N, _ = w1.size() 

1304 K = w2.size(1) 

1305 if global_num_experts == -1: 

1306 global_num_experts = E 

1307 top_k_num = topk_ids.size(1) 

1308 

1309 CHUNK_SIZE: int = 64 * 1024 

1310 M = min(num_tokens, CHUNK_SIZE) 

1311 

1312 config_dtype = _get_config_dtype_str( 

1313 use_fp8_w8a8=use_fp8_w8a8, 

1314 use_int8_w8a16=use_int8_w8a16, 

1315 use_int4_w4a16=use_int4_w4a16, 

1316 ocp_mx_scheme=ocp_mx_scheme, 

1317 dtype=hidden_states.dtype, 

1318 ) 

1319 

1320 quant_dtype = _get_config_quant_dtype( 

1321 use_fp8_w8a8=use_fp8_w8a8, 

1322 use_int8_w8a8=use_int8_w8a8, 

1323 ocp_mx_scheme=ocp_mx_scheme, 

1324 ) 

1325 

1326 get_config_func = functools.partial( 

1327 try_get_optimal_moe_config, 

1328 w1.size(), 

1329 w2.size(), 

1330 top_k_num, 

1331 config_dtype, 

1332 block_shape=block_shape, 

1333 ) 

1334 

1335 config = get_config_func(M) 

1336 

1337 # cache1 and cache3 share memory (non-overlapping lifetime) 

1338 cache13 = torch.empty( 

1339 M * top_k_num * max(N, K), 

1340 device=hidden_states.device, 

1341 dtype=hidden_states.dtype, 

1342 ) 

1343 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) 

1344 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) 

1345 

1346 # cache2 needs separate memory (concurrent with cache1) 

1347 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum) 

1348 intermediate_cache2 = torch.empty( 

1349 (M * top_k_num, activation_out_dim), 

1350 device=hidden_states.device, 

1351 dtype=hidden_states.dtype, 

1352 ) 

1353 

1354 if hidden_states.dtype == torch.bfloat16: 

1355 compute_type = tl.bfloat16 

1356 elif hidden_states.dtype == torch.float16: 

1357 compute_type = tl.float16 

1358 elif hidden_states.dtype == torch.float32: 

1359 compute_type = tl.float32 

1360 else: 

1361 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") 

1362 

1363 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) 

1364 

1365 if ocp_mx_scheme is not None: 

1366 # Dequantize OCP MX weights (TODO: skip on platforms with native MX) 

1367 if ocp_mx_scheme.startswith("w_mxfp4"): 

1368 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) 

1369 w1_scale = None 

1370 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) 

1371 w2_scale = None 

1372 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"): 

1373 w1 = dequant_mxfp6( 

1374 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1375 ) 

1376 w1_scale = None 

1377 w2 = dequant_mxfp6( 

1378 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1379 ) 

1380 w2_scale = None 

1381 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"): 

1382 w1 = dequant_mxfp6( 

1383 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1384 ) 

1385 w1_scale = None 

1386 w2 = dequant_mxfp6( 

1387 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1388 ) 

1389 w2_scale = None 

1390 else: 

1391 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1392 

1393 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot) 

1394 if use_int8_w8a16 or use_int4_w4a16: 

1395 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype) 

1396 w1_scale = None 

1397 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype) 

1398 w2_scale = None 

1399 use_int8_w8a16 = False 

1400 use_int4_w4a16 = False 

1401 

1402 for chunk in range((num_tokens // CHUNK_SIZE) + 1): 

1403 begin_chunk_idx, end_chunk_idx = ( 

1404 chunk * CHUNK_SIZE, 

1405 min((chunk + 1) * CHUNK_SIZE, num_tokens), 

1406 ) 

1407 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] 

1408 tokens_in_chunk, _ = curr_hidden_states.size() 

1409 

1410 if tokens_in_chunk == 0: 

1411 break 

1412 

1413 if tokens_in_chunk < CHUNK_SIZE and chunk > 0: 

1414 # Adjust cache size for last chunk 

1415 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] 

1416 intermediate_cache2 = intermediate_cache2[ 

1417 : tokens_in_chunk * topk_ids.size(1) 

1418 ] 

1419 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] 

1420 config = get_config_func(tokens_in_chunk) 

1421 

1422 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] 

1423 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] 

1424 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( 

1425 A=curr_hidden_states, 

1426 A_scale=a1_scale, 

1427 quant_dtype=quant_dtype, 

1428 per_act_token_quant=per_channel_quant, 

1429 block_shape=block_shape, 

1430 ocp_mx_scheme=ocp_mx_scheme, 

1431 ) 

1432 

1433 SPARSITY_FACTOR = 4 

1434 naive_block_assignment = ( 

1435 expert_map is None 

1436 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts 

1437 and not ( 

1438 (use_int8_w8a16 or use_int4_w4a16) 

1439 and block_shape is not None 

1440 and block_shape[1] > 0 

1441 ) 

1442 ) 

1443 

1444 if not naive_block_assignment: 

1445 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

1446 curr_topk_ids, 

1447 config["BLOCK_SIZE_M"], 

1448 global_num_experts, 

1449 expert_map, 

1450 # ignore_invalid_experts=True, 

1451 ) 

1452 else: 

1453 max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"] 

1454 expert_ids = curr_topk_ids.view(-1) 

1455 num_tokens_post_padded = torch.empty( 

1456 (1), dtype=torch.int32, device=topk_ids.device 

1457 ) 

1458 num_tokens_post_padded.fill_(max_num_tokens_padded) 

1459 sorted_token_ids = None 

1460 

1461 dispatch_fused_moe_kernel( 

1462 qcurr_hidden_states, 

1463 w1, 

1464 intermediate_cache1, 

1465 a1q_scale, 

1466 w1_scale, 

1467 w1_zp, 

1468 curr_topk_weights, 

1469 sorted_token_ids, 

1470 expert_ids, 

1471 num_tokens_post_padded, 

1472 apply_router_weight_on_input, 

1473 top_k_num, 

1474 config, 

1475 compute_type=compute_type, 

1476 use_fp8_w8a8=use_fp8_w8a8, 

1477 use_int8_w8a8=use_int8_w8a8, 

1478 use_int8_w8a16=use_int8_w8a16, 

1479 use_int4_w4a16=use_int4_w4a16, 

1480 per_channel_quant=per_channel_quant, 

1481 block_shape=block_shape, 

1482 B_bias=w1_bias, 

1483 ) 

1484 

1485 apply_moe_activation( 

1486 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N) 

1487 ) 

1488 

1489 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( 

1490 A=intermediate_cache2, 

1491 A_scale=a2_scale, 

1492 quant_dtype=quant_dtype, 

1493 per_act_token_quant=per_channel_quant, 

1494 block_shape=block_shape, 

1495 ocp_mx_scheme=ocp_mx_scheme, 

1496 ) 

1497 

1498 if expert_map is not None: 

1499 intermediate_cache3.zero_() 

1500 

1501 dispatch_fused_moe_kernel( 

1502 qintermediate_cache2, 

1503 w2, 

1504 intermediate_cache3, 

1505 a2q_scale, 

1506 w2_scale, 

1507 w2_zp, 

1508 curr_topk_weights, 

1509 sorted_token_ids, 

1510 expert_ids, 

1511 num_tokens_post_padded, 

1512 not apply_router_weight_on_input, 

1513 1, 

1514 config, 

1515 compute_type=compute_type, 

1516 use_fp8_w8a8=use_fp8_w8a8, 

1517 use_int8_w8a8=use_int8_w8a8, 

1518 use_int8_w8a16=use_int8_w8a16, 

1519 use_int4_w4a16=use_int4_w4a16, 

1520 per_channel_quant=per_channel_quant, 

1521 block_shape=block_shape, 

1522 B_bias=w2_bias, 

1523 ) 

1524 

1525 moe_sum( 

1526 intermediate_cache3.view(*intermediate_cache3.size()), 

1527 out_hidden_states[begin_chunk_idx:end_chunk_idx], 

1528 ) 

1529 

1530 return out_hidden_states 

1531 

1532 

1533def inplace_fused_experts( 

1534 hidden_states: torch.Tensor, 

1535 w1: torch.Tensor, 

1536 w2: torch.Tensor, 

1537 topk_weights: torch.Tensor, 

1538 topk_ids: torch.Tensor, 

1539 activation: str = "silu", 

1540 apply_router_weight_on_input: bool = False, 

1541 use_fp8_w8a8: bool = False, 

1542 use_int8_w8a8: bool = False, 

1543 use_int8_w8a16: bool = False, 

1544 use_int4_w4a16: bool = False, 

1545 per_channel_quant: bool = False, 

1546 global_num_experts: int = -1, 

1547 w1_scale: Optional[torch.Tensor] = None, 

1548 w2_scale: Optional[torch.Tensor] = None, 

1549 a1_scale: Optional[torch.Tensor] = None, 

1550 a2_scale: Optional[torch.Tensor] = None, 

1551 block_shape: Optional[list[int]] = None, 

1552 w1_bias: Optional[torch.Tensor] = None, 

1553 w2_bias: Optional[torch.Tensor] = None, 

1554) -> None: 

1555 """ 

1556 In-place fused MoE: writes output directly into ``hidden_states``. 

1557 

1558 Same semantics as ``fused_experts_impl(..., inplace=True)``. 

1559 Returns None (the result is stored in ``hidden_states``). 

1560 """ 

1561 fused_experts_impl( 

1562 hidden_states, 

1563 w1, 

1564 w2, 

1565 topk_weights, 

1566 topk_ids, 

1567 inplace=True, 

1568 activation=activation, 

1569 apply_router_weight_on_input=apply_router_weight_on_input, 

1570 use_fp8_w8a8=use_fp8_w8a8, 

1571 use_int8_w8a8=use_int8_w8a8, 

1572 use_int8_w8a16=use_int8_w8a16, 

1573 use_int4_w4a16=use_int4_w4a16, 

1574 per_channel_quant=per_channel_quant, 

1575 global_num_experts=global_num_experts, 

1576 w1_scale=w1_scale, 

1577 w2_scale=w2_scale, 

1578 a1_scale=a1_scale, 

1579 a2_scale=a2_scale, 

1580 block_shape=block_shape, 

1581 w1_bias=w1_bias, 

1582 w2_bias=w2_bias, 

1583 ) 

1584 

1585 

1586def outplace_fused_experts( 

1587 hidden_states: torch.Tensor, 

1588 w1: torch.Tensor, 

1589 w2: torch.Tensor, 

1590 topk_weights: torch.Tensor, 

1591 topk_ids: torch.Tensor, 

1592 activation: str = "silu", 

1593 apply_router_weight_on_input: bool = False, 

1594 use_fp8_w8a8: bool = False, 

1595 use_int8_w8a8: bool = False, 

1596 use_int8_w8a16: bool = False, 

1597 use_int4_w4a16: bool = False, 

1598 per_channel_quant: bool = False, 

1599 global_num_experts: int = -1, 

1600 w1_scale: Optional[torch.Tensor] = None, 

1601 w2_scale: Optional[torch.Tensor] = None, 

1602 a1_scale: Optional[torch.Tensor] = None, 

1603 a2_scale: Optional[torch.Tensor] = None, 

1604 block_shape: Optional[list[int]] = None, 

1605 w1_bias: Optional[torch.Tensor] = None, 

1606 w2_bias: Optional[torch.Tensor] = None, 

1607) -> torch.Tensor: 

1608 """ 

1609 Out-of-place fused MoE: allocates and returns a new output tensor. 

1610 

1611 Same semantics as ``fused_experts_impl(..., inplace=False)``. 

1612 """ 

1613 return fused_experts_impl( 

1614 hidden_states, 

1615 w1, 

1616 w2, 

1617 topk_weights, 

1618 topk_ids, 

1619 inplace=False, 

1620 activation=activation, 

1621 apply_router_weight_on_input=apply_router_weight_on_input, 

1622 use_fp8_w8a8=use_fp8_w8a8, 

1623 use_int8_w8a8=use_int8_w8a8, 

1624 use_int8_w8a16=use_int8_w8a16, 

1625 use_int4_w4a16=use_int4_w4a16, 

1626 per_channel_quant=per_channel_quant, 

1627 global_num_experts=global_num_experts, 

1628 w1_scale=w1_scale, 

1629 w2_scale=w2_scale, 

1630 a1_scale=a1_scale, 

1631 a2_scale=a2_scale, 

1632 block_shape=block_shape, 

1633 w1_bias=w1_bias, 

1634 w2_bias=w2_bias, 

1635 )