Coverage for src/flag_gems/fused/fused_moe.py: 43%

665 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# 

4# Adapted from the vLLM project (https://github.com/vllm-project/vllm). 

5# Source files under vllm/model_executor/layers/: 

6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl 

7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation 

8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input 

9# fused_moe/config.py – _get_config_dtype_str 

10# quantization/utils/mxfp4_utils.py – dequant_mxfp4 

11# quantization/utils/mxfp6_utils.py – dequant_mxfp6 

12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE 

13 

14 

15import functools 

16import json 

17import logging 

18import os 

19from enum import Enum 

20from typing import Any, Optional 

21 

22import torch 

23import torch.nn.functional as F 

24import triton 

25import triton.language as tl 

26 

27from flag_gems.fused.moe_align_block_size import moe_align_block_size 

28from flag_gems.fused.moe_sum import moe_sum 

29from flag_gems.utils import pointwise_dynamic 

30 

31logger = logging.getLogger(__name__) 

32 

33# OCP MX quantization helpers (requires amd-quark) 

34 

35OCP_MX_BLOCK_SIZE = 32 

36 

37 

38@functools.lru_cache(maxsize=1) 

39def get_embedded_moe_configs(): 

40 config_path = os.path.join( 

41 os.path.dirname(__file__), "..", "utils", "configs", "fused_moe_config.json" 

42 ) 

43 if not os.path.exists(config_path): 

44 return {}, {} 

45 with open(config_path, "r") as f: 

46 # JSON keys are strings, values are dicts where keys are M and values are configs 

47 data = json.load(f) 

48 

49 fallback = data.get("_FALLBACK", {}) 

50 

51 # We need to convert the innermost keys (which are stringified integers for M) back to integers. 

52 # Ensure we map the lists back to config dicts. 

53 keys_order = [ 

54 "BLOCK_SIZE_M", 

55 "BLOCK_SIZE_N", 

56 "BLOCK_SIZE_K", 

57 "GROUP_SIZE_M", 

58 "num_warps", 

59 "num_stages", 

60 ] 

61 parsed_data = {} 

62 for dev, configs in data.items(): 

63 if dev == "_FALLBACK": 

64 continue 

65 parsed_data[dev] = {} 

66 for k, m_dict in configs.items(): 

67 parsed_dict = {} 

68 for m, v in m_dict.items(): 

69 if isinstance(v, list): 

70 parsed_dict[int(m)] = dict(zip(keys_order, v)) 

71 else: 

72 parsed_dict[int(m)] = v 

73 parsed_data[dev][k] = parsed_dict 

74 

75 return parsed_data, fallback 

76 

77 

78def dequant_mxfp4( 

79 x: torch.Tensor, 

80 scale: torch.Tensor, 

81 float_dtype: torch.dtype, 

82) -> torch.Tensor: 

83 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4.""" 

84 try: 

85 from quark.torch.kernel import mx 

86 except ImportError as err: 

87 raise ImportError("amd-quark is required for MX-FP4") from err 

88 

89 return mx.dq_mxfp4(x, scale, float_dtype) 

90 

91 

92def dequant_mxfp6( 

93 x: torch.Tensor, 

94 scale: torch.Tensor, 

95 float_dtype: torch.dtype, 

96 quant_dtype: str, 

97) -> torch.Tensor: 

98 """Dequantize MXFP6 tensor via quark hw_emulation.""" 

99 try: 

100 from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( 

101 dequantize_fp4_fp6_per_group, 

102 ) 

103 from quark.torch.utils.pack import create_pack_method 

104 except ImportError as err: 

105 raise ImportError("amd-quark is required for MX-FP6") from err 

106 

107 pack_method = create_pack_method(None, dtype=quant_dtype) 

108 unpacked_x = pack_method.unpack(x, reorder=False) 

109 

110 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype) 

111 

112 return dequantize_fp4_fp6_per_group( 

113 unpacked_x, 

114 scale, 

115 axis=-1, 

116 group_size=OCP_MX_BLOCK_SIZE, 

117 quant_dtype=quant_dtype, 

118 ).to(float_dtype) 

119 

120 

121# Activation quantization helpers 

122 

123 

124@functools.lru_cache(maxsize=1) 

125def _get_device_name() -> str: 

126 """Return the normalised CUDA device name (spaces replaced by underscores). 

127 

128 Matches the naming convention used by vLLM for its per-device config files. 

129 H800 falls back to H100_80GB_HBM3 (same SM 9.0 architecture). 

130 """ 

131 name = torch.cuda.get_device_name().replace(" ", "_") 

132 # Normalise the H200 product family to a single key, following vLLM. 

133 if "H200" in name.split("_"): 

134 name = "NVIDIA_H200" 

135 # H800 has the same SM 9.0 as H100; use H100 configs as fallback. 

136 embedded_configs, fallback_mapping = get_embedded_moe_configs() 

137 if name in embedded_configs: 

138 return name 

139 # Fallback mapping for devices whose tuning profiles are equivalent. 

140 fallback = fallback_mapping.get(name) 

141 if fallback and fallback in embedded_configs: 

142 logger.info("Device %s not in config table, falling back to %s", name, fallback) 

143 return fallback 

144 return name 

145 

146 

147def get_moe_configs( 

148 E: int, 

149 N: int, 

150 dtype: str | None, 

151 block_n: int | None = None, 

152 block_k: int | None = None, 

153) -> dict[int, Any] | None: 

154 """ 

155 Return optimized configurations for the fused MoE kernel. 

156 

157 Looks up pre-tuned configs from the embedded table (ported from vLLM) 

158 for the current GPU device. Returns None if no matching config is found. 

