Coverage for src/flag_gems/fused/fused_moe.py: 40%
614 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3#
4# Adapted from the vLLM project (https://github.com/vllm-project/vllm).
5# Source files under vllm/model_executor/layers/:
6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl
7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation
8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input
9# fused_moe/config.py – _get_config_dtype_str
10# quantization/utils/mxfp4_utils.py – dequant_mxfp4
11# quantization/utils/mxfp6_utils.py – dequant_mxfp6
12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE
15import functools
16import logging
17from enum import Enum
18from typing import Any, Optional
20import torch
21import torch.nn.functional as F
22import triton
23import triton.language as tl
25from flag_gems.fused.moe_align_block_size import moe_align_block_size
26from flag_gems.fused.moe_sum import moe_sum
27from flag_gems.utils import pointwise_dynamic
29logger = logging.getLogger(__name__)
31# OCP MX quantization helpers (requires amd-quark)
33OCP_MX_BLOCK_SIZE = 32
36def dequant_mxfp4(
37 x: torch.Tensor,
38 scale: torch.Tensor,
39 float_dtype: torch.dtype,
40) -> torch.Tensor:
41 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4."""
42 try:
43 from quark.torch.kernel import mx
44 except ImportError as err:
45 raise ImportError("amd-quark is required for MX-FP4") from err
47 return mx.dq_mxfp4(x, scale, float_dtype)
50def dequant_mxfp6(
51 x: torch.Tensor,
52 scale: torch.Tensor,
53 float_dtype: torch.dtype,
54 quant_dtype: str,
55) -> torch.Tensor:
56 """Dequantize MXFP6 tensor via quark hw_emulation."""
57 try:
58 from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
59 dequantize_fp4_fp6_per_group,
60 )
61 from quark.torch.utils.pack import create_pack_method
62 except ImportError as err:
63 raise ImportError("amd-quark is required for MX-FP6") from err
65 pack_method = create_pack_method(None, dtype=quant_dtype)
66 unpacked_x = pack_method.unpack(x, reorder=False)
68 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype)
70 return dequantize_fp4_fp6_per_group(
71 unpacked_x,
72 scale,
73 axis=-1,
74 group_size=OCP_MX_BLOCK_SIZE,
75 quant_dtype=quant_dtype,
76 ).to(float_dtype)
79# Activation quantization helpers
82def get_moe_configs(
83 E: int,
84 N: int,
85 dtype: str | None,
86 block_n: int | None = None,
87 block_k: int | None = None,
88) -> dict[int, Any] | None:
89 """Return None; FlagGems uses get_default_config instead."""
90 return None
93def try_get_optimal_moe_config(
94 w1_shape: tuple[int, ...],
95 w2_shape: tuple[int, ...],
96 top_k: int,
97 dtype: str | None,
98 M: int,
99 block_shape: list[int] | None = None,
100) -> dict[str, int]:
101 override_config: Optional[dict[str, Any]] = None
102 if override_config:
103 config = override_config
104 else:
105 # First try to load optimal config from the file
106 E, _, N = w2_shape
107 if dtype == "int4_w4a16":
108 N = N * 2
109 block_n = block_shape[0] if block_shape else 0
110 block_k = block_shape[1] if block_shape else 0
111 configs = get_moe_configs(E, N, dtype, block_n, block_k)
113 if configs:
114 config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
115 else:
116 config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape)
117 return config
120def _get_config_quant_dtype(
121 use_fp8_w8a8: bool,
122 use_int8_w8a8: bool,
123 ocp_mx_scheme: str | None,
124) -> None | torch.dtype | str:
125 """Map quantization flags to the corresponding dtype."""
126 if use_fp8_w8a8:
127 return torch.float8_e4m3fn
128 elif use_int8_w8a8:
129 return torch.int8
130 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
131 return "mxfp4"
132 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}:
133 return "mxfp6_e3m2"
134 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
135 return "mxfp6_e2m3"
136 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}:
137 return torch.bfloat16
138 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}:
139 return torch.float8_e4m3fn
141 return None
144def get_moe_wna16_block_config(
145 config: dict[str, int],
146 use_moe_wna16_cuda: bool,
147 num_valid_tokens: int,
148 size_k: int,
149 size_n: int,
150 num_experts: int,
151 group_size: int,
152 real_top_k: int,
153 block_size_m: int,
154):
155 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
156 return {}
157 if not use_moe_wna16_cuda:
158 if num_valid_tokens // real_top_k == 1:
159 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
160 else:
161 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
162 else:
163 block_size_n = 128
164 block_size_k = 128
165 if block_size_k <= group_size:
166 block_size_k = group_size
168 num_n_blocks = size_k // block_size_k
169 num_k_blocks = size_n // block_size_k
170 num_m_blocks = (
171 num_valid_tokens + block_size_m - 1
172 ) / block_size_m + num_experts
173 if num_valid_tokens // real_top_k <= block_size_m:
174 num_m_blocks = min(num_m_blocks, num_valid_tokens)
175 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
177 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256:
178 block_size_k = 256
179 num_blocks = num_blocks // (256 // block_size_k)
181 if (
182 num_m_blocks <= 16
183 and size_k % (block_size_k * 2) == 0
184 and size_k % (block_size_k * 2) == 0
185 and block_size_k <= 512
186 and num_blocks >= 512
187 ):
188 block_size_k = block_size_k * 2
189 num_blocks = num_blocks // 2
191 if num_blocks > 1024:
192 block_size_n = 256
193 num_n_blocks = num_n_blocks // 2
194 num_blocks = num_blocks // 2
196 if size_n <= 1024 and num_blocks >= 1024:
197 block_size_n = 1024
199 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
201 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
204def get_default_config(
205 M: int,
206 E: int,
207 N: int,
208 K: int,
209 topk: int,
210 dtype: str | None,
211 block_shape: list[int] | None = None,
212) -> dict[str, int]:
213 """Default Triton config for fused MoE kernel."""
214 if dtype == "fp8_w8a8" and block_shape is not None:
215 config = {
216 "BLOCK_SIZE_M": 16 if M <= 64 else 64,
217 "BLOCK_SIZE_N": block_shape[0],
218 "BLOCK_SIZE_K": block_shape[1],
219 "GROUP_SIZE_M": 1 if M <= 16 else 32,
220 "num_warps": 4,
221 "num_stages": 3,
222 }
223 else:
224 if M <= 32:
225 block_m = 16
226 elif M <= 96:
227 block_m = 32
228 elif M <= 512:
229 block_m = 64
230 else:
231 block_m = 128
233 # Tile sizing for H100/H800
234 if N >= 4096:
235 block_n = 128 if M <= 128 else 256
236 elif N >= 1024:
237 block_n = 64 if M <= 64 else 128
238 else:
239 block_n = 64 if M <= 64 else 128
241 if dtype == "fp8_w8a8":
242 block_k = 128
243 elif K >= 4096 or M <= 64:
244 block_k = 128
245 else:
246 block_k = 64
248 tokens_per_expert = (M * topk) // max(E, 1)
249 if tokens_per_expert > 128:
250 group_m = 16
251 elif tokens_per_expert > 32:
252 group_m = 8
253 else:
254 group_m = 1
256 num_warps = 4 if block_m * block_n < 8192 else 8
257 num_stages = 3
259 smem_per_stage = (block_m * block_k + block_k * block_n) * 2
260 while num_stages > 2 and smem_per_stage * num_stages > 200_000:
261 num_stages -= 1
263 config = {
264 "BLOCK_SIZE_M": block_m,
265 "BLOCK_SIZE_N": block_n,
266 "BLOCK_SIZE_K": block_k,
267 "GROUP_SIZE_M": group_m,
268 "num_warps": num_warps,
269 "num_stages": num_stages,
270 }
271 return config
274def _get_config_dtype_str(
275 dtype: Optional[torch.dtype] = None,
276 use_fp8_w8a8: bool = False,
277 use_fp8_w8a16: bool = False,
278 use_int8_w8a16: bool = False,
279 use_int4_w4a16: bool = False,
280 ocp_mx_scheme: str | None = None,
281) -> str | None:
282 """Return dtype string for kernel config lookup."""
283 if use_fp8_w8a8:
284 return "fp8_w8a8"
285 elif use_fp8_w8a16:
286 return "fp8_w8a16"
287 elif use_int8_w8a16:
288 return "int8_w8a16"
289 elif use_int4_w4a16:
290 return "int4_w4a16"
291 elif ocp_mx_scheme is not None:
292 return None
293 elif dtype == torch.float:
294 return "float32"
295 return None
298# MoE activation enum
301class MoEActivation(Enum):
302 """Activation functions for MoE layers."""
304 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d]
305 SILU = "silu"
306 GELU = "gelu"
307 RELU2 = "relu2"
308 SWIGLUOAI = "swigluoai"
309 SWIGLUSTEP = "swiglustep"
311 # Non-gated: input [..., d] -> output [..., d]
312 SILU_NO_MUL = "silu_no_mul"
313 GELU_NO_MUL = "gelu_no_mul"
314 RELU2_NO_MUL = "relu2_no_mul"
316 @property
317 def is_gated(self) -> bool:
318 return not self.value.endswith("_no_mul")
320 def without_mul(self) -> "MoEActivation":
321 """Return the non-gated variant."""
322 _without_mul: dict[MoEActivation, MoEActivation] = {
323 MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
324 MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
325 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
326 }
327 return _without_mul.get(self, self)
329 @classmethod
330 def from_str(cls, s: str) -> "MoEActivation":
331 for member in cls:
332 if member.value == s:
333 return member
334 valid = [m.value for m in cls]
335 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
337 @staticmethod
338 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int:
339 """Return N for non-gated, N // 2 for gated activations."""
340 return N if not activation.is_gated else N // 2
343def apply_moe_activation(
344 activation: MoEActivation,
345 output: torch.Tensor,
346 input: torch.Tensor,
347) -> torch.Tensor:
348 """Apply MoE activation (pure PyTorch / FlagGems Triton)."""
349 assert input.dim() == 2, "Input must be 2D"
350 assert output.dim() == 2, "Output must be 2D"
351 if activation.is_gated:
352 assert output.size(-1) * 2 == input.size(-1), (
353 f"{activation.value} expects 2x ratio: "
354 f"{output.size(-1) * 2} vs {input.size(-1)}"
355 )
356 else:
357 assert output.size(-1) == input.size(-1), (
358 f"{activation.value} expects equal sizes: "
359 f"{output.size(-1)} vs {input.size(-1)}"
360 )
362 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI):
363 N = output.size(-1)
364 x, y = input[:, :N], input[:, N:]
365 _silu_and_mul_kernel(x, y, out0=output)
366 elif activation == MoEActivation.GELU:
367 N = output.size(-1)
368 gate, up = input[:, :N], input[:, N:]
369 output.copy_(F.gelu(gate) * up)
370 elif activation == MoEActivation.SWIGLUSTEP:
371 N = output.size(-1)
372 gate, up = input[:, :N], input[:, N:]
373 output.copy_(torch.sigmoid(gate) * up)
374 elif activation == MoEActivation.RELU2:
375 N = output.size(-1)
376 gate, up = input[:, :N], input[:, N:]
377 output.copy_(F.relu(gate).square() * up)
379 elif activation == MoEActivation.SILU_NO_MUL:
380 output.copy_(F.silu(input))
381 elif activation == MoEActivation.GELU_NO_MUL:
382 output.copy_(F.gelu(input))
383 elif activation == MoEActivation.RELU2_NO_MUL:
384 F.relu(input, inplace=True)
385 torch.square(input, out=output)
386 else:
387 raise ValueError(f"Unsupported FusedMoe activation: {activation}")
389 return output
392def _fp8_quantize(
393 A: torch.Tensor,
394 A_scale: Optional[torch.Tensor],
395 per_act_token: bool,
396 block_shape: Optional[list[int]] = None,
397) -> tuple[torch.Tensor, torch.Tensor]:
398 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise."""
399 fp8_dtype = torch.float8_e4m3fn
400 finfo = torch.finfo(fp8_dtype)
401 fp8_max = finfo.max
402 fp8_min = finfo.min
403 eps = 1e-10
405 if block_shape is not None:
406 assert not per_act_token
407 assert len(block_shape) == 2
408 block_k = block_shape[1]
409 assert A.size(-1) % block_k == 0
410 orig_shape = A.shape
411 A_flat = A.reshape(-1, A.size(-1))
412 M, K = A_flat.shape
413 A_groups = A_flat.reshape(M * (K // block_k), block_k)
414 amax = (
415 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
416 )
417 scale = amax / fp8_max
418 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
419 A_q = A_q.reshape(orig_shape)
420 scale = scale.reshape(M, K // block_k)
421 return A_q, scale
423 elif per_act_token:
424 A_flat = A.reshape(-1, A.size(-1))
425 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
426 scale = amax / fp8_max
427 min_scale = torch.tensor(
428 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device
429 )
430 scale = scale.clamp(min=min_scale)
431 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
432 A_q = A_q.reshape(A.shape)
433 scale = scale.reshape(A.shape[:-1] + (1,))
434 return A_q, scale
436 else:
437 if A_scale is not None:
438 scale = (
439 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
440 )
441 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
442 return A_q, A_scale
443 else:
444 amax = A.abs().amax().clamp(min=eps).to(torch.float32)
445 scale = amax / fp8_max
446 iscale = 1.0 / scale
447 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype)
448 return A_q, scale.view(1)
451def _int8_quantize(
452 A: torch.Tensor,
453 A_scale: Optional[torch.Tensor],
454 per_act_token: bool,
455 block_shape: Optional[list[int]] = None,
456) -> tuple[torch.Tensor, torch.Tensor]:
457 """INT8 quantization: per-tensor, per-token, or block-wise."""
458 iinfo = torch.iinfo(torch.int8)
459 int8_max = iinfo.max
460 int8_min = iinfo.min
461 eps = 1e-10
463 if block_shape is not None:
464 assert not per_act_token
465 assert len(block_shape) == 2
466 block_k = block_shape[1]
467 assert A.size(-1) % block_k == 0
468 orig_shape = A.shape
469 A_flat = A.reshape(-1, A.size(-1))
470 M, K = A_flat.shape
471 A_groups = A_flat.reshape(M * (K // block_k), block_k)
472 amax = (
473 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
474 )
475 scale = amax / int8_max
476 A_q = (
477 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
478 )
479 A_q = A_q.reshape(orig_shape)
480 scale = scale.reshape(M, K // block_k)
481 return A_q, scale
483 elif per_act_token:
484 A_flat = A.reshape(-1, A.size(-1))
485 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
486 scale = amax / int8_max
487 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
488 A_q = A_q.reshape(A.shape)
489 scale = scale.reshape(A.shape[:-1] + (1,))
490 return A_q, scale
492 else:
493 assert A_scale is not None, "int8 per-tensor requires A_scale"
494 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
495 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
496 return A_q, A_scale
499def moe_kernel_quantize_input(
500 A: torch.Tensor,
501 A_scale: Optional[torch.Tensor],
502 quant_dtype: None | torch.dtype | str,
503 per_act_token_quant: bool,
504 block_shape: Optional[list[int]] = None,
505 ocp_mx_scheme: str | None = None,
506) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
507 """Quantize MoE input activations before GEMM."""
508 if ocp_mx_scheme is not None:
509 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
510 pass
511 elif ocp_mx_scheme.endswith("a_fp8"):
512 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False)
513 A = (qA.float() * qA_scale.float()).to(A.dtype)
514 return A, None
516 if quant_dtype is None:
517 return A, A_scale
518 elif quant_dtype == torch.float8_e4m3fn:
519 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
520 elif quant_dtype == torch.int8:
521 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
522 else:
523 return A, A_scale
526def _ensure_block_size_k_divisible(
527 size_k: int, block_size_k: int, group_size: int
528) -> int:
529 """Find largest block_size_k that divides size_k and is divisible by group_size."""
530 if size_k % block_size_k == 0 and block_size_k % group_size == 0:
531 return block_size_k
533 max_search = min(block_size_k, size_k)
534 start = (max_search // group_size) * group_size
535 for candidate in range(start, group_size - 1, -group_size):
536 if size_k % candidate == 0:
537 return candidate
539 if size_k % group_size == 0:
540 return group_size
542 return size_k
545@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
546@triton.jit
547def _silu_and_mul_kernel(x, y):
548 x_fp32 = x.to(tl.float32)
549 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
550 return x_silu * y
553@triton.jit
554def write_zeros_to_output(
555 c_ptr,
556 stride_cm,
557 stride_cn,
558 pid_n,
559 N,
560 offs_token,
561 token_mask,
562 BLOCK_SIZE_M,
563 BLOCK_SIZE_N,
564 compute_type,
565):
566 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
567 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
568 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
569 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
570 tl.store(c_ptrs, accumulator, mask=c_mask)
573@triton.jit
574def fused_moe_kernel_gptq_awq(
575 # Pointers to matrices
576 a_ptr,
577 b_ptr,
578 c_ptr,
579 b_scale_ptr,
580 b_zp_ptr,
581 topk_weights_ptr,
582 sorted_token_ids_ptr,
583 expert_ids_ptr,
584 num_tokens_post_padded_ptr,
585 # Matrix dimensions
586 N: tl.constexpr,
587 K: tl.constexpr,
588 EM,
589 num_valid_tokens,
590 # The stride variables represent how much to increase the ptr by when
591 # moving by 1 element in a particular dimension. E.g. `stride_am` is
592 # how much to increase `a_ptr` by to get the element one row down
593 # (A has M rows).
594 stride_am,
595 stride_ak,
596 stride_be,
597 stride_bk,
598 stride_bn,
599 stride_cm,
600 stride_cn,
601 stride_bse,
602 stride_bsk,
603 stride_bsn,
604 stride_bze,
605 stride_bzk,
606 stride_bzn,
607 block_k_diviable: tl.constexpr,
608 group_size: tl.constexpr,
609 # Meta-parameters
610 BLOCK_SIZE_M: tl.constexpr,
611 BLOCK_SIZE_N: tl.constexpr,
612 BLOCK_SIZE_K: tl.constexpr,
613 GROUP_SIZE_M: tl.constexpr,
614 SPLIT_K: tl.constexpr,
615 MUL_ROUTED_WEIGHT: tl.constexpr,
616 top_k: tl.constexpr,
617 compute_type: tl.constexpr,
618 has_zp: tl.constexpr,
619 use_int4_w4a16: tl.constexpr,
620 use_int8_w8a16: tl.constexpr,
621):
622 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights."""
623 # Map pid to C block (grouped ordering for L2 reuse)
624 pid = tl.program_id(axis=0)
625 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
626 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
627 num_pid_in_group = GROUP_SIZE_M * num_pid_n
628 group_id = pid // num_pid_in_group
629 first_pid_m = group_id * GROUP_SIZE_M
630 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
631 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
632 pid_n = (pid % num_pid_in_group) // group_size_m
634 # Create pointers for first blocks of A and B
635 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
636 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
637 return
638 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
639 # Cast to int64 to prevent overflow in stride*offset products
640 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
641 token_mask = offs_token < num_valid_tokens
643 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
644 if off_experts == -1:
645 # -----------------------------------------------------------
646 # Write back zeros to the output when the expert is not
647 # in the current expert parallel rank.
648 write_zeros_to_output(
649 c_ptr,
650 stride_cm,
651 stride_cn,
652 pid_n,
653 N,
654 offs_token,
655 token_mask,
656 BLOCK_SIZE_M,
657 BLOCK_SIZE_N,
658 compute_type,
659 )
660 return
662 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
663 offs_k = tl.arange(0, BLOCK_SIZE_K)
664 a_ptrs = a_ptr + (
665 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
666 )
668 if use_int4_w4a16:
669 b_ptrs = (
670 b_ptr
671 + off_experts * stride_be
672 + (offs_k[:, None] // 2) * stride_bk
673 + offs_bn[None, :] * stride_bn
674 )
675 b_shifter = (offs_k[:, None] % 2) * 4
676 elif use_int8_w8a16:
677 b_ptrs = (
678 b_ptr
679 + off_experts * stride_be
680 + offs_k[:, None] * stride_bk
681 + offs_bn[None, :] * stride_bn
682 )
684 if not has_zp and use_int4_w4a16:
685 b_zp_num = 8
686 if not has_zp and use_int8_w8a16:
687 b_zp_num = 128
688 elif has_zp and use_int4_w4a16:
689 b_zp_shifter = (offs_bn[None, :] % 2) * 4
691 # Accumulate C block in fp32
692 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
693 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
694 if not block_k_diviable:
695 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
696 k_other = 0.0
697 else:
698 k_mask = None
699 k_other = None
701 a = tl.load(
702 a_ptrs,
703 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
704 other=0.0,
705 )
706 b = tl.load(b_ptrs)
707 if use_int4_w4a16:
708 b = (b >> b_shifter) & 0xF
710 b_scale_ptrs = (
711 b_scale_ptr
712 + off_experts * stride_bse
713 + offs_bn[None, :] * stride_bsn
714 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
715 )
716 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
717 b_scale = b_scale.to(tl.float32)
719 if has_zp and use_int4_w4a16:
720 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
721 b_zp_ptrs = (
722 b_zp_ptr
723 + off_experts * stride_bze
724 + (offs_bn[None, :] // 2) * stride_bzn
725 + offs_k_true * stride_bzk
726 )
727 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
728 b_zp = (b_zp >> b_zp_shifter) & 0xF
729 b_zp = b_zp.to(tl.float32)
730 elif has_zp and use_int8_w8a16:
731 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
732 b_zp_ptrs = (
733 b_zp_ptr
734 + off_experts * stride_bze
735 + offs_bn[None, :] * stride_bzn
736 + offs_k_true * stride_bzk
737 )
738 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
739 b_zp = b_zp.to(tl.float32)
741 if has_zp:
742 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
743 else:
744 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
745 accumulator = tl.dot(a, b, acc=accumulator)
747 a_ptrs += BLOCK_SIZE_K * stride_ak
748 if use_int4_w4a16:
749 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
750 else:
751 b_ptrs += BLOCK_SIZE_K * stride_bk
753 if MUL_ROUTED_WEIGHT:
754 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
755 accumulator = accumulator * moe_weight[:, None]
757 accumulator = accumulator.to(compute_type)
758 # Write back output
759 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
760 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
761 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
762 tl.store(c_ptrs, accumulator, mask=c_mask)
765@triton.jit
766def fused_moe_kernel(
767 # Pointers to matrices
768 a_ptr,
769 b_ptr,
770 c_ptr,
771 b_bias_ptr,
772 a_scale_ptr,
773 b_scale_ptr,
774 topk_weights_ptr,
775 sorted_token_ids_ptr,
776 expert_ids_ptr,
777 num_tokens_post_padded_ptr,
778 # Matrix dimensions
779 N,
780 K,
781 EM,
782 num_valid_tokens,
783 stride_am,
784 stride_ak,
785 stride_be,
786 stride_bk,
787 stride_bn,
788 stride_cm,
789 stride_cn,
790 stride_asm,
791 stride_ask,
792 stride_bse,
793 stride_bsk,
794 stride_bsn,
795 stride_bbe, # bias expert stride
796 stride_bbn, # bias N stride
797 # Block size for block-wise quantization
798 group_n: tl.constexpr,
799 group_k: tl.constexpr,
800 naive_block_assignment: tl.constexpr,
801 # Meta-parameters
802 BLOCK_SIZE_M: tl.constexpr,
803 BLOCK_SIZE_N: tl.constexpr,
804 BLOCK_SIZE_K: tl.constexpr,
805 GROUP_SIZE_M: tl.constexpr,
806 SPLIT_K: tl.constexpr,
807 MUL_ROUTED_WEIGHT: tl.constexpr,
808 top_k: tl.constexpr,
809 compute_type: tl.constexpr,
810 use_fp8_w8a8: tl.constexpr,
811 use_int8_w8a8: tl.constexpr,
812 use_int8_w8a16: tl.constexpr,
813 per_channel_quant: tl.constexpr,
814 HAS_BIAS: tl.constexpr,
815):
816 """Fused MoE kernel: token × expert GEMM with quantization support."""
817 # Map pid to C block (grouped ordering for L2 reuse)
818 pid = tl.program_id(axis=0)
819 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
820 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
821 num_pid_in_group = GROUP_SIZE_M * num_pid_n
822 group_id = pid // num_pid_in_group
823 first_pid_m = group_id * GROUP_SIZE_M
824 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
825 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
826 pid_n = (pid % num_pid_in_group) // group_size_m
828 # Create pointers for first blocks of A and B
829 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
830 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
831 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
832 return
833 if not naive_block_assignment:
834 offs_token_id = pid_m * BLOCK_SIZE_M + offs
835 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
836 else:
837 offs_token = tl.where(
838 offs == 0,
839 pid_m, # first element = pid_m
840 num_valid_tokens, # remaining elements = constant
841 )
842 offs_token = offs_token.to(tl.int64) # prevent int32 overflow
844 token_mask = offs_token < num_valid_tokens
846 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
847 if off_experts == -1:
848 # Expert not in current EP rank, write zeros
849 write_zeros_to_output(
850 c_ptr,
851 stride_cm,
852 stride_cn,
853 pid_n,
854 N,
855 offs_token,
856 token_mask,
857 BLOCK_SIZE_M,
858 BLOCK_SIZE_N,
859 compute_type,
860 )
861 return
863 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
864 offs_k = tl.arange(0, BLOCK_SIZE_K)
865 a_ptrs = a_ptr + (
866 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
867 )
869 b_ptrs = (
870 b_ptr
871 + off_experts * stride_be
872 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
873 )
874 if use_int8_w8a16:
875 b_scale_ptrs = (
876 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
877 )
878 b_scale = tl.load(b_scale_ptrs)
880 if use_fp8_w8a8 or use_int8_w8a8:
881 if group_k > 0 and group_n > 0: # block-wise
882 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
883 offs_bsn = offs_bn // group_n
884 b_scale_ptrs = (
885 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
886 )
887 elif per_channel_quant: # channel-wise
888 b_scale_ptrs = (
889 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
890 )
891 b_scale = tl.load(b_scale_ptrs)
892 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
893 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
894 else: # tensor-wise
895 a_scale = tl.load(a_scale_ptr)
896 b_scale = tl.load(b_scale_ptr + off_experts)
897 if HAS_BIAS:
898 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
899 bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
900 # Accumulate C block in fp32
901 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
902 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
903 a = tl.load(
904 a_ptrs,
905 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
906 other=0.0,
907 )
908 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
909 if use_int8_w8a16:
910 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
911 elif use_fp8_w8a8 or use_int8_w8a8:
912 if group_k > 0 and group_n > 0:
913 k_start = k * BLOCK_SIZE_K
914 offs_ks = k_start // group_k
915 a_scale = tl.load(
916 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
917 )
918 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
920 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
921 else:
922 if use_fp8_w8a8:
923 accumulator = tl.dot(a, b, acc=accumulator)
924 else:
925 accumulator += tl.dot(a, b)
926 else:
927 accumulator += tl.dot(a, b)
928 a_ptrs += BLOCK_SIZE_K * stride_ak
929 b_ptrs += BLOCK_SIZE_K * stride_bk
931 # Dequantization
932 if use_int8_w8a16:
933 accumulator = accumulator * b_scale
934 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
935 accumulator = accumulator * a_scale * b_scale
937 if HAS_BIAS:
938 accumulator += bias[None, :]
940 # Router weight multiplication (must be in fp32)
941 if MUL_ROUTED_WEIGHT:
942 moe_weight = tl.load(
943 topk_weights_ptr + offs_token,
944 mask=token_mask,
945 other=0,
946 )
947 accumulator *= moe_weight[:, None]
949 accumulator = accumulator.to(compute_type)
951 # Write back output
952 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
953 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
954 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
955 tl.store(c_ptrs, accumulator, mask=c_mask)
958def invoke_fused_moe_wna16_triton_kernel(
959 A: torch.Tensor,
960 B: torch.Tensor,
961 C: torch.Tensor,
962 B_scale: torch.Tensor | None,
963 B_zp: torch.Tensor | None,
964 topk_weights: torch.Tensor | None,
965 sorted_token_ids: torch.Tensor,
966 expert_ids: torch.Tensor,
967 num_tokens_post_padded: torch.Tensor,
968 mul_routed_weight: bool,
969 top_k: int,
970 config: dict[str, Any],
971 compute_type: tl.dtype,
972 use_int8_w8a16: bool,
973 use_int4_w4a16: bool,
974 block_shape: list[int] | None,
975):
976 assert B_scale is not None and B_scale.ndim == 3
977 assert B_zp is None or B_zp.ndim == 3
978 assert block_shape is not None and block_shape[0] == 0
980 M = A.size(0)
981 num_tokens = M * top_k
983 EM = sorted_token_ids.size(0)
984 if A.size(0) < config["BLOCK_SIZE_M"]:
985 # optimize for small batch_size.
986 # We assume that top_ids of each token is unique,
987 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
988 # and we can skip some invalid blocks.
989 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
990 grid = lambda META: (
991 triton.cdiv(EM, META["BLOCK_SIZE_M"])
992 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
993 )
994 config = config.copy()
995 config.update(
996 get_moe_wna16_block_config(
997 config=config,
998 use_moe_wna16_cuda=False,
999 num_valid_tokens=num_tokens,
1000 size_k=A.size(1),
1001 size_n=B.size(1),
1002 num_experts=B.size(1),
1003 group_size=block_shape[1],
1004 real_top_k=top_k,
1005 block_size_m=config["BLOCK_SIZE_M"],
1006 )
1007 )
1009 fused_moe_kernel_gptq_awq[grid](
1010 A,
1011 B,
1012 C,
1013 B_scale,
1014 B_zp,
1015 topk_weights,
1016 sorted_token_ids,
1017 expert_ids,
1018 num_tokens_post_padded,
1019 B.size(1),
1020 A.size(1),
1021 EM,
1022 num_tokens,
1023 A.stride(0),
1024 A.stride(1),
1025 B.stride(0),
1026 B.stride(2),
1027 B.stride(1),
1028 C.stride(1),
1029 C.stride(2),
1030 B_scale.stride(0),
1031 B_scale.stride(2),
1032 B_scale.stride(1),
1033 B_zp.stride(0) if B_zp is not None else 0,
1034 B_zp.stride(2) if B_zp is not None else 0,
1035 B_zp.stride(1) if B_zp is not None else 0,
1036 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
1037 group_size=block_shape[1],
1038 MUL_ROUTED_WEIGHT=mul_routed_weight,
1039 top_k=top_k,
1040 compute_type=compute_type,
1041 has_zp=B_zp is not None,
1042 use_int4_w4a16=use_int4_w4a16,
1043 use_int8_w8a16=use_int8_w8a16,
1044 **config,
1045 )
1048def invoke_fused_moe_triton_kernel(
1049 A: torch.Tensor,
1050 B: torch.Tensor,
1051 C: torch.Tensor,
1052 A_scale: Optional[torch.Tensor],
1053 B_scale: Optional[torch.Tensor],
1054 topk_weights: Optional[torch.Tensor],
1055 sorted_token_ids: torch.Tensor,
1056 expert_ids: torch.Tensor,
1057 num_tokens_post_padded: torch.Tensor,
1058 mul_routed_weight: bool,
1059 top_k: int,
1060 config: dict[str, Any],
1061 compute_type: tl.dtype,
1062 use_fp8_w8a8: bool = False,
1063 use_int8_w8a8: bool = False,
1064 use_int8_w8a16: bool = False,
1065 use_int4_w4a16: bool = False,
1066 per_channel_quant: bool = False,
1067 block_shape: Optional[list[int]] = None,
1068 B_bias: torch.Tensor | None = None,
1069) -> None:
1070 """Launch the fused_moe_kernel Triton kernel."""
1071 assert topk_weights is not None or not mul_routed_weight
1072 assert topk_weights is None or topk_weights.stride(1) == 1
1073 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1075 if use_fp8_w8a8 or use_int8_w8a8:
1076 assert B_scale is not None
1077 assert block_shape is None or triton.cdiv(
1078 B.size(-2), block_shape[0]
1079 ) == B_scale.size(-2)
1080 assert block_shape is None or triton.cdiv(
1081 B.size(-1), block_shape[1]
1082 ) == B_scale.size(-1)
1083 elif use_int8_w8a16 or use_int4_w4a16:
1084 assert B_scale is not None
1085 assert block_shape is None or block_shape[0] == 0
1086 else:
1087 assert A_scale is None
1088 assert B_scale is None
1090 M = A.size(0)
1091 num_tokens = M * top_k
1092 if sorted_token_ids is not None:
1093 EM = sorted_token_ids.size(0)
1094 if A.size(0) < config["BLOCK_SIZE_M"]:
1095 EM = min(
1096 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]
1097 )
1098 else:
1099 EM = num_tokens * config["BLOCK_SIZE_M"]
1100 grid = lambda META: (
1101 triton.cdiv(EM, META["BLOCK_SIZE_M"])
1102 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
1103 )
1104 HAS_BIAS = B_bias is not None
1106 config = config.copy()
1107 config["SPLIT_K"] = 1
1108 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
1109 if block_shape is not None:
1110 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
1112 fused_moe_kernel[grid](
1113 A,
1114 B,
1115 C,
1116 B_bias,
1117 A_scale,
1118 B_scale,
1119 topk_weights,
1120 sorted_token_ids,
1121 expert_ids,
1122 num_tokens_post_padded,
1123 B.size(1), # N
1124 B.size(2), # K
1125 EM,
1126 num_tokens,
1127 A.stride(0),
1128 A.stride(1),
1129 B.stride(0),
1130 B.stride(2),
1131 B.stride(1),
1132 C.stride(1),
1133 C.stride(2),
1134 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
1135 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
1136 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
1137 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
1138 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
1139 B_bias.stride(0) if B_bias is not None else 0,
1140 B_bias.stride(1) if B_bias is not None else 0,
1141 0 if block_shape is None else block_shape[0],
1142 0 if block_shape is None else block_shape[1],
1143 MUL_ROUTED_WEIGHT=mul_routed_weight,
1144 top_k=top_k,
1145 compute_type=compute_type,
1146 use_fp8_w8a8=use_fp8_w8a8,
1147 use_int8_w8a8=use_int8_w8a8,
1148 use_int8_w8a16=use_int8_w8a16,
1149 per_channel_quant=per_channel_quant,
1150 naive_block_assignment=(sorted_token_ids is None),
1151 HAS_BIAS=HAS_BIAS,
1152 BLOCK_SIZE_K=BLOCK_SIZE_K,
1153 **config,
1154 )
1157def dispatch_fused_moe_kernel(
1158 A: torch.Tensor,
1159 B: torch.Tensor,
1160 C: torch.Tensor,
1161 A_scale: Optional[torch.Tensor],
1162 B_scale: Optional[torch.Tensor],
1163 B_zp: Optional[torch.Tensor],
1164 topk_weights: Optional[torch.Tensor],
1165 sorted_token_ids: torch.Tensor,
1166 expert_ids: torch.Tensor,
1167 num_tokens_post_padded: torch.Tensor,
1168 mul_routed_weight: bool,
1169 top_k: int,
1170 config: dict[str, Any],
1171 compute_type: tl.dtype,
1172 use_fp8_w8a8: bool,
1173 use_int8_w8a8: bool,
1174 use_int8_w8a16: bool,
1175 use_int4_w4a16: bool,
1176 per_channel_quant: bool,
1177 block_shape: Optional[list[int]] = None,
1178 B_bias: Optional[torch.Tensor] = None,
1179) -> None:
1180 """Dispatch to the appropriate fused MoE kernel based on quantization flags."""
1181 assert topk_weights is not None or not mul_routed_weight
1182 assert topk_weights is None or topk_weights.stride(1) == 1
1183 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1185 # M = A.size(0)
1186 # num_tokens = M * top_k
1188 if False:
1189 # TODO: Other precision-specific implementations
1190 # use_fp8_w8a8,
1191 # use_int8_w8a8,
1192 # use_int8_w8a16,
1193 # use_int4_w4a16,
1194 pass
1195 if (use_int8_w8a16 or use_int4_w4a16) and (
1196 block_shape is not None and block_shape[1] > 0
1197 ):
1198 assert B_bias is None
1199 invoke_fused_moe_wna16_triton_kernel(
1200 A,
1201 B,
1202 C,
1203 B_scale,
1204 B_zp,
1205 topk_weights,
1206 sorted_token_ids,
1207 expert_ids,
1208 num_tokens_post_padded,
1209 mul_routed_weight,
1210 top_k,
1211 config,
1212 compute_type,
1213 use_int8_w8a16,
1214 use_int4_w4a16,
1215 block_shape,
1216 )
1217 else:
1218 invoke_fused_moe_triton_kernel(
1219 A,
1220 B,
1221 C,
1222 A_scale,
1223 B_scale,
1224 topk_weights,
1225 sorted_token_ids,
1226 expert_ids,
1227 num_tokens_post_padded,
1228 mul_routed_weight,
1229 top_k,
1230 config,
1231 compute_type,
1232 use_fp8_w8a8,
1233 use_int8_w8a8,
1234 use_int8_w8a16,
1235 use_int4_w4a16,
1236 per_channel_quant,
1237 block_shape,
1238 B_bias,
1239 )
1242def fused_experts_impl(
1243 hidden_states: torch.Tensor,
1244 w1: torch.Tensor,
1245 w2: torch.Tensor,
1246 topk_weights: torch.Tensor,
1247 topk_ids: torch.Tensor,
1248 inplace: bool = False,
1249 activation: str = "silu",
1250 apply_router_weight_on_input: bool = False,
1251 use_fp8_w8a8: bool = False,
1252 use_int8_w8a8: bool = False,
1253 use_int8_w8a16: bool = False,
1254 use_int4_w4a16: bool = False,
1255 ocp_mx_scheme: str | None = None,
1256 per_channel_quant: bool = False,
1257 global_num_experts: int = -1,
1258 expert_map: torch.Tensor | None = None,
1259 w1_scale: Optional[torch.Tensor] = None,
1260 w2_scale: Optional[torch.Tensor] = None,
1261 w1_zp: torch.Tensor | None = None,
1262 w2_zp: torch.Tensor | None = None,
1263 a1_scale: Optional[torch.Tensor] = None,
1264 a2_scale: Optional[torch.Tensor] = None,
1265 block_shape: Optional[list[int]] = None,
1266 w1_bias: Optional[torch.Tensor] = None,
1267 w2_bias: Optional[torch.Tensor] = None,
1268) -> torch.Tensor:
1269 logger.debug("GEMS FUSED MOE")
1270 assert (
1271 activation == "silu"
1272 ), f"Only 'silu' activation is supported, got {activation}"
1274 activation_enum = MoEActivation.from_str(activation)
1276 # Check constraints
1277 if use_int4_w4a16:
1278 # INT4 stored unpacked in INT8 containers (full K dim)
1279 assert hidden_states.size(1) == w1.size(
1280 2
1281 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1282 elif ocp_mx_scheme is not None:
1283 if ocp_mx_scheme.startswith("w_mxfp4"):
1284 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
1285 elif ocp_mx_scheme.startswith("w_mxfp6"):
1286 assert (
1287 hidden_states.size(1) == (w1.size(2) * 4) // 3
1288 ), "hidden size mismatch"
1289 else:
1290 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1291 else:
1292 assert hidden_states.size(1) == w1.size(
1293 2
1294 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1296 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
1297 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1298 assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
1299 assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
1300 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
1302 num_tokens = hidden_states.size(0)
1303 E, N, _ = w1.size()
1304 K = w2.size(1)
1305 if global_num_experts == -1:
1306 global_num_experts = E
1307 top_k_num = topk_ids.size(1)
1309 CHUNK_SIZE: int = 64 * 1024
1310 M = min(num_tokens, CHUNK_SIZE)
1312 config_dtype = _get_config_dtype_str(
1313 use_fp8_w8a8=use_fp8_w8a8,
1314 use_int8_w8a16=use_int8_w8a16,
1315 use_int4_w4a16=use_int4_w4a16,
1316 ocp_mx_scheme=ocp_mx_scheme,
1317 dtype=hidden_states.dtype,
1318 )
1320 quant_dtype = _get_config_quant_dtype(
1321 use_fp8_w8a8=use_fp8_w8a8,
1322 use_int8_w8a8=use_int8_w8a8,
1323 ocp_mx_scheme=ocp_mx_scheme,
1324 )
1326 get_config_func = functools.partial(
1327 try_get_optimal_moe_config,
1328 w1.size(),
1329 w2.size(),
1330 top_k_num,
1331 config_dtype,
1332 block_shape=block_shape,
1333 )
1335 config = get_config_func(M)
1337 # cache1 and cache3 share memory (non-overlapping lifetime)
1338 cache13 = torch.empty(
1339 M * top_k_num * max(N, K),
1340 device=hidden_states.device,
1341 dtype=hidden_states.dtype,
1342 )
1343 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
1344 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
1346 # cache2 needs separate memory (concurrent with cache1)
1347 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum)
1348 intermediate_cache2 = torch.empty(
1349 (M * top_k_num, activation_out_dim),
1350 device=hidden_states.device,
1351 dtype=hidden_states.dtype,
1352 )
1354 if hidden_states.dtype == torch.bfloat16:
1355 compute_type = tl.bfloat16
1356 elif hidden_states.dtype == torch.float16:
1357 compute_type = tl.float16
1358 elif hidden_states.dtype == torch.float32:
1359 compute_type = tl.float32
1360 else:
1361 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1363 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
1365 if ocp_mx_scheme is not None:
1366 # Dequantize OCP MX weights (TODO: skip on platforms with native MX)
1367 if ocp_mx_scheme.startswith("w_mxfp4"):
1368 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
1369 w1_scale = None
1370 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
1371 w2_scale = None
1372 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
1373 w1 = dequant_mxfp6(
1374 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1375 )
1376 w1_scale = None
1377 w2 = dequant_mxfp6(
1378 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1379 )
1380 w2_scale = None
1381 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
1382 w1 = dequant_mxfp6(
1383 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1384 )
1385 w1_scale = None
1386 w2 = dequant_mxfp6(
1387 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1388 )
1389 w2_scale = None
1390 else:
1391 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1393 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot)
1394 if use_int8_w8a16 or use_int4_w4a16:
1395 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype)
1396 w1_scale = None
1397 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype)
1398 w2_scale = None
1399 use_int8_w8a16 = False
1400 use_int4_w4a16 = False
1402 for chunk in range((num_tokens // CHUNK_SIZE) + 1):
1403 begin_chunk_idx, end_chunk_idx = (
1404 chunk * CHUNK_SIZE,
1405 min((chunk + 1) * CHUNK_SIZE, num_tokens),
1406 )
1407 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
1408 tokens_in_chunk, _ = curr_hidden_states.size()
1410 if tokens_in_chunk == 0:
1411 break
1413 if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
1414 # Adjust cache size for last chunk
1415 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
1416 intermediate_cache2 = intermediate_cache2[
1417 : tokens_in_chunk * topk_ids.size(1)
1418 ]
1419 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
1420 config = get_config_func(tokens_in_chunk)
1422 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
1423 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
1424 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
1425 A=curr_hidden_states,
1426 A_scale=a1_scale,
1427 quant_dtype=quant_dtype,
1428 per_act_token_quant=per_channel_quant,
1429 block_shape=block_shape,
1430 ocp_mx_scheme=ocp_mx_scheme,
1431 )
1433 SPARSITY_FACTOR = 4
1434 naive_block_assignment = (
1435 expert_map is None
1436 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
1437 and not (
1438 (use_int8_w8a16 or use_int4_w4a16)
1439 and block_shape is not None
1440 and block_shape[1] > 0
1441 )
1442 )
1444 if not naive_block_assignment:
1445 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
1446 curr_topk_ids,
1447 config["BLOCK_SIZE_M"],
1448 global_num_experts,
1449 expert_map,
1450 # ignore_invalid_experts=True,
1451 )
1452 else:
1453 max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
1454 expert_ids = curr_topk_ids.view(-1)
1455 num_tokens_post_padded = torch.empty(
1456 (1), dtype=torch.int32, device=topk_ids.device
1457 )
1458 num_tokens_post_padded.fill_(max_num_tokens_padded)
1459 sorted_token_ids = None
1461 dispatch_fused_moe_kernel(
1462 qcurr_hidden_states,
1463 w1,
1464 intermediate_cache1,
1465 a1q_scale,
1466 w1_scale,
1467 w1_zp,
1468 curr_topk_weights,
1469 sorted_token_ids,
1470 expert_ids,
1471 num_tokens_post_padded,
1472 apply_router_weight_on_input,
1473 top_k_num,
1474 config,
1475 compute_type=compute_type,
1476 use_fp8_w8a8=use_fp8_w8a8,
1477 use_int8_w8a8=use_int8_w8a8,
1478 use_int8_w8a16=use_int8_w8a16,
1479 use_int4_w4a16=use_int4_w4a16,
1480 per_channel_quant=per_channel_quant,
1481 block_shape=block_shape,
1482 B_bias=w1_bias,
1483 )
1485 apply_moe_activation(
1486 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
1487 )
1489 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
1490 A=intermediate_cache2,
1491 A_scale=a2_scale,
1492 quant_dtype=quant_dtype,
1493 per_act_token_quant=per_channel_quant,
1494 block_shape=block_shape,
1495 ocp_mx_scheme=ocp_mx_scheme,
1496 )
1498 if expert_map is not None:
1499 intermediate_cache3.zero_()
1501 dispatch_fused_moe_kernel(
1502 qintermediate_cache2,
1503 w2,
1504 intermediate_cache3,
1505 a2q_scale,
1506 w2_scale,
1507 w2_zp,
1508 curr_topk_weights,
1509 sorted_token_ids,
1510 expert_ids,
1511 num_tokens_post_padded,
1512 not apply_router_weight_on_input,
1513 1,
1514 config,
1515 compute_type=compute_type,
1516 use_fp8_w8a8=use_fp8_w8a8,
1517 use_int8_w8a8=use_int8_w8a8,
1518 use_int8_w8a16=use_int8_w8a16,
1519 use_int4_w4a16=use_int4_w4a16,
1520 per_channel_quant=per_channel_quant,
1521 block_shape=block_shape,
1522 B_bias=w2_bias,
1523 )
1525 moe_sum(
1526 intermediate_cache3.view(*intermediate_cache3.size()),
1527 out_hidden_states[begin_chunk_idx:end_chunk_idx],
1528 )
1530 return out_hidden_states
1533def inplace_fused_experts(
1534 hidden_states: torch.Tensor,
1535 w1: torch.Tensor,
1536 w2: torch.Tensor,
1537 topk_weights: torch.Tensor,
1538 topk_ids: torch.Tensor,
1539 activation: str = "silu",
1540 apply_router_weight_on_input: bool = False,
1541 use_fp8_w8a8: bool = False,
1542 use_int8_w8a8: bool = False,
1543 use_int8_w8a16: bool = False,
1544 use_int4_w4a16: bool = False,
1545 per_channel_quant: bool = False,
1546 global_num_experts: int = -1,
1547 w1_scale: Optional[torch.Tensor] = None,
1548 w2_scale: Optional[torch.Tensor] = None,
1549 a1_scale: Optional[torch.Tensor] = None,
1550 a2_scale: Optional[torch.Tensor] = None,
1551 block_shape: Optional[list[int]] = None,
1552 w1_bias: Optional[torch.Tensor] = None,
1553 w2_bias: Optional[torch.Tensor] = None,
1554) -> None:
1555 """
1556 In-place fused MoE: writes output directly into ``hidden_states``.
1558 Same semantics as ``fused_experts_impl(..., inplace=True)``.
1559 Returns None (the result is stored in ``hidden_states``).
1560 """
1561 fused_experts_impl(
1562 hidden_states,
1563 w1,
1564 w2,
1565 topk_weights,
1566 topk_ids,
1567 inplace=True,
1568 activation=activation,
1569 apply_router_weight_on_input=apply_router_weight_on_input,
1570 use_fp8_w8a8=use_fp8_w8a8,
1571 use_int8_w8a8=use_int8_w8a8,
1572 use_int8_w8a16=use_int8_w8a16,
1573 use_int4_w4a16=use_int4_w4a16,
1574 per_channel_quant=per_channel_quant,
1575 global_num_experts=global_num_experts,
1576 w1_scale=w1_scale,
1577 w2_scale=w2_scale,
1578 a1_scale=a1_scale,
1579 a2_scale=a2_scale,
1580 block_shape=block_shape,
1581 w1_bias=w1_bias,
1582 w2_bias=w2_bias,
1583 )
1586def outplace_fused_experts(
1587 hidden_states: torch.Tensor,
1588 w1: torch.Tensor,
1589 w2: torch.Tensor,
1590 topk_weights: torch.Tensor,
1591 topk_ids: torch.Tensor,
1592 activation: str = "silu",
1593 apply_router_weight_on_input: bool = False,
1594 use_fp8_w8a8: bool = False,
1595 use_int8_w8a8: bool = False,
1596 use_int8_w8a16: bool = False,
1597 use_int4_w4a16: bool = False,
1598 per_channel_quant: bool = False,
1599 global_num_experts: int = -1,
1600 w1_scale: Optional[torch.Tensor] = None,
1601 w2_scale: Optional[torch.Tensor] = None,
1602 a1_scale: Optional[torch.Tensor] = None,
1603 a2_scale: Optional[torch.Tensor] = None,
1604 block_shape: Optional[list[int]] = None,
1605 w1_bias: Optional[torch.Tensor] = None,
1606 w2_bias: Optional[torch.Tensor] = None,
1607) -> torch.Tensor:
1608 """
1609 Out-of-place fused MoE: allocates and returns a new output tensor.
1611 Same semantics as ``fused_experts_impl(..., inplace=False)``.
1612 """
1613 return fused_experts_impl(
1614 hidden_states,
1615 w1,
1616 w2,
1617 topk_weights,
1618 topk_ids,
1619 inplace=False,
1620 activation=activation,
1621 apply_router_weight_on_input=apply_router_weight_on_input,
1622 use_fp8_w8a8=use_fp8_w8a8,
1623 use_int8_w8a8=use_int8_w8a8,
1624 use_int8_w8a16=use_int8_w8a16,
1625 use_int4_w4a16=use_int4_w4a16,
1626 per_channel_quant=per_channel_quant,
1627 global_num_experts=global_num_experts,
1628 w1_scale=w1_scale,
1629 w2_scale=w2_scale,
1630 a1_scale=a1_scale,
1631 a2_scale=a2_scale,
1632 block_shape=block_shape,
1633 w1_bias=w1_bias,
1634 w2_bias=w2_bias,
1635 )