Coverage for src/flag_gems/fused/fused_moe.py: 43%
665 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3#
4# Adapted from the vLLM project (https://github.com/vllm-project/vllm).
5# Source files under vllm/model_executor/layers/:
6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl
7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation
8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input
9# fused_moe/config.py – _get_config_dtype_str
10# quantization/utils/mxfp4_utils.py – dequant_mxfp4
11# quantization/utils/mxfp6_utils.py – dequant_mxfp6
12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE
15import functools
16import json
17import logging
18import os
19from enum import Enum
20from typing import Any, Optional
22import torch
23import torch.nn.functional as F
24import triton
25import triton.language as tl
27from flag_gems.fused.moe_align_block_size import moe_align_block_size
28from flag_gems.fused.moe_sum import moe_sum
29from flag_gems.utils import pointwise_dynamic
31logger = logging.getLogger(__name__)
33# OCP MX quantization helpers (requires amd-quark)
35OCP_MX_BLOCK_SIZE = 32
38@functools.lru_cache(maxsize=1)
39def get_embedded_moe_configs():
40 config_path = os.path.join(
41 os.path.dirname(__file__), "..", "utils", "configs", "fused_moe_config.json"
42 )
43 if not os.path.exists(config_path):
44 return {}, {}
45 with open(config_path, "r") as f:
46 # JSON keys are strings, values are dicts where keys are M and values are configs
47 data = json.load(f)
49 fallback = data.get("_FALLBACK", {})
51 # We need to convert the innermost keys (which are stringified integers for M) back to integers.
52 # Ensure we map the lists back to config dicts.
53 keys_order = [
54 "BLOCK_SIZE_M",
55 "BLOCK_SIZE_N",
56 "BLOCK_SIZE_K",
57 "GROUP_SIZE_M",
58 "num_warps",
59 "num_stages",
60 ]
61 parsed_data = {}
62 for dev, configs in data.items():
63 if dev == "_FALLBACK":
64 continue
65 parsed_data[dev] = {}
66 for k, m_dict in configs.items():
67 parsed_dict = {}
68 for m, v in m_dict.items():
69 if isinstance(v, list):
70 parsed_dict[int(m)] = dict(zip(keys_order, v))
71 else:
72 parsed_dict[int(m)] = v
73 parsed_data[dev][k] = parsed_dict
75 return parsed_data, fallback
78def dequant_mxfp4(
79 x: torch.Tensor,
80 scale: torch.Tensor,
81 float_dtype: torch.dtype,
82) -> torch.Tensor:
83 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4."""
84 try:
85 from quark.torch.kernel import mx
86 except ImportError as err:
87 raise ImportError("amd-quark is required for MX-FP4") from err
89 return mx.dq_mxfp4(x, scale, float_dtype)
92def dequant_mxfp6(
93 x: torch.Tensor,
94 scale: torch.Tensor,
95 float_dtype: torch.dtype,
96 quant_dtype: str,
97) -> torch.Tensor:
98 """Dequantize MXFP6 tensor via quark hw_emulation."""
99 try:
100 from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
101 dequantize_fp4_fp6_per_group,
102 )
103 from quark.torch.utils.pack import create_pack_method
104 except ImportError as err:
105 raise ImportError("amd-quark is required for MX-FP6") from err
107 pack_method = create_pack_method(None, dtype=quant_dtype)
108 unpacked_x = pack_method.unpack(x, reorder=False)
110 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype)
112 return dequantize_fp4_fp6_per_group(
113 unpacked_x,
114 scale,
115 axis=-1,
116 group_size=OCP_MX_BLOCK_SIZE,
117 quant_dtype=quant_dtype,
118 ).to(float_dtype)
121# Activation quantization helpers
124@functools.lru_cache(maxsize=1)
125def _get_device_name() -> str:
126 """Return the normalised CUDA device name (spaces replaced by underscores).
128 Matches the naming convention used by vLLM for its per-device config files.
129 H800 falls back to H100_80GB_HBM3 (same SM 9.0 architecture).
130 """
131 name = torch.cuda.get_device_name().replace(" ", "_")
132 # Normalise the H200 product family to a single key, following vLLM.
133 if "H200" in name.split("_"):
134 name = "NVIDIA_H200"
135 # H800 has the same SM 9.0 as H100; use H100 configs as fallback.
136 embedded_configs, fallback_mapping = get_embedded_moe_configs()
137 if name in embedded_configs:
138 return name
139 # Fallback mapping for devices whose tuning profiles are equivalent.
140 fallback = fallback_mapping.get(name)
141 if fallback and fallback in embedded_configs:
142 logger.info("Device %s not in config table, falling back to %s", name, fallback)
143 return fallback
144 return name
147def get_moe_configs(
148 E: int,
149 N: int,
150 dtype: str | None,
151 block_n: int | None = None,
152 block_k: int | None = None,
153) -> dict[int, Any] | None:
154 """
155 Return optimized configurations for the fused MoE kernel.
157 Looks up pre-tuned configs from the embedded table (ported from vLLM)
158 for the current GPU device. Returns None if no matching config is found.
159 """
160 device_name = _get_device_name()
161 embedded_configs, _ = get_embedded_moe_configs()
162 device_table = embedded_configs.get(device_name)
163 if device_table is None:
164 logger.warning(
165 "No embedded MoE configs for device %s. Will use default config.",
166 device_name,
167 )
168 return None
170 _block_n = block_n if block_n else 0
171 _block_k = block_k if block_k else 0
172 key = f"{E},{N},{dtype},{_block_n},{_block_k}"
173 configs = device_table.get(key)
174 if configs is not None:
175 logger.info("Using embedded MoE config for device=%s, key=%s", device_name, key)
176 return configs
177 logger.warning(
178 "No embedded MoE config for device=%s, key=%s. Will use default config.",
179 device_name,
180 key,
181 )
182 return None
185def try_get_optimal_moe_config(
186 w1_shape: tuple[int, ...],
187 w2_shape: tuple[int, ...],
188 top_k: int,
189 dtype: str | None,
190 M: int,
191 block_shape: list[int] | None = None,
192) -> dict[str, int]:
193 override_config: Optional[dict[str, Any]] = None
194 if override_config:
195 config = override_config
196 else:
197 # First try to load optimal config from the file
198 E, _, N = w2_shape
199 if dtype == "int4_w4a16":
200 N = N * 2
201 block_n = block_shape[0] if block_shape else 0
202 block_k = block_shape[1] if block_shape else 0
203 configs = get_moe_configs(E, N, dtype, block_n, block_k)
205 if configs:
206 config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
207 else:
208 config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape)
209 return config
212def _get_config_quant_dtype(
213 use_fp8_w8a8: bool,
214 use_int8_w8a8: bool,
215 ocp_mx_scheme: str | None,
216) -> None | torch.dtype | str:
217 """Map quantization flags to the corresponding dtype."""
218 if use_fp8_w8a8:
219 return torch.float8_e4m3fn
220 elif use_int8_w8a8:
221 return torch.int8
222 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
223 return "mxfp4"
224 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}:
225 return "mxfp6_e3m2"
226 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
227 return "mxfp6_e2m3"
228 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}:
229 return torch.bfloat16
230 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}:
231 return torch.float8_e4m3fn
233 return None
236def get_moe_wna16_block_config(
237 config: dict[str, int],
238 use_moe_wna16_cuda: bool,
239 num_valid_tokens: int,
240 size_k: int,
241 size_n: int,
242 num_experts: int,
243 group_size: int,
244 real_top_k: int,
245 block_size_m: int,
246):
247 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
248 return {}
249 if not use_moe_wna16_cuda:
250 if num_valid_tokens // real_top_k == 1:
251 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
252 else:
253 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
254 else:
255 block_size_n = 128
256 block_size_k = 128
257 if block_size_k <= group_size:
258 block_size_k = group_size
260 num_n_blocks = size_k // block_size_k
261 num_k_blocks = size_n // block_size_k
262 num_m_blocks = (
263 num_valid_tokens + block_size_m - 1
264 ) / block_size_m + num_experts
265 if num_valid_tokens // real_top_k <= block_size_m:
266 num_m_blocks = min(num_m_blocks, num_valid_tokens)
267 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
269 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256:
270 block_size_k = 256
271 num_blocks = num_blocks // (256 // block_size_k)
273 if (
274 num_m_blocks <= 16
275 and size_k % (block_size_k * 2) == 0
276 and size_k % (block_size_k * 2) == 0
277 and block_size_k <= 512
278 and num_blocks >= 512
279 ):
280 block_size_k = block_size_k * 2
281 num_blocks = num_blocks // 2
283 if num_blocks > 1024:
284 block_size_n = 256
285 num_n_blocks = num_n_blocks // 2
286 num_blocks = num_blocks // 2
288 if size_n <= 1024 and num_blocks >= 1024:
289 block_size_n = 1024
291 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
293 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
296def get_default_config(
297 M: int,
298 E: int,
299 N: int,
300 K: int,
301 topk: int,
302 dtype: str | None,
303 block_shape: list[int] | None = None,
304) -> dict[str, int]:
305 """Default Triton config for fused MoE kernel.
307 Heuristic selection aligned with vLLM v0.17.0 defaults, tuned on H20/H100.
308 Key insight: for high-expert-count MoE (e.g. DeepSeek-V3 E=256), each
309 expert sees very few tokens, so small BLOCK_SIZE_M (16) is critical.
310 """
311 if dtype == "fp8_w8a8" and block_shape is not None:
312 config = {
313 "BLOCK_SIZE_M": 16 if M <= 64 else 64,
314 "BLOCK_SIZE_N": block_shape[0],
315 "BLOCK_SIZE_K": block_shape[1],
316 "GROUP_SIZE_M": 1 if M <= 16 else 32,
317 "num_warps": 4,
318 "num_stages": 3,
319 }
320 else:
321 # tokens_per_expert drives block_m: use M//E (not M*topk//E) to
322 # estimate the actual per-expert token count after routing.
323 tokens_per_expert = M // max(E, 1)
325 if tokens_per_expert <= 2:
326 block_m = 16
327 elif tokens_per_expert <= 4:
328 block_m = 32
329 elif tokens_per_expert <= 16:
330 block_m = 64
331 else:
332 block_m = 128
334 # Tile sizing
335 if N >= 4096:
336 block_n = 128 if M <= 128 else 256
337 elif N >= 1024:
338 block_n = 64 if M <= 64 else 128
339 else:
340 block_n = 64 if M <= 64 else 128
342 if dtype == "fp8_w8a8":
343 block_k = 128
344 elif M <= 64:
345 block_k = 128
346 else:
347 block_k = 64
349 if tokens_per_expert > 128:
350 group_m = 16
351 elif tokens_per_expert > 32:
352 group_m = 8
353 else:
354 group_m = 1
356 # Prefer 4 warps for small tiles; only use 8 for large M
357 num_warps = 4 if M <= 128 else 8
358 num_stages = 3
360 smem_per_stage = (block_m * block_k + block_k * block_n) * 2
361 while num_stages > 2 and smem_per_stage * num_stages > 200_000:
362 num_stages -= 1
364 config = {
365 "BLOCK_SIZE_M": block_m,
366 "BLOCK_SIZE_N": block_n,
367 "BLOCK_SIZE_K": block_k,
368 "GROUP_SIZE_M": group_m,
369 "num_warps": num_warps,
370 "num_stages": num_stages,
371 }
372 return config
375def _get_config_dtype_str(
376 dtype: Optional[torch.dtype] = None,
377 use_fp8_w8a8: bool = False,
378 use_fp8_w8a16: bool = False,
379 use_int8_w8a16: bool = False,
380 use_int4_w4a16: bool = False,
381 ocp_mx_scheme: str | None = None,
382) -> str | None:
383 """Return dtype string for kernel config lookup."""
384 if use_fp8_w8a8:
385 return "fp8_w8a8"
386 elif use_fp8_w8a16:
387 return "fp8_w8a16"
388 elif use_int8_w8a16:
389 return "int8_w8a16"
390 elif use_int4_w4a16:
391 return "int4_w4a16"
392 elif ocp_mx_scheme is not None:
393 return None
394 elif dtype == torch.float:
395 return "float32"
396 return None
399# MoE activation enum
402class MoEActivation(Enum):
403 """Activation functions for MoE layers."""
405 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d]
406 SILU = "silu"
407 GELU = "gelu"
408 RELU2 = "relu2"
409 SWIGLUOAI = "swigluoai"
410 SWIGLUSTEP = "swiglustep"
412 # Non-gated: input [..., d] -> output [..., d]
413 SILU_NO_MUL = "silu_no_mul"
414 GELU_NO_MUL = "gelu_no_mul"
415 RELU2_NO_MUL = "relu2_no_mul"
417 @property
418 def is_gated(self) -> bool:
419 return not self.value.endswith("_no_mul")
421 def without_mul(self) -> "MoEActivation":
422 """Return the non-gated variant."""
423 _without_mul: dict[MoEActivation, MoEActivation] = {
424 MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
425 MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
426 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
427 }
428 return _without_mul.get(self, self)
430 @classmethod
431 def from_str(cls, s: str) -> "MoEActivation":
432 for member in cls:
433 if member.value == s:
434 return member
435 valid = [m.value for m in cls]
436 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
438 @staticmethod
439 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int:
440 """Return N for non-gated, N // 2 for gated activations."""
441 return N if not activation.is_gated else N // 2
444def apply_moe_activation(
445 activation: MoEActivation,
446 output: torch.Tensor,
447 input: torch.Tensor,
448) -> torch.Tensor:
449 """Apply MoE activation (pure PyTorch / FlagGems Triton)."""
450 assert input.dim() == 2, "Input must be 2D"
451 assert output.dim() == 2, "Output must be 2D"
452 if activation.is_gated:
453 assert output.size(-1) * 2 == input.size(-1), (
454 f"{activation.value} expects 2x ratio: "
455 f"{output.size(-1) * 2} vs {input.size(-1)}"
456 )
457 else:
458 assert output.size(-1) == input.size(-1), (
459 f"{activation.value} expects equal sizes: "
460 f"{output.size(-1)} vs {input.size(-1)}"
461 )
463 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI):
464 N = output.size(-1)
465 x, y = input[:, :N], input[:, N:]
466 _silu_and_mul_kernel(x, y, out0=output)
467 elif activation == MoEActivation.GELU:
468 N = output.size(-1)
469 gate, up = input[:, :N], input[:, N:]
470 output.copy_(F.gelu(gate) * up)
471 elif activation == MoEActivation.SWIGLUSTEP:
472 N = output.size(-1)
473 gate, up = input[:, :N], input[:, N:]
474 output.copy_(torch.sigmoid(gate) * up)
475 elif activation == MoEActivation.RELU2:
476 N = output.size(-1)
477 gate, up = input[:, :N], input[:, N:]
478 output.copy_(F.relu(gate).square() * up)
480 elif activation == MoEActivation.SILU_NO_MUL:
481 output.copy_(F.silu(input))
482 elif activation == MoEActivation.GELU_NO_MUL:
483 output.copy_(F.gelu(input))
484 elif activation == MoEActivation.RELU2_NO_MUL:
485 F.relu(input, inplace=True)
486 torch.square(input, out=output)
487 else:
488 raise ValueError(f"Unsupported FusedMoe activation: {activation}")
490 return output
493def _fp8_quantize(
494 A: torch.Tensor,
495 A_scale: Optional[torch.Tensor],
496 per_act_token: bool,
497 block_shape: Optional[list[int]] = None,
498) -> tuple[torch.Tensor, torch.Tensor]:
499 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise."""
500 fp8_dtype = torch.float8_e4m3fn
501 finfo = torch.finfo(fp8_dtype)
502 fp8_max = finfo.max
503 fp8_min = finfo.min
504 eps = 1e-10
506 if block_shape is not None:
507 assert not per_act_token
508 assert len(block_shape) == 2
509 block_k = block_shape[1]
510 assert A.size(-1) % block_k == 0
511 orig_shape = A.shape
512 A_flat = A.reshape(-1, A.size(-1))
513 M, K = A_flat.shape
514 A_groups = A_flat.reshape(M * (K // block_k), block_k)
515 amax = (
516 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
517 )
518 scale = amax / fp8_max
519 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
520 A_q = A_q.reshape(orig_shape)
521 scale = scale.reshape(M, K // block_k)
522 return A_q, scale
524 elif per_act_token:
525 A_flat = A.reshape(-1, A.size(-1))
526 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
527 scale = amax / fp8_max
528 min_scale = torch.tensor(
529 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device
530 )
531 scale = scale.clamp(min=min_scale)
532 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
533 A_q = A_q.reshape(A.shape)
534 scale = scale.reshape(A.shape[:-1] + (1,))
535 return A_q, scale
537 else:
538 if A_scale is not None:
539 scale = (
540 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
541 )
542 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
543 return A_q, A_scale
544 else:
545 amax = A.abs().amax().clamp(min=eps).to(torch.float32)
546 scale = amax / fp8_max
547 iscale = 1.0 / scale
548 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype)
549 return A_q, scale.view(1)
552def _int8_quantize(
553 A: torch.Tensor,
554 A_scale: Optional[torch.Tensor],
555 per_act_token: bool,
556 block_shape: Optional[list[int]] = None,
557) -> tuple[torch.Tensor, torch.Tensor]:
558 """INT8 quantization: per-tensor, per-token, or block-wise."""
559 iinfo = torch.iinfo(torch.int8)
560 int8_max = iinfo.max
561 int8_min = iinfo.min
562 eps = 1e-10
564 if block_shape is not None:
565 assert not per_act_token
566 assert len(block_shape) == 2
567 block_k = block_shape[1]
568 assert A.size(-1) % block_k == 0
569 orig_shape = A.shape
570 A_flat = A.reshape(-1, A.size(-1))
571 M, K = A_flat.shape
572 A_groups = A_flat.reshape(M * (K // block_k), block_k)
573 amax = (
574 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
575 )
576 scale = amax / int8_max
577 A_q = (
578 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
579 )
580 A_q = A_q.reshape(orig_shape)
581 scale = scale.reshape(M, K // block_k)
582 return A_q, scale
584 elif per_act_token:
585 A_flat = A.reshape(-1, A.size(-1))
586 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
587 scale = amax / int8_max
588 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
589 A_q = A_q.reshape(A.shape)
590 scale = scale.reshape(A.shape[:-1] + (1,))
591 return A_q, scale
593 else:
594 assert A_scale is not None, "int8 per-tensor requires A_scale"
595 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
596 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
597 return A_q, A_scale
600def moe_kernel_quantize_input(
601 A: torch.Tensor,
602 A_scale: Optional[torch.Tensor],
603 quant_dtype: None | torch.dtype | str,
604 per_act_token_quant: bool,
605 block_shape: Optional[list[int]] = None,
606 ocp_mx_scheme: str | None = None,
607) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
608 """Quantize MoE input activations before GEMM."""
609 if ocp_mx_scheme is not None:
610 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
611 pass
612 elif ocp_mx_scheme.endswith("a_fp8"):
613 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False)
614 A = (qA.float() * qA_scale.float()).to(A.dtype)
615 return A, None
617 if quant_dtype is None:
618 return A, A_scale
619 elif quant_dtype == torch.float8_e4m3fn:
620 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
621 elif quant_dtype == torch.int8:
622 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
623 else:
624 return A, A_scale
627def _ensure_block_size_k_divisible(
628 size_k: int, block_size_k: int, group_size: int
629) -> int:
630 """Find largest block_size_k that divides size_k and is divisible by group_size."""
631 if size_k % block_size_k == 0 and block_size_k % group_size == 0:
632 return block_size_k
634 max_search = min(block_size_k, size_k)
635 start = (max_search // group_size) * group_size
636 for candidate in range(start, group_size - 1, -group_size):
637 if size_k % candidate == 0:
638 return candidate
640 if size_k % group_size == 0:
641 return group_size
643 return size_k
646@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
647@triton.jit
648def _silu_and_mul_kernel(x, y):
649 x_fp32 = x.to(tl.float32)
650 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
651 return x_silu * y
654@triton.jit
655def write_zeros_to_output(
656 c_ptr,
657 stride_cm,
658 stride_cn,
659 pid_n,
660 N,
661 offs_token,
662 token_mask,
663 BLOCK_SIZE_M,
664 BLOCK_SIZE_N,
665 compute_type,
666):
667 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
668 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
669 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
670 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
671 tl.store(c_ptrs, accumulator, mask=c_mask)
674@triton.jit
675def fused_moe_kernel_gptq_awq(
676 # Pointers to matrices
677 a_ptr,
678 b_ptr,
679 c_ptr,
680 b_scale_ptr,
681 b_zp_ptr,
682 topk_weights_ptr,
683 sorted_token_ids_ptr,
684 expert_ids_ptr,
685 num_tokens_post_padded_ptr,
686 # Matrix dimensions
687 N: tl.constexpr,
688 K: tl.constexpr,
689 EM,
690 num_valid_tokens,
691 # The stride variables represent how much to increase the ptr by when
692 # moving by 1 element in a particular dimension. E.g. `stride_am` is
693 # how much to increase `a_ptr` by to get the element one row down
694 # (A has M rows).
695 stride_am,
696 stride_ak,
697 stride_be,
698 stride_bk,
699 stride_bn,
700 stride_cm,
701 stride_cn,
702 stride_bse,
703 stride_bsk,
704 stride_bsn,
705 stride_bze,
706 stride_bzk,
707 stride_bzn,
708 block_k_diviable: tl.constexpr,
709 group_size: tl.constexpr,
710 # Meta-parameters
711 BLOCK_SIZE_M: tl.constexpr,
712 BLOCK_SIZE_N: tl.constexpr,
713 BLOCK_SIZE_K: tl.constexpr,
714 GROUP_SIZE_M: tl.constexpr,
715 SPLIT_K: tl.constexpr,
716 MUL_ROUTED_WEIGHT: tl.constexpr,
717 top_k: tl.constexpr,
718 compute_type: tl.constexpr,
719 has_zp: tl.constexpr,
720 use_int4_w4a16: tl.constexpr,
721 use_int8_w8a16: tl.constexpr,
722):
723 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights."""
724 # Map pid to C block (grouped ordering for L2 reuse)
725 pid = tl.program_id(axis=0)
726 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
727 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
728 num_pid_in_group = GROUP_SIZE_M * num_pid_n
729 group_id = pid // num_pid_in_group
730 first_pid_m = group_id * GROUP_SIZE_M
731 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
732 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
733 pid_n = (pid % num_pid_in_group) // group_size_m
735 # Create pointers for first blocks of A and B
736 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
737 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
738 return
739 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
740 # Cast to int64 to prevent overflow in stride*offset products
741 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
742 token_mask = offs_token < num_valid_tokens
744 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
745 if off_experts == -1:
746 # -----------------------------------------------------------
747 # Write back zeros to the output when the expert is not
748 # in the current expert parallel rank.
749 write_zeros_to_output(
750 c_ptr,
751 stride_cm,
752 stride_cn,
753 pid_n,
754 N,
755 offs_token,
756 token_mask,
757 BLOCK_SIZE_M,
758 BLOCK_SIZE_N,
759 compute_type,
760 )
761 return
763 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
764 offs_k = tl.arange(0, BLOCK_SIZE_K)
765 a_ptrs = a_ptr + (
766 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
767 )
769 if use_int4_w4a16:
770 b_ptrs = (
771 b_ptr
772 + off_experts * stride_be
773 + (offs_k[:, None] // 2) * stride_bk
774 + offs_bn[None, :] * stride_bn
775 )
776 b_shifter = (offs_k[:, None] % 2) * 4
777 elif use_int8_w8a16:
778 b_ptrs = (
779 b_ptr
780 + off_experts * stride_be
781 + offs_k[:, None] * stride_bk
782 + offs_bn[None, :] * stride_bn
783 )
785 if not has_zp and use_int4_w4a16:
786 b_zp_num = 8
787 if not has_zp and use_int8_w8a16:
788 b_zp_num = 128
789 elif has_zp and use_int4_w4a16:
790 b_zp_shifter = (offs_bn[None, :] % 2) * 4
792 # Accumulate C block in fp32
793 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
794 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
795 if not block_k_diviable:
796 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
797 k_other = 0.0
798 else:
799 k_mask = None
800 k_other = None
802 a = tl.load(
803 a_ptrs,
804 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
805 other=0.0,
806 )
807 b = tl.load(b_ptrs)
808 if use_int4_w4a16:
809 b = (b >> b_shifter) & 0xF
811 b_scale_ptrs = (
812 b_scale_ptr
813 + off_experts * stride_bse
814 + offs_bn[None, :] * stride_bsn
815 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
816 )
817 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
818 b_scale = b_scale.to(tl.float32)
820 if has_zp and use_int4_w4a16:
821 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
822 b_zp_ptrs = (
823 b_zp_ptr
824 + off_experts * stride_bze
825 + (offs_bn[None, :] // 2) * stride_bzn
826 + offs_k_true * stride_bzk
827 )
828 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
829 b_zp = (b_zp >> b_zp_shifter) & 0xF
830 b_zp = b_zp.to(tl.float32)
831 elif has_zp and use_int8_w8a16:
832 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
833 b_zp_ptrs = (
834 b_zp_ptr
835 + off_experts * stride_bze
836 + offs_bn[None, :] * stride_bzn
837 + offs_k_true * stride_bzk
838 )
839 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
840 b_zp = b_zp.to(tl.float32)
842 if has_zp:
843 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
844 else:
845 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
846 accumulator = tl.dot(a, b, acc=accumulator)
848 a_ptrs += BLOCK_SIZE_K * stride_ak
849 if use_int4_w4a16:
850 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
851 else:
852 b_ptrs += BLOCK_SIZE_K * stride_bk
854 if MUL_ROUTED_WEIGHT:
855 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
856 accumulator = accumulator * moe_weight[:, None]
858 accumulator = accumulator.to(compute_type)
859 # Write back output
860 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
861 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
862 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
863 tl.store(c_ptrs, accumulator, mask=c_mask)
866@triton.jit
867def fused_moe_kernel(
868 # Pointers to matrices
869 a_ptr,
870 b_ptr,
871 c_ptr,
872 b_bias_ptr,
873 a_scale_ptr,
874 b_scale_ptr,
875 topk_weights_ptr,
876 sorted_token_ids_ptr,
877 expert_ids_ptr,
878 num_tokens_post_padded_ptr,
879 # Matrix dimensions
880 N,
881 K,
882 EM,
883 num_valid_tokens,
884 stride_am,
885 stride_ak,
886 stride_be,
887 stride_bk,
888 stride_bn,
889 stride_cm,
890 stride_cn,
891 stride_asm,
892 stride_ask,
893 stride_bse,
894 stride_bsk,
895 stride_bsn,
896 stride_bbe, # bias expert stride
897 stride_bbn, # bias N stride
898 # Block size for block-wise quantization
899 group_n: tl.constexpr,
900 group_k: tl.constexpr,
901 naive_block_assignment: tl.constexpr,
902 # Meta-parameters
903 BLOCK_SIZE_M: tl.constexpr,
904 BLOCK_SIZE_N: tl.constexpr,
905 BLOCK_SIZE_K: tl.constexpr,
906 GROUP_SIZE_M: tl.constexpr,
907 SPLIT_K: tl.constexpr,
908 MUL_ROUTED_WEIGHT: tl.constexpr,
909 top_k: tl.constexpr,
910 compute_type: tl.constexpr,
911 use_fp8_w8a8: tl.constexpr,
912 use_int8_w8a8: tl.constexpr,
913 use_int8_w8a16: tl.constexpr,
914 per_channel_quant: tl.constexpr,
915 HAS_BIAS: tl.constexpr,
916):
917 """Fused MoE kernel: token × expert GEMM with quantization support."""
918 # Map pid to C block (grouped ordering for L2 reuse)
919 pid = tl.program_id(axis=0)
920 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
921 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
922 num_pid_in_group = GROUP_SIZE_M * num_pid_n
923 group_id = pid // num_pid_in_group
924 first_pid_m = group_id * GROUP_SIZE_M
925 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
926 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
927 pid_n = (pid % num_pid_in_group) // group_size_m
929 # Create pointers for first blocks of A and B
930 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
931 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
932 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
933 return
934 if not naive_block_assignment:
935 offs_token_id = pid_m * BLOCK_SIZE_M + offs
936 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
937 else:
938 offs_token = tl.where(
939 offs == 0,
940 pid_m, # first element = pid_m
941 num_valid_tokens, # remaining elements = constant
942 )
943 offs_token = offs_token.to(tl.int64) # prevent int32 overflow
945 token_mask = offs_token < num_valid_tokens
947 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
948 if off_experts == -1:
949 # Expert not in current EP rank, write zeros
950 write_zeros_to_output(
951 c_ptr,
952 stride_cm,
953 stride_cn,
954 pid_n,
955 N,
956 offs_token,
957 token_mask,
958 BLOCK_SIZE_M,
959 BLOCK_SIZE_N,
960 compute_type,
961 )
962 return
964 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
965 offs_k = tl.arange(0, BLOCK_SIZE_K)
966 a_ptrs = a_ptr + (
967 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
968 )
970 b_ptrs = (
971 b_ptr
972 + off_experts * stride_be
973 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
974 )
975 if use_int8_w8a16:
976 b_scale_ptrs = (
977 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
978 )
979 b_scale = tl.load(b_scale_ptrs)
981 if use_fp8_w8a8 or use_int8_w8a8:
982 if group_k > 0 and group_n > 0: # block-wise
983 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
984 offs_bsn = offs_bn // group_n
985 b_scale_ptrs = (
986 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
987 )
988 elif per_channel_quant: # channel-wise
989 b_scale_ptrs = (
990 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
991 )
992 b_scale = tl.load(b_scale_ptrs)
993 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
994 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
995 else: # tensor-wise
996 a_scale = tl.load(a_scale_ptr)
997 b_scale = tl.load(b_scale_ptr + off_experts)
998 if HAS_BIAS:
999 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
1000 bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
1001 # Accumulate C block in fp32
1002 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1003 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1004 a = tl.load(
1005 a_ptrs,
1006 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
1007 other=0.0,
1008 )
1009 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
1010 if use_int8_w8a16:
1011 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
1012 elif use_fp8_w8a8 or use_int8_w8a8:
1013 if group_k > 0 and group_n > 0:
1014 k_start = k * BLOCK_SIZE_K
1015 offs_ks = k_start // group_k
1016 a_scale = tl.load(
1017 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
1018 )
1019 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
1021 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
1022 else:
1023 if use_fp8_w8a8:
1024 accumulator = tl.dot(a, b, acc=accumulator)
1025 else:
1026 accumulator += tl.dot(a, b)
1027 else:
1028 accumulator += tl.dot(a, b)
1029 a_ptrs += BLOCK_SIZE_K * stride_ak
1030 b_ptrs += BLOCK_SIZE_K * stride_bk
1032 # Dequantization
1033 if use_int8_w8a16:
1034 accumulator = accumulator * b_scale
1035 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
1036 accumulator = accumulator * a_scale * b_scale
1038 if HAS_BIAS:
1039 accumulator += bias[None, :]
1041 # Router weight multiplication (must be in fp32)
1042 if MUL_ROUTED_WEIGHT:
1043 moe_weight = tl.load(
1044 topk_weights_ptr + offs_token,
1045 mask=token_mask,
1046 other=0,
1047 )
1048 accumulator *= moe_weight[:, None]
1050 accumulator = accumulator.to(compute_type)
1052 # Write back output
1053 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1054 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
1055 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
1056 tl.store(c_ptrs, accumulator, mask=c_mask)
1059def invoke_fused_moe_wna16_triton_kernel(
1060 A: torch.Tensor,
1061 B: torch.Tensor,
1062 C: torch.Tensor,
1063 B_scale: torch.Tensor | None,
1064 B_zp: torch.Tensor | None,
1065 topk_weights: torch.Tensor | None,
1066 sorted_token_ids: torch.Tensor,
1067 expert_ids: torch.Tensor,
1068 num_tokens_post_padded: torch.Tensor,
1069 mul_routed_weight: bool,
1070 top_k: int,
1071 config: dict[str, Any],
1072 compute_type: tl.dtype,
1073 use_int8_w8a16: bool,
1074 use_int4_w4a16: bool,
1075 block_shape: list[int] | None,
1076):
1077 assert B_scale is not None and B_scale.ndim == 3
1078 assert B_zp is None or B_zp.ndim == 3
1079 assert block_shape is not None and block_shape[0] == 0
1081 M = A.size(0)
1082 num_tokens = M * top_k
1084 EM = sorted_token_ids.size(0)
1085 if A.size(0) < config["BLOCK_SIZE_M"]:
1086 # optimize for small batch_size.
1087 # We assume that top_ids of each token is unique,
1088 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
1089 # and we can skip some invalid blocks.
1090 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
1091 grid = lambda META: (
1092 triton.cdiv(EM, META["BLOCK_SIZE_M"])
1093 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
1094 )
1095 config = config.copy()
1096 config.update(
1097 get_moe_wna16_block_config(
1098 config=config,
1099 use_moe_wna16_cuda=False,
1100 num_valid_tokens=num_tokens,
1101 size_k=A.size(1),
1102 size_n=B.size(1),
1103 num_experts=B.size(1),
1104 group_size=block_shape[1],
1105 real_top_k=top_k,
1106 block_size_m=config["BLOCK_SIZE_M"],
1107 )
1108 )
1110 fused_moe_kernel_gptq_awq[grid](
1111 A,
1112 B,
1113 C,
1114 B_scale,
1115 B_zp,
1116 topk_weights,
1117 sorted_token_ids,
1118 expert_ids,
1119 num_tokens_post_padded,
1120 B.size(1),
1121 A.size(1),
1122 EM,
1123 num_tokens,
1124 A.stride(0),
1125 A.stride(1),
1126 B.stride(0),
1127 B.stride(2),
1128 B.stride(1),
1129 C.stride(1),
1130 C.stride(2),
1131 B_scale.stride(0),
1132 B_scale.stride(2),
1133 B_scale.stride(1),
1134 B_zp.stride(0) if B_zp is not None else 0,
1135 B_zp.stride(2) if B_zp is not None else 0,
1136 B_zp.stride(1) if B_zp is not None else 0,
1137 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
1138 group_size=block_shape[1],
1139 MUL_ROUTED_WEIGHT=mul_routed_weight,
1140 top_k=top_k,
1141 compute_type=compute_type,
1142 has_zp=B_zp is not None,
1143 use_int4_w4a16=use_int4_w4a16,
1144 use_int8_w8a16=use_int8_w8a16,
1145 **config,
1146 )
1149def invoke_fused_moe_triton_kernel(
1150 A: torch.Tensor,
1151 B: torch.Tensor,
1152 C: torch.Tensor,
1153 A_scale: Optional[torch.Tensor],
1154 B_scale: Optional[torch.Tensor],
1155 topk_weights: Optional[torch.Tensor],
1156 sorted_token_ids: torch.Tensor,
1157 expert_ids: torch.Tensor,
1158 num_tokens_post_padded: torch.Tensor,
1159 mul_routed_weight: bool,
1160 top_k: int,
1161 config: dict[str, Any],
1162 compute_type: tl.dtype,
1163 use_fp8_w8a8: bool = False,
1164 use_int8_w8a8: bool = False,
1165 use_int8_w8a16: bool = False,
1166 use_int4_w4a16: bool = False,
1167 per_channel_quant: bool = False,
1168 block_shape: Optional[list[int]] = None,
1169 B_bias: torch.Tensor | None = None,
1170) -> None:
1171 """Launch the fused_moe_kernel Triton kernel."""
1172 assert topk_weights is not None or not mul_routed_weight
1173 assert topk_weights is None or topk_weights.stride(1) == 1
1174 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1176 if use_fp8_w8a8 or use_int8_w8a8:
1177 assert B_scale is not None
1178 assert block_shape is None or triton.cdiv(
1179 B.size(-2), block_shape[0]
1180 ) == B_scale.size(-2)
1181 assert block_shape is None or triton.cdiv(
1182 B.size(-1), block_shape[1]
1183 ) == B_scale.size(-1)
1184 elif use_int8_w8a16 or use_int4_w4a16:
1185 assert B_scale is not None
1186 assert block_shape is None or block_shape[0] == 0
1187 else:
1188 assert A_scale is None
1189 assert B_scale is None
1191 M = A.size(0)
1192 num_tokens = M * top_k
1193 if sorted_token_ids is not None:
1194 EM = sorted_token_ids.size(0)
1195 if A.size(0) < config["BLOCK_SIZE_M"]:
1196 EM = min(
1197 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]
1198 )
1199 else:
1200 EM = num_tokens * config["BLOCK_SIZE_M"]
1201 grid = lambda META: (
1202 triton.cdiv(EM, META["BLOCK_SIZE_M"])
1203 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
1204 )
1205 HAS_BIAS = B_bias is not None
1207 config = config.copy()
1208 config["SPLIT_K"] = 1
1209 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
1210 if block_shape is not None:
1211 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
1213 fused_moe_kernel[grid](
1214 A,
1215 B,
1216 C,
1217 B_bias,
1218 A_scale,
1219 B_scale,
1220 topk_weights,
1221 sorted_token_ids,
1222 expert_ids,
1223 num_tokens_post_padded,
1224 B.size(1), # N
1225 B.size(2), # K
1226 EM,
1227 num_tokens,
1228 A.stride(0),
1229 A.stride(1),
1230 B.stride(0),
1231 B.stride(2),
1232 B.stride(1),
1233 C.stride(1),
1234 C.stride(2),
1235 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
1236 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
1237 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
1238 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
1239 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
1240 B_bias.stride(0) if B_bias is not None else 0,
1241 B_bias.stride(1) if B_bias is not None else 0,
1242 0 if block_shape is None else block_shape[0],
1243 0 if block_shape is None else block_shape[1],
1244 MUL_ROUTED_WEIGHT=mul_routed_weight,
1245 top_k=top_k,
1246 compute_type=compute_type,
1247 use_fp8_w8a8=use_fp8_w8a8,
1248 use_int8_w8a8=use_int8_w8a8,
1249 use_int8_w8a16=use_int8_w8a16,
1250 per_channel_quant=per_channel_quant,
1251 naive_block_assignment=(sorted_token_ids is None),
1252 HAS_BIAS=HAS_BIAS,
1253 BLOCK_SIZE_K=BLOCK_SIZE_K,
1254 **config,
1255 )
1258def dispatch_fused_moe_kernel(
1259 A: torch.Tensor,
1260 B: torch.Tensor,
1261 C: torch.Tensor,
1262 A_scale: Optional[torch.Tensor],
1263 B_scale: Optional[torch.Tensor],
1264 B_zp: Optional[torch.Tensor],
1265 topk_weights: Optional[torch.Tensor],
1266 sorted_token_ids: torch.Tensor,
1267 expert_ids: torch.Tensor,
1268 num_tokens_post_padded: torch.Tensor,
1269 mul_routed_weight: bool,
1270 top_k: int,
1271 config: dict[str, Any],
1272 compute_type: tl.dtype,
1273 use_fp8_w8a8: bool,
1274 use_int8_w8a8: bool,
1275 use_int8_w8a16: bool,
1276 use_int4_w4a16: bool,
1277 per_channel_quant: bool,
1278 block_shape: Optional[list[int]] = None,
1279 B_bias: Optional[torch.Tensor] = None,
1280) -> None:
1281 """Dispatch to the appropriate fused MoE kernel based on quantization flags."""
1282 assert topk_weights is not None or not mul_routed_weight
1283 assert topk_weights is None or topk_weights.stride(1) == 1
1284 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1286 # M = A.size(0)
1287 # num_tokens = M * top_k
1289 if False:
1290 # TODO: Other precision-specific implementations
1291 # use_fp8_w8a8,
1292 # use_int8_w8a8,
1293 # use_int8_w8a16,
1294 # use_int4_w4a16,
1295 pass
1296 if (use_int8_w8a16 or use_int4_w4a16) and (
1297 block_shape is not None and block_shape[1] > 0
1298 ):
1299 assert B_bias is None
1300 invoke_fused_moe_wna16_triton_kernel(
1301 A,
1302 B,
1303 C,
1304 B_scale,
1305 B_zp,
1306 topk_weights,
1307 sorted_token_ids,
1308 expert_ids,
1309 num_tokens_post_padded,
1310 mul_routed_weight,
1311 top_k,
1312 config,
1313 compute_type,
1314 use_int8_w8a16,
1315 use_int4_w4a16,
1316 block_shape,
1317 )
1318 else:
1319 invoke_fused_moe_triton_kernel(
1320 A,
1321 B,
1322 C,
1323 A_scale,
1324 B_scale,
1325 topk_weights,
1326 sorted_token_ids,
1327 expert_ids,
1328 num_tokens_post_padded,
1329 mul_routed_weight,
1330 top_k,
1331 config,
1332 compute_type,
1333 use_fp8_w8a8,
1334 use_int8_w8a8,
1335 use_int8_w8a16,
1336 use_int4_w4a16,
1337 per_channel_quant,
1338 block_shape,
1339 B_bias,
1340 )
1343def fused_experts_impl(
1344 hidden_states: torch.Tensor,
1345 w1: torch.Tensor,
1346 w2: torch.Tensor,
1347 topk_weights: torch.Tensor,
1348 topk_ids: torch.Tensor,
1349 inplace: bool = False,
1350 activation: str = "silu",
1351 apply_router_weight_on_input: bool = False,
1352 use_fp8_w8a8: bool = False,
1353 use_int8_w8a8: bool = False,
1354 use_int8_w8a16: bool = False,
1355 use_int4_w4a16: bool = False,
1356 ocp_mx_scheme: str | None = None,
1357 per_channel_quant: bool = False,
1358 global_num_experts: int = -1,
1359 expert_map: torch.Tensor | None = None,
1360 w1_scale: Optional[torch.Tensor] = None,
1361 w2_scale: Optional[torch.Tensor] = None,
1362 w1_zp: torch.Tensor | None = None,
1363 w2_zp: torch.Tensor | None = None,
1364 a1_scale: Optional[torch.Tensor] = None,
1365 a2_scale: Optional[torch.Tensor] = None,
1366 block_shape: Optional[list[int]] = None,
1367 w1_bias: Optional[torch.Tensor] = None,
1368 w2_bias: Optional[torch.Tensor] = None,
1369) -> torch.Tensor:
1370 logger.debug("GEMS FUSED MOE")
1371 assert (
1372 activation == "silu"
1373 ), f"Only 'silu' activation is supported, got {activation}"
1375 activation_enum = MoEActivation.from_str(activation)
1377 # Check constraints
1378 if use_int4_w4a16:
1379 # INT4 stored unpacked in INT8 containers (full K dim)
1380 assert hidden_states.size(1) == w1.size(
1381 2
1382 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1383 elif ocp_mx_scheme is not None:
1384 if ocp_mx_scheme.startswith("w_mxfp4"):
1385 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
1386 elif ocp_mx_scheme.startswith("w_mxfp6"):
1387 assert (
1388 hidden_states.size(1) == (w1.size(2) * 4) // 3
1389 ), "hidden size mismatch"
1390 else:
1391 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1392 else:
1393 assert hidden_states.size(1) == w1.size(
1394 2
1395 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1397 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
1398 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1399 assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
1400 assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
1401 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
1403 num_tokens = hidden_states.size(0)
1404 E, N, _ = w1.size()
1405 K = w2.size(1)
1406 if global_num_experts == -1:
1407 global_num_experts = E
1408 top_k_num = topk_ids.size(1)
1410 CHUNK_SIZE: int = 16 * 1024
1411 M = min(num_tokens, CHUNK_SIZE)
1413 config_dtype = _get_config_dtype_str(
1414 use_fp8_w8a8=use_fp8_w8a8,
1415 use_int8_w8a16=use_int8_w8a16,
1416 use_int4_w4a16=use_int4_w4a16,
1417 ocp_mx_scheme=ocp_mx_scheme,
1418 dtype=hidden_states.dtype,
1419 )
1421 quant_dtype = _get_config_quant_dtype(
1422 use_fp8_w8a8=use_fp8_w8a8,
1423 use_int8_w8a8=use_int8_w8a8,
1424 ocp_mx_scheme=ocp_mx_scheme,
1425 )
1427 get_config_func = functools.partial(
1428 try_get_optimal_moe_config,
1429 w1.size(),
1430 w2.size(),
1431 top_k_num,
1432 config_dtype,
1433 block_shape=block_shape,
1434 )
1436 config = get_config_func(M)
1438 # cache1 and cache3 share memory (non-overlapping lifetime)
1439 cache13 = torch.empty(
1440 M * top_k_num * max(N, K),
1441 device=hidden_states.device,
1442 dtype=hidden_states.dtype,
1443 )
1444 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
1445 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
1447 # cache2 needs separate memory (concurrent with cache1)
1448 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum)
1449 intermediate_cache2 = torch.empty(
1450 (M * top_k_num, activation_out_dim),
1451 device=hidden_states.device,
1452 dtype=hidden_states.dtype,
1453 )
1455 if hidden_states.dtype == torch.bfloat16:
1456 compute_type = tl.bfloat16
1457 elif hidden_states.dtype == torch.float16:
1458 compute_type = tl.float16
1459 elif hidden_states.dtype == torch.float32:
1460 compute_type = tl.float32
1461 else:
1462 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1464 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
1466 if ocp_mx_scheme is not None:
1467 # Dequantize OCP MX weights (TODO: skip on platforms with native MX)
1468 if ocp_mx_scheme.startswith("w_mxfp4"):
1469 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
1470 w1_scale = None
1471 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
1472 w2_scale = None
1473 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
1474 w1 = dequant_mxfp6(
1475 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1476 )
1477 w1_scale = None
1478 w2 = dequant_mxfp6(
1479 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1480 )
1481 w2_scale = None
1482 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
1483 w1 = dequant_mxfp6(
1484 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1485 )
1486 w1_scale = None
1487 w2 = dequant_mxfp6(
1488 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1489 )
1490 w2_scale = None
1491 else:
1492 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1494 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot)
1495 if use_int8_w8a16 or use_int4_w4a16:
1496 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype)
1497 w1_scale = None
1498 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype)
1499 w2_scale = None
1500 use_int8_w8a16 = False
1501 use_int4_w4a16 = False
1503 for chunk in range((num_tokens // CHUNK_SIZE) + 1):
1504 begin_chunk_idx, end_chunk_idx = (
1505 chunk * CHUNK_SIZE,
1506 min((chunk + 1) * CHUNK_SIZE, num_tokens),
1507 )
1508 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
1509 tokens_in_chunk, _ = curr_hidden_states.size()
1511 if tokens_in_chunk == 0:
1512 break
1514 if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
1515 # Adjust cache size for last chunk
1516 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
1517 intermediate_cache2 = intermediate_cache2[
1518 : tokens_in_chunk * topk_ids.size(1)
1519 ]
1520 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
1521 config = get_config_func(tokens_in_chunk)
1523 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
1524 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
1525 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
1526 A=curr_hidden_states,
1527 A_scale=a1_scale,
1528 quant_dtype=quant_dtype,
1529 per_act_token_quant=per_channel_quant,
1530 block_shape=block_shape,
1531 ocp_mx_scheme=ocp_mx_scheme,
1532 )
1534 SPARSITY_FACTOR = 4
1535 naive_block_assignment = (
1536 expert_map is None
1537 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
1538 and not (
1539 (use_int8_w8a16 or use_int4_w4a16)
1540 and block_shape is not None
1541 and block_shape[1] > 0
1542 )
1543 )
1545 if not naive_block_assignment:
1546 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
1547 curr_topk_ids,
1548 config["BLOCK_SIZE_M"],
1549 global_num_experts,
1550 expert_map,
1551 # ignore_invalid_experts=True,
1552 )
1553 else:
1554 max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
1555 expert_ids = curr_topk_ids.view(-1)
1556 num_tokens_post_padded = torch.empty(
1557 (1), dtype=torch.int32, device=topk_ids.device
1558 )
1559 num_tokens_post_padded.fill_(max_num_tokens_padded)
1560 sorted_token_ids = None
1562 dispatch_fused_moe_kernel(
1563 qcurr_hidden_states,
1564 w1,
1565 intermediate_cache1,
1566 a1q_scale,
1567 w1_scale,
1568 w1_zp,
1569 curr_topk_weights,
1570 sorted_token_ids,
1571 expert_ids,
1572 num_tokens_post_padded,
1573 apply_router_weight_on_input,
1574 top_k_num,
1575 config,
1576 compute_type=compute_type,
1577 use_fp8_w8a8=use_fp8_w8a8,
1578 use_int8_w8a8=use_int8_w8a8,
1579 use_int8_w8a16=use_int8_w8a16,
1580 use_int4_w4a16=use_int4_w4a16,
1581 per_channel_quant=per_channel_quant,
1582 block_shape=block_shape,
1583 B_bias=w1_bias,
1584 )
1586 apply_moe_activation(
1587 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
1588 )
1590 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
1591 A=intermediate_cache2,
1592 A_scale=a2_scale,
1593 quant_dtype=quant_dtype,
1594 per_act_token_quant=per_channel_quant,
1595 block_shape=block_shape,
1596 ocp_mx_scheme=ocp_mx_scheme,
1597 )
1599 if expert_map is not None:
1600 intermediate_cache3.zero_()
1602 dispatch_fused_moe_kernel(
1603 qintermediate_cache2,
1604 w2,
1605 intermediate_cache3,
1606 a2q_scale,
1607 w2_scale,
1608 w2_zp,
1609 curr_topk_weights,
1610 sorted_token_ids,
1611 expert_ids,
1612 num_tokens_post_padded,
1613 not apply_router_weight_on_input,
1614 1,
1615 config,
1616 compute_type=compute_type,
1617 use_fp8_w8a8=use_fp8_w8a8,
1618 use_int8_w8a8=use_int8_w8a8,
1619 use_int8_w8a16=use_int8_w8a16,
1620 use_int4_w4a16=use_int4_w4a16,
1621 per_channel_quant=per_channel_quant,
1622 block_shape=block_shape,
1623 B_bias=w2_bias,
1624 )
1626 moe_sum(
1627 intermediate_cache3.view(*intermediate_cache3.size()),
1628 out_hidden_states[begin_chunk_idx:end_chunk_idx],
1629 )
1631 return out_hidden_states
1634def inplace_fused_experts(
1635 hidden_states: torch.Tensor,
1636 w1: torch.Tensor,
1637 w2: torch.Tensor,
1638 topk_weights: torch.Tensor,
1639 topk_ids: torch.Tensor,
1640 activation: str = "silu",
1641 apply_router_weight_on_input: bool = False,
1642 use_fp8_w8a8: bool = False,
1643 use_int8_w8a8: bool = False,
1644 use_int8_w8a16: bool = False,
1645 use_int4_w4a16: bool = False,
1646 per_channel_quant: bool = False,
1647 global_num_experts: int = -1,
1648 w1_scale: Optional[torch.Tensor] = None,
1649 w2_scale: Optional[torch.Tensor] = None,
1650 a1_scale: Optional[torch.Tensor] = None,
1651 a2_scale: Optional[torch.Tensor] = None,
1652 block_shape: Optional[list[int]] = None,
1653 w1_bias: Optional[torch.Tensor] = None,
1654 w2_bias: Optional[torch.Tensor] = None,
1655) -> None:
1656 """
1657 In-place fused MoE: writes output directly into ``hidden_states``.
1659 Same semantics as ``fused_experts_impl(..., inplace=True)``.
1660 Returns None (the result is stored in ``hidden_states``).
1661 """
1662 fused_experts_impl(
1663 hidden_states,
1664 w1,
1665 w2,
1666 topk_weights,
1667 topk_ids,
1668 inplace=True,
1669 activation=activation,
1670 apply_router_weight_on_input=apply_router_weight_on_input,
1671 use_fp8_w8a8=use_fp8_w8a8,
1672 use_int8_w8a8=use_int8_w8a8,
1673 use_int8_w8a16=use_int8_w8a16,
1674 use_int4_w4a16=use_int4_w4a16,
1675 per_channel_quant=per_channel_quant,
1676 global_num_experts=global_num_experts,
1677 w1_scale=w1_scale,
1678 w2_scale=w2_scale,
1679 a1_scale=a1_scale,
1680 a2_scale=a2_scale,
1681 block_shape=block_shape,
1682 w1_bias=w1_bias,
1683 w2_bias=w2_bias,
1684 )
1687def outplace_fused_experts(
1688 hidden_states: torch.Tensor,
1689 w1: torch.Tensor,
1690 w2: torch.Tensor,
1691 topk_weights: torch.Tensor,
1692 topk_ids: torch.Tensor,
1693 activation: str = "silu",
1694 apply_router_weight_on_input: bool = False,
1695 use_fp8_w8a8: bool = False,
1696 use_int8_w8a8: bool = False,
1697 use_int8_w8a16: bool = False,
1698 use_int4_w4a16: bool = False,
1699 per_channel_quant: bool = False,
1700 global_num_experts: int = -1,
1701 w1_scale: Optional[torch.Tensor] = None,
1702 w2_scale: Optional[torch.Tensor] = None,
1703 a1_scale: Optional[torch.Tensor] = None,
1704 a2_scale: Optional[torch.Tensor] = None,
1705 block_shape: Optional[list[int]] = None,
1706 w1_bias: Optional[torch.Tensor] = None,
1707 w2_bias: Optional[torch.Tensor] = None,
1708) -> torch.Tensor:
1709 """
1710 Out-of-place fused MoE: allocates and returns a new output tensor.
1712 Same semantics as ``fused_experts_impl(..., inplace=False)``.
1713 """
1714 return fused_experts_impl(
1715 hidden_states,
1716 w1,
1717 w2,
1718 topk_weights,
1719 topk_ids,
1720 inplace=False,
1721 activation=activation,
1722 apply_router_weight_on_input=apply_router_weight_on_input,
1723 use_fp8_w8a8=use_fp8_w8a8,
1724 use_int8_w8a8=use_int8_w8a8,
1725 use_int8_w8a16=use_int8_w8a16,
1726 use_int4_w4a16=use_int4_w4a16,
1727 per_channel_quant=per_channel_quant,
1728 global_num_experts=global_num_experts,
1729 w1_scale=w1_scale,
1730 w2_scale=w2_scale,
1731 a1_scale=a1_scale,
1732 a2_scale=a2_scale,
1733 block_shape=block_shape,
1734 w1_bias=w1_bias,
1735 w2_bias=w2_bias,
1736 )