159 """ 

160 device_name = _get_device_name() 

161 embedded_configs, _ = get_embedded_moe_configs() 

162 device_table = embedded_configs.get(device_name) 

163 if device_table is None: 

164 logger.warning( 

165 "No embedded MoE configs for device %s. Will use default config.", 

166 device_name, 

167 ) 

168 return None 

169 

170 _block_n = block_n if block_n else 0 

171 _block_k = block_k if block_k else 0 

172 key = f"{E},{N},{dtype},{_block_n},{_block_k}" 

173 configs = device_table.get(key) 

174 if configs is not None: 

175 logger.info("Using embedded MoE config for device=%s, key=%s", device_name, key) 

176 return configs 

177 logger.warning( 

178 "No embedded MoE config for device=%s, key=%s. Will use default config.", 

179 device_name, 

180 key, 

181 ) 

182 return None 

183 

184 

185def try_get_optimal_moe_config( 

186 w1_shape: tuple[int, ...], 

187 w2_shape: tuple[int, ...], 

188 top_k: int, 

189 dtype: str | None, 

190 M: int, 

191 block_shape: list[int] | None = None, 

192) -> dict[str, int]: 

193 override_config: Optional[dict[str, Any]] = None 

194 if override_config: 

195 config = override_config 

196 else: 

197 # First try to load optimal config from the file 

198 E, _, N = w2_shape 

199 if dtype == "int4_w4a16": 

200 N = N * 2 

201 block_n = block_shape[0] if block_shape else 0 

202 block_k = block_shape[1] if block_shape else 0 

203 configs = get_moe_configs(E, N, dtype, block_n, block_k) 

204 

205 if configs: 

206 config = configs[min(configs.keys(), key=lambda x: abs(x - M))] 

207 else: 

208 config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) 

209 return config 

210 

211 

212def _get_config_quant_dtype( 

213 use_fp8_w8a8: bool, 

214 use_int8_w8a8: bool, 

215 ocp_mx_scheme: str | None, 

216) -> None | torch.dtype | str: 

217 """Map quantization flags to the corresponding dtype.""" 

218 if use_fp8_w8a8: 

219 return torch.float8_e4m3fn 

220 elif use_int8_w8a8: 

221 return torch.int8 

222 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": 

223 return "mxfp4" 

224 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}: 

225 return "mxfp6_e3m2" 

226 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: 

227 return "mxfp6_e2m3" 

228 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}: 

229 return torch.bfloat16 

230 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}: 

231 return torch.float8_e4m3fn 

232 

233 return None 

234 

235 

236def get_moe_wna16_block_config( 

237 config: dict[str, int], 

238 use_moe_wna16_cuda: bool, 

239 num_valid_tokens: int, 

240 size_k: int, 

241 size_n: int, 

242 num_experts: int, 

243 group_size: int, 

244 real_top_k: int, 

245 block_size_m: int, 

246): 

247 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: 

248 return {} 

249 if not use_moe_wna16_cuda: 

250 if num_valid_tokens // real_top_k == 1: 

251 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} 

252 else: 

253 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} 

254 else: 

255 block_size_n = 128 

256 block_size_k = 128 

257 if block_size_k <= group_size: 

258 block_size_k = group_size 

259 

260 num_n_blocks = size_k // block_size_k 

261 num_k_blocks = size_n // block_size_k 

262 num_m_blocks = ( 

263 num_valid_tokens + block_size_m - 1 

264 ) / block_size_m + num_experts 

265 if num_valid_tokens // real_top_k <= block_size_m: 

266 num_m_blocks = min(num_m_blocks, num_valid_tokens) 

267 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks 

268 

269 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256: 

270 block_size_k = 256 

271 num_blocks = num_blocks // (256 // block_size_k) 

272 

273 if ( 

274 num_m_blocks <= 16 

275 and size_k % (block_size_k * 2) == 0 

276 and size_k % (block_size_k * 2) == 0 

277 and block_size_k <= 512 

278 and num_blocks >= 512 

279 ): 

280 block_size_k = block_size_k * 2 

281 num_blocks = num_blocks // 2 

282 

283 if num_blocks > 1024: 

284 block_size_n = 256 

285 num_n_blocks = num_n_blocks // 2 

286 num_blocks = num_blocks // 2 

287 

288 if size_n <= 1024 and num_blocks >= 1024: 

289 block_size_n = 1024 

290 

291 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size) 

292 

293 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} 

294 

295 

296def get_default_config( 

297 M: int, 

298 E: int, 

299 N: int, 

300 K: int, 

301 topk: int, 

302 dtype: str | None, 

303 block_shape: list[int] | None = None, 

304) -> dict[str, int]: 

305 """Default Triton config for fused MoE kernel. 

306 

307 Heuristic selection aligned with vLLM v0.17.0 defaults, tuned on H20/H100. 

308 Key insight: for high-expert-count MoE (e.g. DeepSeek-V3 E=256), each 

309 expert sees very few tokens, so small BLOCK_SIZE_M (16) is critical. 

310 """ 

311 if dtype == "fp8_w8a8" and block_shape is not None: 

312 config = { 

313 "BLOCK_SIZE_M": 16 if M <= 64 else 64, 

314 "BLOCK_SIZE_N": block_shape[0], 

315 "BLOCK_SIZE_K": block_shape[1], 

316 "GROUP_SIZE_M": 1 if M <= 16 else 32, 

317 "num_warps": 4, 

318 "num_stages": 3, 

319 } 

320 else: 

321 # tokens_per_expert drives block_m: use M//E (not M*topk//E) to 

322 # estimate the actual per-expert token count after routing. 

323 tokens_per_expert = M // max(E, 1) 

324 

325 if tokens_per_expert <= 2: 

326 block_m = 16 

327 elif tokens_per_expert <= 4: 

328 block_m = 32 

329 elif tokens_per_expert <= 16: 

330 block_m = 64 

331 else: 

332 block_m = 128 

333 

334 # Tile sizing 

335 if N >= 4096: 

336 block_n = 128 if M <= 128 else 256 

337 elif N >= 1024: 

338 block_n = 64 if M <= 64 else 128 

339 else: 

340 block_n = 64 if M <= 64 else 128 

341 

342 if dtype == "fp8_w8a8": 

343 block_k = 128 

344 elif M <= 64: 

345 block_k = 128 

346 else: 

347 block_k = 64 

348 

349 if tokens_per_expert > 128: 

350 group_m = 16 

351 elif tokens_per_expert > 32: 

352 group_m = 8 

353 else: 

354 group_m = 1 

355 

356 # Prefer 4 warps for small tiles; only use 8 for large M 

357 num_warps = 4 if M <= 128 else 8 

358 num_stages = 3 

359 

360 smem_per_stage = (block_m * block_k + block_k * block_n) * 2 

361 while num_stages > 2 and smem_per_stage * num_stages > 200_000: 

362 num_stages -= 1 

363 

364 config = { 

365 "BLOCK_SIZE_M": block_m, 

366 "BLOCK_SIZE_N": block_n, 

367 "BLOCK_SIZE_K": block_k, 

368 "GROUP_SIZE_M": group_m, 

369 "num_warps": num_warps, 

370 "num_stages": num_stages, 

371 } 

372 return config 

373 

374 

375def _get_config_dtype_str( 

376 dtype: Optional[torch.dtype] = None, 

377 use_fp8_w8a8: bool = False, 

378 use_fp8_w8a16: bool = False, 

379 use_int8_w8a16: bool = False, 

380 use_int4_w4a16: bool = False, 

381 ocp_mx_scheme: str | None = None, 

382) -> str | None: 

383 """Return dtype string for kernel config lookup.""" 

384 if use_fp8_w8a8: 

385 return "fp8_w8a8" 

386 elif use_fp8_w8a16: 

387 return "fp8_w8a16" 

388 elif use_int8_w8a16: 

389 return "int8_w8a16" 

390 elif use_int4_w4a16: 

391 return "int4_w4a16" 

392 elif ocp_mx_scheme is not None: 

393 return None 

394 elif dtype == torch.float: 

395 return "float32" 

396 return None 

397 

398 

399# MoE activation enum 

400 

401 

402class MoEActivation(Enum): 

403 """Activation functions for MoE layers.""" 

404 

405 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d] 

406 SILU = "silu" 

407 GELU = "gelu" 

408 RELU2 = "relu2" 

409 SWIGLUOAI = "swigluoai" 

410 SWIGLUSTEP = "swiglustep" 

411 

412 # Non-gated: input [..., d] -> output [..., d] 

413 SILU_NO_MUL = "silu_no_mul" 

414 GELU_NO_MUL = "gelu_no_mul" 

415 RELU2_NO_MUL = "relu2_no_mul" 

416 

417 @property 

418 def is_gated(self) -> bool: 

419 return not self.value.endswith("_no_mul") 

420 

421 def without_mul(self) -> "MoEActivation": 

422 """Return the non-gated variant.""" 

423 _without_mul: dict[MoEActivation, MoEActivation] = { 

424 MoEActivation.SILU: MoEActivation.SILU_NO_MUL, 

425 MoEActivation.GELU: MoEActivation.GELU_NO_MUL, 

426 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL, 

427 } 

428 return _without_mul.get(self, self) 

429 

430 @classmethod 

431 def from_str(cls, s: str) -> "MoEActivation": 

432 for member in cls: 

433 if member.value == s: 

434 return member 

435 valid = [m.value for m in cls] 

436 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}") 

437 

438 @staticmethod 

439 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int: 

440 """Return N for non-gated, N // 2 for gated activations.""" 

441 return N if not activation.is_gated else N // 2 

442 

443 

444def apply_moe_activation( 

445 activation: MoEActivation, 

446 output: torch.Tensor, 

447 input: torch.Tensor, 

448) -> torch.Tensor: 

449 """Apply MoE activation (pure PyTorch / FlagGems Triton).""" 

450 assert input.dim() == 2, "Input must be 2D" 

451 assert output.dim() == 2, "Output must be 2D" 

452 if activation.is_gated: 

453 assert output.size(-1) * 2 == input.size(-1), ( 

454 f"{activation.value} expects 2x ratio: " 

455 f"{output.size(-1) * 2} vs {input.size(-1)}" 

456 ) 

457 else: 

458 assert output.size(-1) == input.size(-1), ( 

459 f"{activation.value} expects equal sizes: " 

460 f"{output.size(-1)} vs {input.size(-1)}" 

461 ) 

462 

463 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI): 

464 N = output.size(-1) 

465 x, y = input[:, :N], input[:, N:] 

466 _silu_and_mul_kernel(x, y, out0=output) 

467 elif activation == MoEActivation.GELU: 

468 N = output.size(-1) 

469 gate, up = input[:, :N], input[:, N:] 

470 output.copy_(F.gelu(gate) * up) 

471 elif activation == MoEActivation.SWIGLUSTEP: 

472 N = output.size(-1) 

473 gate, up = input[:, :N], input[:, N:] 

474 output.copy_(torch.sigmoid(gate) * up) 

475 elif activation == MoEActivation.RELU2: 

476 N = output.size(-1) 

477 gate, up = input[:, :N], input[:, N:] 

478 output.copy_(F.relu(gate).square() * up) 

479 

480 elif activation == MoEActivation.SILU_NO_MUL: 

481 output.copy_(F.silu(input)) 

482 elif activation == MoEActivation.GELU_NO_MUL: 

483 output.copy_(F.gelu(input)) 

484 elif activation == MoEActivation.RELU2_NO_MUL: 

485 F.relu(input, inplace=True) 

486 torch.square(input, out=output) 

487 else: 

488 raise ValueError(f"Unsupported FusedMoe activation: {activation}") 

489 

490 return output 

491 

492 

493def _fp8_quantize( 

494 A: torch.Tensor, 

495 A_scale: Optional[torch.Tensor], 

496 per_act_token: bool, 

497 block_shape: Optional[list[int]] = None, 

498) -> tuple[torch.Tensor, torch.Tensor]: 

499 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise.""" 

500 fp8_dtype = torch.float8_e4m3fn 

501 finfo = torch.finfo(fp8_dtype) 

502 fp8_max = finfo.max 

503 fp8_min = finfo.min 

504 eps = 1e-10 

505 

506 if block_shape is not None: 

507 assert not per_act_token 

508 assert len(block_shape) == 2 

509 block_k = block_shape[1] 

510 assert A.size(-1) % block_k == 0 

511 orig_shape = A.shape 

512 A_flat = A.reshape(-1, A.size(-1)) 

513 M, K = A_flat.shape 

514 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

515 amax = ( 

516 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

517 ) 

518 scale = amax / fp8_max 

519 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

520 A_q = A_q.reshape(orig_shape) 

521 scale = scale.reshape(M, K // block_k) 

522 return A_q, scale 

523 

524 elif per_act_token: 

525 A_flat = A.reshape(-1, A.size(-1)) 

526 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

527 scale = amax / fp8_max 

528 min_scale = torch.tensor( 

529 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device 

530 ) 

531 scale = scale.clamp(min=min_scale) 

532 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

533 A_q = A_q.reshape(A.shape) 

534 scale = scale.reshape(A.shape[:-1] + (1,)) 

535 return A_q, scale 

536 

537 else: 

538 if A_scale is not None: 

539 scale = ( 

540 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

541 ) 

542 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

543 return A_q, A_scale 

544 else: 

545 amax = A.abs().amax().clamp(min=eps).to(torch.float32) 

546 scale = amax / fp8_max 

547 iscale = 1.0 / scale 

548 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

549 return A_q, scale.view(1) 

550 

551 

552def _int8_quantize( 

553 A: torch.Tensor, 

554 A_scale: Optional[torch.Tensor], 

555 per_act_token: bool, 

556 block_shape: Optional[list[int]] = None, 

557) -> tuple[torch.Tensor, torch.Tensor]: 

558 """INT8 quantization: per-tensor, per-token, or block-wise.""" 

559 iinfo = torch.iinfo(torch.int8) 

560 int8_max = iinfo.max 

561 int8_min = iinfo.min 

562 eps = 1e-10 

563 

564 if block_shape is not None: 

565 assert not per_act_token 

566 assert len(block_shape) == 2 

567 block_k = block_shape[1] 

568 assert A.size(-1) % block_k == 0 

569 orig_shape = A.shape 

570 A_flat = A.reshape(-1, A.size(-1)) 

571 M, K = A_flat.shape 

572 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

573 amax = ( 

574 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

575 ) 

576 scale = amax / int8_max 

577 A_q = ( 

578 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

579 ) 

580 A_q = A_q.reshape(orig_shape) 

581 scale = scale.reshape(M, K // block_k) 

582 return A_q, scale 

583 

584 elif per_act_token: 

585 A_flat = A.reshape(-1, A.size(-1)) 

586 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

587 scale = amax / int8_max 

588 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

589 A_q = A_q.reshape(A.shape) 

590 scale = scale.reshape(A.shape[:-1] + (1,)) 

591 return A_q, scale 

592 

593 else: 

594 assert A_scale is not None, "int8 per-tensor requires A_scale" 

595 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

596 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

597 return A_q, A_scale 

598 

599 

600def moe_kernel_quantize_input( 

601 A: torch.Tensor, 

602 A_scale: Optional[torch.Tensor], 

603 quant_dtype: None | torch.dtype | str, 

604 per_act_token_quant: bool, 

605 block_shape: Optional[list[int]] = None, 

606 ocp_mx_scheme: str | None = None, 

607) -> tuple[torch.Tensor, Optional[torch.Tensor]]: 

608 """Quantize MoE input activations before GEMM.""" 

609 if ocp_mx_scheme is not None: 

610 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}: 

611 pass 

612 elif ocp_mx_scheme.endswith("a_fp8"): 

613 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False) 

614 A = (qA.float() * qA_scale.float()).to(A.dtype) 

615 return A, None 

616 

617 if quant_dtype is None: 

618 return A, A_scale 

619 elif quant_dtype == torch.float8_e4m3fn: 

620 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) 

621 elif quant_dtype == torch.int8: 

622 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) 

623 else: 

624 return A, A_scale 

625 

626 

627def _ensure_block_size_k_divisible( 

628 size_k: int, block_size_k: int, group_size: int 

629) -> int: 

630 """Find largest block_size_k that divides size_k and is divisible by group_size.""" 

631 if size_k % block_size_k == 0 and block_size_k % group_size == 0: 

632 return block_size_k 

633 

634 max_search = min(block_size_k, size_k) 

635 start = (max_search // group_size) * group_size 

636 for candidate in range(start, group_size - 1, -group_size): 

637 if size_k % candidate == 0: 

638 return candidate 

639 

640 if size_k % group_size == 0: 

641 return group_size 

642 

643 return size_k 

644 

645 

646@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

647@triton.jit 

648def _silu_and_mul_kernel(x, y): 

649 x_fp32 = x.to(tl.float32) 

650 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

651 return x_silu * y 

652 

653 

654@triton.jit 

655def write_zeros_to_output( 

656 c_ptr, 

657 stride_cm, 

658 stride_cn, 

659 pid_n, 

660 N, 

661 offs_token, 

662 token_mask, 

663 BLOCK_SIZE_M, 

664 BLOCK_SIZE_N, 

665 compute_type, 

666): 

667 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) 

668 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

669 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

670 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

671 tl.store(c_ptrs, accumulator, mask=c_mask) 

672 

673 

674@triton.jit 

675def fused_moe_kernel_gptq_awq( 

676 # Pointers to matrices 

677 a_ptr, 

678 b_ptr, 

679 c_ptr, 

680 b_scale_ptr, 

681 b_zp_ptr, 

682 topk_weights_ptr, 

683 sorted_token_ids_ptr, 

684 expert_ids_ptr, 

685 num_tokens_post_padded_ptr, 

686 # Matrix dimensions 

687 N: tl.constexpr, 

688 K: tl.constexpr, 

689 EM, 

690 num_valid_tokens, 

691 # The stride variables represent how much to increase the ptr by when 

692 # moving by 1 element in a particular dimension. E.g. `stride_am` is 

693 # how much to increase `a_ptr` by to get the element one row down 

694 # (A has M rows). 

695 stride_am, 

696 stride_ak, 

697 stride_be, 

698 stride_bk, 

699 stride_bn, 

700 stride_cm, 

701 stride_cn, 

702 stride_bse, 

703 stride_bsk, 

704 stride_bsn, 

705 stride_bze, 

706 stride_bzk, 

707 stride_bzn, 

708 block_k_diviable: tl.constexpr, 

709 group_size: tl.constexpr, 

710 # Meta-parameters 

711 BLOCK_SIZE_M: tl.constexpr, 

712 BLOCK_SIZE_N: tl.constexpr, 

713 BLOCK_SIZE_K: tl.constexpr, 

714 GROUP_SIZE_M: tl.constexpr, 

715 SPLIT_K: tl.constexpr, 

716 MUL_ROUTED_WEIGHT: tl.constexpr, 

717 top_k: tl.constexpr, 

718 compute_type: tl.constexpr, 

719 has_zp: tl.constexpr, 

720 use_int4_w4a16: tl.constexpr, 

721 use_int8_w8a16: tl.constexpr, 

722): 

723 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights.""" 

724 # Map pid to C block (grouped ordering for L2 reuse) 

725 pid = tl.program_id(axis=0) 

726 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

727 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

728 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

729 group_id = pid // num_pid_in_group 

730 first_pid_m = group_id * GROUP_SIZE_M 

731 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

732 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

733 pid_n = (pid % num_pid_in_group) // group_size_m 

734 

735 # Create pointers for first blocks of A and B 

736 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

737 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

738 return 

739 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

740 # Cast to int64 to prevent overflow in stride*offset products 

741 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) 

742 token_mask = offs_token < num_valid_tokens 

743 

744 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

745 if off_experts == -1: 

746 # ----------------------------------------------------------- 

747 # Write back zeros to the output when the expert is not 

748 # in the current expert parallel rank. 

749 write_zeros_to_output( 

750 c_ptr, 

751 stride_cm, 

752 stride_cn, 

753 pid_n, 

754 N, 

755 offs_token, 

756 token_mask, 

757 BLOCK_SIZE_M, 

758 BLOCK_SIZE_N, 

759 compute_type, 

760 ) 

761 return 

762 

763 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

764 offs_k = tl.arange(0, BLOCK_SIZE_K) 

765 a_ptrs = a_ptr + ( 

766 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

767 ) 

768 

769 if use_int4_w4a16: 

770 b_ptrs = ( 

771 b_ptr 

772 + off_experts * stride_be 

773 + (offs_k[:, None] // 2) * stride_bk 

774 + offs_bn[None, :] * stride_bn 

775 ) 

776 b_shifter = (offs_k[:, None] % 2) * 4 

777 elif use_int8_w8a16: 

778 b_ptrs = ( 

779 b_ptr 

780 + off_experts * stride_be 

781 + offs_k[:, None] * stride_bk 

782 + offs_bn[None, :] * stride_bn 

783 ) 

784 

785 if not has_zp and use_int4_w4a16: 

786 b_zp_num = 8 

787 if not has_zp and use_int8_w8a16: 

788 b_zp_num = 128 

789 elif has_zp and use_int4_w4a16: 

790 b_zp_shifter = (offs_bn[None, :] % 2) * 4 

791 

792 # Accumulate C block in fp32 

793 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

794 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

795 if not block_k_diviable: 

796 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K 

797 k_other = 0.0 

798 else: 

799 k_mask = None 

800 k_other = None 

801 

802 a = tl.load( 

803 a_ptrs, 

804 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

805 other=0.0, 

806 ) 

807 b = tl.load(b_ptrs) 

808 if use_int4_w4a16: 

809 b = (b >> b_shifter) & 0xF 

810 

811 b_scale_ptrs = ( 

812 b_scale_ptr 

813 + off_experts * stride_bse 

814 + offs_bn[None, :] * stride_bsn 

815 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk 

816 ) 

817 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) 

818 b_scale = b_scale.to(tl.float32) 

819 

820 if has_zp and use_int4_w4a16: 

821 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

822 b_zp_ptrs = ( 

823 b_zp_ptr 

824 + off_experts * stride_bze 

825 + (offs_bn[None, :] // 2) * stride_bzn 

826 + offs_k_true * stride_bzk 

827 ) 

828 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

829 b_zp = (b_zp >> b_zp_shifter) & 0xF 

830 b_zp = b_zp.to(tl.float32) 

831 elif has_zp and use_int8_w8a16: 

832 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

833 b_zp_ptrs = ( 

834 b_zp_ptr 

835 + off_experts * stride_bze 

836 + offs_bn[None, :] * stride_bzn 

837 + offs_k_true * stride_bzk 

838 ) 

839 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

840 b_zp = b_zp.to(tl.float32) 

841 

842 if has_zp: 

843 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) 

844 else: 

845 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) 

846 accumulator = tl.dot(a, b, acc=accumulator) 

847 

848 a_ptrs += BLOCK_SIZE_K * stride_ak 

849 if use_int4_w4a16: 

850 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk 

851 else: 

852 b_ptrs += BLOCK_SIZE_K * stride_bk 

853 

854 if MUL_ROUTED_WEIGHT: 

855 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) 

856 accumulator = accumulator * moe_weight[:, None] 

857 

858 accumulator = accumulator.to(compute_type) 

859 # Write back output 

860 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

861 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

862 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

863 tl.store(c_ptrs, accumulator, mask=c_mask) 

864 

865 

866@triton.jit 

867def fused_moe_kernel( 

868 # Pointers to matrices 

869 a_ptr, 

870 b_ptr, 

871 c_ptr, 

872 b_bias_ptr, 

873 a_scale_ptr, 

874 b_scale_ptr, 

875 topk_weights_ptr, 

876 sorted_token_ids_ptr, 

877 expert_ids_ptr, 

878 num_tokens_post_padded_ptr, 

879 # Matrix dimensions 

880 N, 

881 K, 

882 EM, 

883 num_valid_tokens, 

884 stride_am, 

885 stride_ak, 

886 stride_be, 

887 stride_bk, 

888 stride_bn, 

889 stride_cm, 

890 stride_cn, 

891 stride_asm, 

892 stride_ask, 

893 stride_bse, 

894 stride_bsk, 

895 stride_bsn, 

896 stride_bbe, # bias expert stride 

897 stride_bbn, # bias N stride 

898 # Block size for block-wise quantization 

899 group_n: tl.constexpr, 

900 group_k: tl.constexpr, 

901 naive_block_assignment: tl.constexpr, 

902 # Meta-parameters 

903 BLOCK_SIZE_M: tl.constexpr, 

904 BLOCK_SIZE_N: tl.constexpr, 

905 BLOCK_SIZE_K: tl.constexpr, 

906 GROUP_SIZE_M: tl.constexpr, 

907 SPLIT_K: tl.constexpr, 

908 MUL_ROUTED_WEIGHT: tl.constexpr, 

909 top_k: tl.constexpr, 

910 compute_type: tl.constexpr, 

911 use_fp8_w8a8: tl.constexpr, 

912 use_int8_w8a8: tl.constexpr, 

913 use_int8_w8a16: tl.constexpr, 

914 per_channel_quant: tl.constexpr, 

915 HAS_BIAS: tl.constexpr, 

916): 

917 """Fused MoE kernel: token × expert GEMM with quantization support.""" 

918 # Map pid to C block (grouped ordering for L2 reuse) 

919 pid = tl.program_id(axis=0) 

920 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

921 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

922 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

923 group_id = pid // num_pid_in_group 

924 first_pid_m = group_id * GROUP_SIZE_M 

925 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

926 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

927 pid_n = (pid % num_pid_in_group) // group_size_m 

928 

929 # Create pointers for first blocks of A and B 

930 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

931 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

932 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

933 return 

934 if not naive_block_assignment: 

935 offs_token_id = pid_m * BLOCK_SIZE_M + offs 

936 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) 

937 else: 

938 offs_token = tl.where( 

939 offs == 0, 

940 pid_m, # first element = pid_m 

941 num_valid_tokens, # remaining elements = constant 

942 ) 

943 offs_token = offs_token.to(tl.int64) # prevent int32 overflow 

944 

945 token_mask = offs_token < num_valid_tokens 

946 

947 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

948 if off_experts == -1: 

949 # Expert not in current EP rank, write zeros 

950 write_zeros_to_output( 

951 c_ptr, 

952 stride_cm, 

953 stride_cn, 

954 pid_n, 

955 N, 

956 offs_token, 

957 token_mask, 

958 BLOCK_SIZE_M, 

959 BLOCK_SIZE_N, 

960 compute_type, 

961 ) 

962 return 

963 

964 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

965 offs_k = tl.arange(0, BLOCK_SIZE_K) 

966 a_ptrs = a_ptr + ( 

967 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

968 ) 

969 

970 b_ptrs = ( 

971 b_ptr 

972 + off_experts * stride_be 

973 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

974 ) 

975 if use_int8_w8a16: 

976 b_scale_ptrs = ( 

977 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

978 ) 

979 b_scale = tl.load(b_scale_ptrs) 

980 

981 if use_fp8_w8a8 or use_int8_w8a8: 

982 if group_k > 0 and group_n > 0: # block-wise 

983 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

984 offs_bsn = offs_bn // group_n 

985 b_scale_ptrs = ( 

986 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn 

987 ) 

988 elif per_channel_quant: # channel-wise 

989 b_scale_ptrs = ( 

990 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

991 ) 

992 b_scale = tl.load(b_scale_ptrs) 

993 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

994 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] 

995 else: # tensor-wise 

996 a_scale = tl.load(a_scale_ptr) 

997 b_scale = tl.load(b_scale_ptr + off_experts) 

998 if HAS_BIAS: 

999 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn 

1000 bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) 

1001 # Accumulate C block in fp32 

1002 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

1003 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

1004 a = tl.load( 

1005 a_ptrs, 

1006 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

1007 other=0.0, 

1008 ) 

1009 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) 

1010 if use_int8_w8a16: 

1011 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) 

1012 elif use_fp8_w8a8 or use_int8_w8a8: 

1013 if group_k > 0 and group_n > 0: 

1014 k_start = k * BLOCK_SIZE_K 

1015 offs_ks = k_start // group_k 

1016 a_scale = tl.load( 

1017 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 

1018 ) 

1019 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) 

1020 

1021 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] 

1022 else: 

1023 if use_fp8_w8a8: 

1024 accumulator = tl.dot(a, b, acc=accumulator) 

1025 else: 

1026 accumulator += tl.dot(a, b) 

1027 else: 

1028 accumulator += tl.dot(a, b) 

1029 a_ptrs += BLOCK_SIZE_K * stride_ak 

1030 b_ptrs += BLOCK_SIZE_K * stride_bk 

1031 

1032 # Dequantization 

1033 if use_int8_w8a16: 

1034 accumulator = accumulator * b_scale 

1035 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0): 

1036 accumulator = accumulator * a_scale * b_scale 

1037 

1038 if HAS_BIAS: 

1039 accumulator += bias[None, :] 

1040 

1041 # Router weight multiplication (must be in fp32) 

1042 if MUL_ROUTED_WEIGHT: 

1043 moe_weight = tl.load( 

1044 topk_weights_ptr + offs_token, 

1045 mask=token_mask, 

1046 other=0, 

1047 ) 

1048 accumulator *= moe_weight[:, None] 

1049 

1050 accumulator = accumulator.to(compute_type) 

1051 

1052 # Write back output 

1053 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

1054 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

1055 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

1056 tl.store(c_ptrs, accumulator, mask=c_mask) 

1057 

1058 

1059def invoke_fused_moe_wna16_triton_kernel( 

1060 A: torch.Tensor, 

1061 B: torch.Tensor, 

1062 C: torch.Tensor, 

1063 B_scale: torch.Tensor | None, 

1064 B_zp: torch.Tensor | None, 

1065 topk_weights: torch.Tensor | None, 

1066 sorted_token_ids: torch.Tensor, 

1067 expert_ids: torch.Tensor, 

1068 num_tokens_post_padded: torch.Tensor, 

1069 mul_routed_weight: bool, 

1070 top_k: int, 

1071 config: dict[str, Any], 

1072 compute_type: tl.dtype, 

1073 use_int8_w8a16: bool, 

1074 use_int4_w4a16: bool, 

1075 block_shape: list[int] | None, 

1076): 

1077 assert B_scale is not None and B_scale.ndim == 3 

1078 assert B_zp is None or B_zp.ndim == 3 

1079 assert block_shape is not None and block_shape[0] == 0 

1080 

1081 M = A.size(0) 

1082 num_tokens = M * top_k 

1083 

1084 EM = sorted_token_ids.size(0) 

1085 if A.size(0) < config["BLOCK_SIZE_M"]: 

1086 # optimize for small batch_size. 

1087 # We assume that top_ids of each token is unique, 

1088 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, 

1089 # and we can skip some invalid blocks. 

1090 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) 

1091 grid = lambda META: ( 

1092 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

1093 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

1094 ) 

1095 config = config.copy() 

1096 config.update( 

1097 get_moe_wna16_block_config( 

1098 config=config, 

1099 use_moe_wna16_cuda=False, 

1100 num_valid_tokens=num_tokens, 

1101 size_k=A.size(1), 

1102 size_n=B.size(1), 

1103 num_experts=B.size(1), 

1104 group_size=block_shape[1], 

1105 real_top_k=top_k, 

1106 block_size_m=config["BLOCK_SIZE_M"], 

1107 ) 

1108 ) 

1109 

1110 fused_moe_kernel_gptq_awq[grid]( 

1111 A, 

1112 B, 

1113 C, 

1114 B_scale, 

1115 B_zp, 

1116 topk_weights, 

1117 sorted_token_ids, 

1118 expert_ids, 

1119 num_tokens_post_padded, 

1120 B.size(1), 

1121 A.size(1), 

1122 EM, 

1123 num_tokens, 

1124 A.stride(0), 

1125 A.stride(1), 

1126 B.stride(0), 

1127 B.stride(2), 

1128 B.stride(1), 

1129 C.stride(1), 

1130 C.stride(2), 

1131 B_scale.stride(0), 

1132 B_scale.stride(2), 

1133 B_scale.stride(1), 

1134 B_zp.stride(0) if B_zp is not None else 0, 

1135 B_zp.stride(2) if B_zp is not None else 0, 

1136 B_zp.stride(1) if B_zp is not None else 0, 

1137 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, 

1138 group_size=block_shape[1], 

1139 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1140 top_k=top_k, 

1141 compute_type=compute_type, 

1142 has_zp=B_zp is not None, 

1143 use_int4_w4a16=use_int4_w4a16, 

1144 use_int8_w8a16=use_int8_w8a16, 

1145 **config, 

1146 ) 

1147 

1148 

1149def invoke_fused_moe_triton_kernel( 

1150 A: torch.Tensor, 

1151 B: torch.Tensor, 

1152 C: torch.Tensor, 

1153 A_scale: Optional[torch.Tensor], 

1154 B_scale: Optional[torch.Tensor], 

1155 topk_weights: Optional[torch.Tensor], 

1156 sorted_token_ids: torch.Tensor, 

1157 expert_ids: torch.Tensor, 

1158 num_tokens_post_padded: torch.Tensor, 

1159 mul_routed_weight: bool, 

1160 top_k: int, 

1161 config: dict[str, Any], 

1162 compute_type: tl.dtype, 

1163 use_fp8_w8a8: bool = False, 

1164 use_int8_w8a8: bool = False, 

1165 use_int8_w8a16: bool = False, 

1166 use_int4_w4a16: bool = False, 

1167 per_channel_quant: bool = False, 

1168 block_shape: Optional[list[int]] = None, 

1169 B_bias: torch.Tensor | None = None, 

1170) -> None: 

1171 """Launch the fused_moe_kernel Triton kernel.""" 

1172 assert topk_weights is not None or not mul_routed_weight 

1173 assert topk_weights is None or topk_weights.stride(1) == 1 

1174 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1175 

1176 if use_fp8_w8a8 or use_int8_w8a8: 

1177 assert B_scale is not None 

1178 assert block_shape is None or triton.cdiv( 

1179 B.size(-2), block_shape[0] 

1180 ) == B_scale.size(-2) 

1181 assert block_shape is None or triton.cdiv( 

1182 B.size(-1), block_shape[1] 

1183 ) == B_scale.size(-1) 

1184 elif use_int8_w8a16 or use_int4_w4a16: 

1185 assert B_scale is not None 

1186 assert block_shape is None or block_shape[0] == 0 

1187 else: 

1188 assert A_scale is None 

1189 assert B_scale is None 

1190 

1191 M = A.size(0) 

1192 num_tokens = M * top_k 

1193 if sorted_token_ids is not None: 

1194 EM = sorted_token_ids.size(0) 

1195 if A.size(0) < config["BLOCK_SIZE_M"]: 

1196 EM = min( 

1197 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"] 

1198 ) 

1199 else: 

1200 EM = num_tokens * config["BLOCK_SIZE_M"] 

1201 grid = lambda META: ( 

1202 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

1203 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

1204 ) 

1205 HAS_BIAS = B_bias is not None 

1206 

1207 config = config.copy() 

1208 config["SPLIT_K"] = 1 

1209 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") 

1210 if block_shape is not None: 

1211 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) 

1212 

1213 fused_moe_kernel[grid]( 

1214 A, 

1215 B, 

1216 C, 

1217 B_bias, 

1218 A_scale, 

1219 B_scale, 

1220 topk_weights, 

1221 sorted_token_ids, 

1222 expert_ids, 

1223 num_tokens_post_padded, 

1224 B.size(1), # N 

1225 B.size(2), # K 

1226 EM, 

1227 num_tokens, 

1228 A.stride(0), 

1229 A.stride(1), 

1230 B.stride(0), 

1231 B.stride(2), 

1232 B.stride(1), 

1233 C.stride(1), 

1234 C.stride(2), 

1235 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, 

1236 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, 

1237 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1238 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, 

1239 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1240 B_bias.stride(0) if B_bias is not None else 0, 

1241 B_bias.stride(1) if B_bias is not None else 0, 

1242 0 if block_shape is None else block_shape[0], 

1243 0 if block_shape is None else block_shape[1], 

1244 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1245 top_k=top_k, 

1246 compute_type=compute_type, 

1247 use_fp8_w8a8=use_fp8_w8a8, 

1248 use_int8_w8a8=use_int8_w8a8, 

1249 use_int8_w8a16=use_int8_w8a16, 

1250 per_channel_quant=per_channel_quant, 

1251 naive_block_assignment=(sorted_token_ids is None), 

1252 HAS_BIAS=HAS_BIAS, 

1253 BLOCK_SIZE_K=BLOCK_SIZE_K, 

1254 **config, 

1255 ) 

1256 

1257 

1258def dispatch_fused_moe_kernel( 

1259 A: torch.Tensor, 

1260 B: torch.Tensor, 

1261 C: torch.Tensor, 

1262 A_scale: Optional[torch.Tensor], 

1263 B_scale: Optional[torch.Tensor], 

1264 B_zp: Optional[torch.Tensor], 

1265 topk_weights: Optional[torch.Tensor], 

1266 sorted_token_ids: torch.Tensor, 

1267 expert_ids: torch.Tensor, 

1268 num_tokens_post_padded: torch.Tensor, 

1269 mul_routed_weight: bool, 

1270 top_k: int, 

1271 config: dict[str, Any], 

1272 compute_type: tl.dtype, 

1273 use_fp8_w8a8: bool, 

1274 use_int8_w8a8: bool, 

1275 use_int8_w8a16: bool, 

1276 use_int4_w4a16: bool, 

1277 per_channel_quant: bool, 

1278 block_shape: Optional[list[int]] = None, 

1279 B_bias: Optional[torch.Tensor] = None, 

1280) -> None: 

1281 """Dispatch to the appropriate fused MoE kernel based on quantization flags.""" 

1282 assert topk_weights is not None or not mul_routed_weight 

1283 assert topk_weights is None or topk_weights.stride(1) == 1 

1284 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1285 

1286 # M = A.size(0) 

1287 # num_tokens = M * top_k 

1288 

1289 if False: 

1290 # TODO: Other precision-specific implementations 

1291 # use_fp8_w8a8, 

1292 # use_int8_w8a8, 

1293 # use_int8_w8a16, 

1294 # use_int4_w4a16, 

1295 pass 

1296 if (use_int8_w8a16 or use_int4_w4a16) and ( 

1297 block_shape is not None and block_shape[1] > 0 

1298 ): 

1299 assert B_bias is None 

1300 invoke_fused_moe_wna16_triton_kernel( 

1301 A, 

1302 B, 

1303 C, 

1304 B_scale, 

1305 B_zp, 

1306 topk_weights, 

1307 sorted_token_ids, 

1308 expert_ids, 

1309 num_tokens_post_padded, 

1310 mul_routed_weight, 

1311 top_k, 

1312 config, 

1313 compute_type, 

1314 use_int8_w8a16, 

1315 use_int4_w4a16, 

1316 block_shape, 

1317 ) 

1318 else: 

1319 invoke_fused_moe_triton_kernel( 

1320 A, 

1321 B, 

1322 C, 

1323 A_scale, 

1324 B_scale, 

1325 topk_weights, 

1326 sorted_token_ids, 

1327 expert_ids, 

1328 num_tokens_post_padded, 

1329 mul_routed_weight, 

1330 top_k, 

1331 config, 

1332 compute_type, 

1333 use_fp8_w8a8, 

1334 use_int8_w8a8, 

1335 use_int8_w8a16, 

1336 use_int4_w4a16, 

1337 per_channel_quant, 

1338 block_shape, 

1339 B_bias, 

1340 ) 

1341 

1342 

1343def fused_experts_impl( 

1344 hidden_states: torch.Tensor, 

1345 w1: torch.Tensor, 

1346 w2: torch.Tensor, 

1347 topk_weights: torch.Tensor, 

1348 topk_ids: torch.Tensor, 

1349 inplace: bool = False, 

1350 activation: str = "silu", 

1351 apply_router_weight_on_input: bool = False, 

1352 use_fp8_w8a8: bool = False, 

1353 use_int8_w8a8: bool = False, 

1354 use_int8_w8a16: bool = False, 

1355 use_int4_w4a16: bool = False, 

1356 ocp_mx_scheme: str | None = None, 

1357 per_channel_quant: bool = False, 

1358 global_num_experts: int = -1, 

1359 expert_map: torch.Tensor | None = None, 

1360 w1_scale: Optional[torch.Tensor] = None, 

1361 w2_scale: Optional[torch.Tensor] = None, 

1362 w1_zp: torch.Tensor | None = None, 

1363 w2_zp: torch.Tensor | None = None, 

1364 a1_scale: Optional[torch.Tensor] = None, 

1365 a2_scale: Optional[torch.Tensor] = None, 

1366 block_shape: Optional[list[int]] = None, 

1367 w1_bias: Optional[torch.Tensor] = None, 

1368 w2_bias: Optional[torch.Tensor] = None, 

1369) -> torch.Tensor: 

1370 logger.debug("GEMS FUSED MOE") 

1371 assert ( 

1372 activation == "silu" 

1373 ), f"Only 'silu' activation is supported, got {activation}" 

1374 

1375 activation_enum = MoEActivation.from_str(activation) 

1376 

1377 # Check constraints 

1378 if use_int4_w4a16: 

1379 # INT4 stored unpacked in INT8 containers (full K dim) 

1380 assert hidden_states.size(1) == w1.size( 

1381 2 

1382 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1383 elif ocp_mx_scheme is not None: 

1384 if ocp_mx_scheme.startswith("w_mxfp4"): 

1385 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" 

1386 elif ocp_mx_scheme.startswith("w_mxfp6"): 

1387 assert ( 

1388 hidden_states.size(1) == (w1.size(2) * 4) // 3 

1389 ), "hidden size mismatch" 

1390 else: 

1391 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1392 else: 

1393 assert hidden_states.size(1) == w1.size( 

1394 2 

1395 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1396 

1397 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" 

1398 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" 

1399 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" 

1400 assert w2.stride(-1) == 1, "Stride of last dimension must be 1" 

1401 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] 

1402 

1403 num_tokens = hidden_states.size(0) 

1404 E, N, _ = w1.size() 

1405 K = w2.size(1) 

1406 if global_num_experts == -1: 

1407 global_num_experts = E 

1408 top_k_num = topk_ids.size(1) 

1409 

1410 CHUNK_SIZE: int = 16 * 1024 

1411 M = min(num_tokens, CHUNK_SIZE) 

1412 

1413 config_dtype = _get_config_dtype_str( 

1414 use_fp8_w8a8=use_fp8_w8a8, 

1415 use_int8_w8a16=use_int8_w8a16, 

1416 use_int4_w4a16=use_int4_w4a16, 

1417 ocp_mx_scheme=ocp_mx_scheme, 

1418 dtype=hidden_states.dtype, 

1419 ) 

1420 

1421 quant_dtype = _get_config_quant_dtype( 

1422 use_fp8_w8a8=use_fp8_w8a8, 

1423 use_int8_w8a8=use_int8_w8a8, 

1424 ocp_mx_scheme=ocp_mx_scheme, 

1425 ) 

1426 

1427 get_config_func = functools.partial( 

1428 try_get_optimal_moe_config, 

1429 w1.size(), 

1430 w2.size(), 

1431 top_k_num, 

1432 config_dtype, 

1433 block_shape=block_shape, 

1434 ) 

1435 

1436 config = get_config_func(M) 

1437 

1438 # cache1 and cache3 share memory (non-overlapping lifetime) 

1439 cache13 = torch.empty( 

1440 M * top_k_num * max(N, K), 

1441 device=hidden_states.device, 

1442 dtype=hidden_states.dtype, 

1443 ) 

1444 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) 

1445 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) 

1446 

1447 # cache2 needs separate memory (concurrent with cache1) 

1448 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum) 

1449 intermediate_cache2 = torch.empty( 

1450 (M * top_k_num, activation_out_dim), 

1451 device=hidden_states.device, 

1452 dtype=hidden_states.dtype, 

1453 ) 

1454 

1455 if hidden_states.dtype == torch.bfloat16: 

1456 compute_type = tl.bfloat16 

1457 elif hidden_states.dtype == torch.float16: 

1458 compute_type = tl.float16 

1459 elif hidden_states.dtype == torch.float32: 

1460 compute_type = tl.float32 

1461 else: 

1462 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") 

1463 

1464 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) 

1465 

1466 if ocp_mx_scheme is not None: 

1467 # Dequantize OCP MX weights (TODO: skip on platforms with native MX) 

1468 if ocp_mx_scheme.startswith("w_mxfp4"): 

1469 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) 

1470 w1_scale = None 

1471 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) 

1472 w2_scale = None 

1473 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"): 

1474 w1 = dequant_mxfp6( 

1475 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1476 ) 

1477 w1_scale = None 

1478 w2 = dequant_mxfp6( 

1479 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1480 ) 

1481 w2_scale = None 

1482 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"): 

1483 w1 = dequant_mxfp6( 

1484 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1485 ) 

1486 w1_scale = None 

1487 w2 = dequant_mxfp6( 

1488 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1489 ) 

1490 w2_scale = None 

1491 else: 

1492 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1493 

1494 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot) 

1495 if use_int8_w8a16 or use_int4_w4a16: 

1496 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype) 

1497 w1_scale = None 

1498 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype) 

1499 w2_scale = None 

1500 use_int8_w8a16 = False 

1501 use_int4_w4a16 = False 

1502 

1503 for chunk in range((num_tokens // CHUNK_SIZE) + 1): 

1504 begin_chunk_idx, end_chunk_idx = ( 

1505 chunk * CHUNK_SIZE, 

1506 min((chunk + 1) * CHUNK_SIZE, num_tokens), 

1507 ) 

1508 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] 

1509 tokens_in_chunk, _ = curr_hidden_states.size() 

1510 

1511 if tokens_in_chunk == 0: 

1512 break 

1513 

1514 if tokens_in_chunk < CHUNK_SIZE and chunk > 0: 

1515 # Adjust cache size for last chunk 

1516 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] 

1517 intermediate_cache2 = intermediate_cache2[ 

1518 : tokens_in_chunk * topk_ids.size(1) 

1519 ] 

1520 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] 

1521 config = get_config_func(tokens_in_chunk) 

1522 

1523 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] 

1524 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] 

1525 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( 

1526 A=curr_hidden_states, 

1527 A_scale=a1_scale, 

1528 quant_dtype=quant_dtype, 

1529 per_act_token_quant=per_channel_quant, 

1530 block_shape=block_shape, 

1531 ocp_mx_scheme=ocp_mx_scheme, 

1532 ) 

1533 

1534 SPARSITY_FACTOR = 4 

1535 naive_block_assignment = ( 

1536 expert_map is None 

1537 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts 

1538 and not ( 

1539 (use_int8_w8a16 or use_int4_w4a16) 

1540 and block_shape is not None 

1541 and block_shape[1] > 0 

1542 ) 

1543 ) 

1544 

1545 if not naive_block_assignment: 

1546 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

1547 curr_topk_ids, 

1548 config["BLOCK_SIZE_M"], 

1549 global_num_experts, 

1550 expert_map, 

1551 # ignore_invalid_experts=True, 

1552 ) 

1553 else: 

1554 max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"] 

1555 expert_ids = curr_topk_ids.view(-1) 

1556 num_tokens_post_padded = torch.empty( 

1557 (1), dtype=torch.int32, device=topk_ids.device 

1558 ) 

1559 num_tokens_post_padded.fill_(max_num_tokens_padded) 

1560 sorted_token_ids = None 

1561 

1562 dispatch_fused_moe_kernel( 

1563 qcurr_hidden_states, 

1564 w1, 

1565 intermediate_cache1, 

1566 a1q_scale, 

1567 w1_scale, 

1568 w1_zp, 

1569 curr_topk_weights, 

1570 sorted_token_ids, 

1571 expert_ids, 

1572 num_tokens_post_padded, 

1573 apply_router_weight_on_input, 

1574 top_k_num, 

1575 config, 

1576 compute_type=compute_type, 

1577 use_fp8_w8a8=use_fp8_w8a8, 

1578 use_int8_w8a8=use_int8_w8a8, 

1579 use_int8_w8a16=use_int8_w8a16, 

1580 use_int4_w4a16=use_int4_w4a16, 

1581 per_channel_quant=per_channel_quant, 

1582 block_shape=block_shape, 

1583 B_bias=w1_bias, 

1584 ) 

1585 

1586 apply_moe_activation( 

1587 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N) 

1588 ) 

1589 

1590 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( 

1591 A=intermediate_cache2, 

1592 A_scale=a2_scale, 

1593 quant_dtype=quant_dtype, 

1594 per_act_token_quant=per_channel_quant, 

1595 block_shape=block_shape, 

1596 ocp_mx_scheme=ocp_mx_scheme, 

1597 ) 

1598 

1599 if expert_map is not None: 

1600 intermediate_cache3.zero_() 

1601 

1602 dispatch_fused_moe_kernel( 

1603 qintermediate_cache2, 

1604 w2, 

1605 intermediate_cache3, 

1606 a2q_scale, 

1607 w2_scale, 

1608 w2_zp, 

1609 curr_topk_weights, 

1610 sorted_token_ids, 

1611 expert_ids, 

1612 num_tokens_post_padded, 

1613 not apply_router_weight_on_input, 

1614 1, 

1615 config, 

1616 compute_type=compute_type, 

1617 use_fp8_w8a8=use_fp8_w8a8, 

1618 use_int8_w8a8=use_int8_w8a8, 

1619 use_int8_w8a16=use_int8_w8a16, 

1620 use_int4_w4a16=use_int4_w4a16, 

1621 per_channel_quant=per_channel_quant, 

1622 block_shape=block_shape, 

1623 B_bias=w2_bias, 

1624 ) 

1625 

1626 moe_sum( 

1627 intermediate_cache3.view(*intermediate_cache3.size()), 

1628 out_hidden_states[begin_chunk_idx:end_chunk_idx], 

1629 ) 

1630 

1631 return out_hidden_states 

1632 

1633 

1634def inplace_fused_experts( 

1635 hidden_states: torch.Tensor, 

1636 w1: torch.Tensor, 

1637 w2: torch.Tensor, 

1638 topk_weights: torch.Tensor, 

1639 topk_ids: torch.Tensor, 

1640 activation: str = "silu", 

1641 apply_router_weight_on_input: bool = False, 

1642 use_fp8_w8a8: bool = False, 

1643 use_int8_w8a8: bool = False, 

1644 use_int8_w8a16: bool = False, 

1645 use_int4_w4a16: bool = False, 

1646 per_channel_quant: bool = False, 

1647 global_num_experts: int = -1, 

1648 w1_scale: Optional[torch.Tensor] = None, 

1649 w2_scale: Optional[torch.Tensor] = None, 

1650 a1_scale: Optional[torch.Tensor] = None, 

1651 a2_scale: Optional[torch.Tensor] = None, 

1652 block_shape: Optional[list[int]] = None, 

1653 w1_bias: Optional[torch.Tensor] = None, 

1654 w2_bias: Optional[torch.Tensor] = None, 

1655) -> None: 

1656 """ 

1657 In-place fused MoE: writes output directly into ``hidden_states``. 

1658 

1659 Same semantics as ``fused_experts_impl(..., inplace=True)``. 

1660 Returns None (the result is stored in ``hidden_states``). 

1661 """ 

1662 fused_experts_impl( 

1663 hidden_states, 

1664 w1, 

1665 w2, 

1666 topk_weights, 

1667 topk_ids, 

1668 inplace=True, 

1669 activation=activation, 

1670 apply_router_weight_on_input=apply_router_weight_on_input, 

1671 use_fp8_w8a8=use_fp8_w8a8, 

1672 use_int8_w8a8=use_int8_w8a8, 

1673 use_int8_w8a16=use_int8_w8a16, 

1674 use_int4_w4a16=use_int4_w4a16, 

1675 per_channel_quant=per_channel_quant, 

1676 global_num_experts=global_num_experts, 

1677 w1_scale=w1_scale, 

1678 w2_scale=w2_scale, 

1679 a1_scale=a1_scale, 

1680 a2_scale=a2_scale, 

1681 block_shape=block_shape, 

1682 w1_bias=w1_bias, 

1683 w2_bias=w2_bias, 

1684 ) 

1685 

1686 

1687def outplace_fused_experts( 

1688 hidden_states: torch.Tensor, 

1689 w1: torch.Tensor, 

1690 w2: torch.Tensor, 

1691 topk_weights: torch.Tensor, 

1692 topk_ids: torch.Tensor, 

1693 activation: str = "silu", 

1694 apply_router_weight_on_input: bool = False, 

1695 use_fp8_w8a8: bool = False, 

1696 use_int8_w8a8: bool = False, 

1697 use_int8_w8a16: bool = False, 

1698 use_int4_w4a16: bool = False, 

1699 per_channel_quant: bool = False, 

1700 global_num_experts: int = -1, 

1701 w1_scale: Optional[torch.Tensor] = None, 

1702 w2_scale: Optional[torch.Tensor] = None, 

1703 a1_scale: Optional[torch.Tensor] = None, 

1704 a2_scale: Optional[torch.Tensor] = None, 

1705 block_shape: Optional[list[int]] = None, 

1706 w1_bias: Optional[torch.Tensor] = None, 

1707 w2_bias: Optional[torch.Tensor] = None, 

1708) -> torch.Tensor: 

1709 """ 

1710 Out-of-place fused MoE: allocates and returns a new output tensor. 

1711 

1712 Same semantics as ``fused_experts_impl(..., inplace=False)``. 

1713 """ 

1714 return fused_experts_impl( 

1715 hidden_states, 

1716 w1, 

1717 w2, 

1718 topk_weights, 

1719 topk_ids, 

1720 inplace=False, 

1721 activation=activation, 

1722 apply_router_weight_on_input=apply_router_weight_on_input, 

1723 use_fp8_w8a8=use_fp8_w8a8, 

1724 use_int8_w8a8=use_int8_w8a8, 

1725 use_int8_w8a16=use_int8_w8a16, 

1726 use_int4_w4a16=use_int4_w4a16, 

1727 per_channel_quant=per_channel_quant, 

1728 global_num_experts=global_num_experts, 

1729 w1_scale=w1_scale, 

1730 w2_scale=w2_scale, 

1731 a1_scale=a1_scale, 

1732 a2_scale=a2_scale, 

1733 block_shape=block_shape, 

1734 w1_bias=w1_bias, 

1735 w2_bias=w2_bias, 

1736 )