Coverage for src/flag_gems/fused/fused_moe.py: 10%
163 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2from typing import Any, Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems.fused.moe_align_block_size import moe_align_block_size
9from flag_gems.fused.moe_sum import moe_sum
10from flag_gems.fused.silu_and_mul import silu_and_mul_kernel
12logger = logging.getLogger(__name__)
15@triton.jit
16def write_zeros_to_output(
17 c_ptr,
18 stride_cm,
19 stride_cn,
20 pid_n,
21 N,
22 offs_token,
23 token_mask,
24 BLOCK_SIZE_M,
25 BLOCK_SIZE_N,
26 compute_type,
27):
28 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
29 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
30 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
31 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
32 tl.store(c_ptrs, accumulator, mask=c_mask)
35@triton.jit
36def fused_moe_kernel(
37 # Pointers to matrices
38 a_ptr,
39 b_ptr,
40 c_ptr,
41 a_scale_ptr,
42 b_scale_ptr,
43 topk_weights_ptr,
44 sorted_token_ids_ptr,
45 expert_ids_ptr,
46 num_tokens_post_padded_ptr,
47 # Matrix dimensions
48 N,
49 K,
50 EM,
51 num_valid_tokens,
52 # Strides
53 stride_am,
54 stride_ak,
55 stride_be,
56 stride_bk,
57 stride_bn,
58 stride_cm,
59 stride_cn,
60 stride_asm,
61 stride_ask,
62 stride_bse,
63 stride_bsk,
64 stride_bsn,
65 # Block size for block-wise quantization
66 group_n: tl.constexpr,
67 group_k: tl.constexpr,
68 # Meta-parameters
69 BLOCK_SIZE_M: tl.constexpr,
70 BLOCK_SIZE_N: tl.constexpr,
71 BLOCK_SIZE_K: tl.constexpr,
72 GROUP_SIZE_M: tl.constexpr,
73 MUL_ROUTED_WEIGHT: tl.constexpr,
74 top_k: tl.constexpr,
75 compute_type: tl.constexpr,
76 use_fp8_w8a8: tl.constexpr,
77 use_int8_w8a8: tl.constexpr,
78 per_channel_quant: tl.constexpr,
79):
80 """
81 Fused MoE GEMM kernel with expert-based indirect addressing.
83 Computes: C[t, :] = A[t // topk, :] @ B[expert(t), :, :] [* topk_weight[t]]
85 Key Parameters:
86 - A: Input activations [M, K] (or quantized)
87 - B: Stacked expert weights [E, N, K]
88 - C: Output [num_sorted_tokens, N] (indexed by sorted_token_ids)
89 - sorted_token_ids: Per-expert sorted token indices (from moe_align_block_size)
90 - expert_ids: Expert index for each M-block
91 """
92 # Map program id to the block of C it should compute.
93 # Grouped ordering promotes L2 data reuse.
94 pid = tl.program_id(axis=0)
95 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
96 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
97 num_pid_in_group = GROUP_SIZE_M * num_pid_n
98 group_id = pid // num_pid_in_group
99 first_pid_m = group_id * GROUP_SIZE_M
100 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
101 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
102 pid_n = (pid % num_pid_in_group) // group_size_m
104 # Load sorted token indices for this M-block
105 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
106 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
107 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
108 return
110 offs_token_id = pid_m * BLOCK_SIZE_M + offs
111 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
112 offs_token = offs_token.to(tl.int64)
113 token_mask = offs_token < num_valid_tokens
115 # Determine which expert this block belongs to
116 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
117 if off_experts == -1:
118 write_zeros_to_output(
119 c_ptr,
120 stride_cm,
121 stride_cn,
122 pid_n,
123 N,
124 offs_token,
125 token_mask,
126 BLOCK_SIZE_M,
127 BLOCK_SIZE_N,
128 compute_type,
129 )
130 return
132 # Set up A and B pointers
133 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
134 offs_k = tl.arange(0, BLOCK_SIZE_K)
135 a_ptrs = a_ptr + (
136 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
137 )
138 b_ptrs = (
139 b_ptr
140 + off_experts * stride_be
141 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
142 )
144 # Load quantization scales based on mode
145 if use_fp8_w8a8 or use_int8_w8a8:
146 if group_k > 0 and group_n > 0:
147 # block-wise quantization
148 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
149 offs_bsn = offs_bn // group_n
150 b_scale_ptrs = (
151 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
152 )
153 elif per_channel_quant:
154 # per-channel quantization
155 b_scale_ptrs = (
156 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
157 )
158 b_scale = tl.load(b_scale_ptrs)
159 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
160 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
161 else:
162 # per-tensor quantization
163 a_scale = tl.load(a_scale_ptr)
164 b_scale = tl.load(b_scale_ptr + off_experts)
166 # Main GEMM loop: accumulate in float32
167 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
168 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
169 a = tl.load(
170 a_ptrs,
171 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
172 other=0.0,
173 )
174 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
176 if use_fp8_w8a8 or use_int8_w8a8:
177 if group_k > 0 and group_n > 0:
178 k_start = k * BLOCK_SIZE_K
179 offs_ks = k_start // group_k
180 a_scale = tl.load(
181 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
182 )
183 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
184 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
185 else:
186 accumulator = tl.dot(a, b, acc=accumulator)
187 else:
188 # Fused dot-accumulate: on SM90 this maps to WGMMA with
189 # in-place accumulation, avoiding a separate add instruction.
190 accumulator = tl.dot(a, b, acc=accumulator)
192 a_ptrs += BLOCK_SIZE_K * stride_ak
193 b_ptrs += BLOCK_SIZE_K * stride_bk
195 # Post-loop dequantization
196 if (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
197 accumulator = accumulator * a_scale * b_scale
199 # Router weight multiplication (in float32 for numerical stability)
200 if MUL_ROUTED_WEIGHT:
201 moe_weight = tl.load(
202 topk_weights_ptr + offs_token,
203 mask=token_mask,
204 other=0,
205 )
206 accumulator *= moe_weight[:, None]
208 accumulator = accumulator.to(compute_type)
210 # Write back
211 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
212 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
213 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
214 tl.store(c_ptrs, accumulator, mask=c_mask)
217def get_default_config(
218 M: int,
219 E: int,
220 N: int,
221 K: int,
222 topk: int,
223 dtype: str | None,
224 block_shape: list[int] | None = None,
225) -> dict[str, int]:
226 """Return a reasonable default Triton config for the fused MoE kernel."""
227 if dtype == "fp8_w8a8" and block_shape is not None:
228 config = {
229 "BLOCK_SIZE_M": 16 if M <= 64 else 64,
230 "BLOCK_SIZE_N": block_shape[0],
231 "BLOCK_SIZE_K": block_shape[1],
232 "GROUP_SIZE_M": 1 if M <= 16 else 32,
233 "num_warps": 4,
234 "num_stages": 3,
235 }
236 else:
237 if M <= 32:
238 block_m = 16
239 elif M <= 96:
240 block_m = 32
241 elif M <= 512:
242 block_m = 64
243 else:
244 block_m = 128
246 # --- Tile sizing optimised for H100/H800 SM90 GPUs ---
247 # Larger N/K tiles improve compute intensity and reduce grid
248 # launches for the common case where N is large (e.g. 14336).
249 if N >= 4096:
250 block_n = 128 if M <= 128 else 256
251 elif N >= 1024:
252 block_n = 64 if M <= 64 else 128
253 else:
254 block_n = 64 if M <= 64 else 128
256 # K-tile: 128 gives better arithmetic intensity.
257 if dtype == "fp8_w8a8":
258 block_k = 128
259 elif K >= 4096 or M <= 64:
260 block_k = 128
261 else:
262 block_k = 64
264 # Group-M: promotes L2 reuse across M-blocks.
265 tokens_per_expert = (M * topk) // max(E, 1)
266 if tokens_per_expert > 128:
267 group_m = 16
268 elif tokens_per_expert > 32:
269 group_m = 8
270 else:
271 group_m = 1
273 num_warps = 4 if block_m * block_n < 8192 else 8
274 num_stages = 3
276 # Shared-memory guard (~232 KB on H100/H800).
277 smem_per_stage = (block_m * block_k + block_k * block_n) * 2
278 while num_stages > 2 and smem_per_stage * num_stages > 200_000:
279 num_stages -= 1
281 config = {
282 "BLOCK_SIZE_M": block_m,
283 "BLOCK_SIZE_N": block_n,
284 "BLOCK_SIZE_K": block_k,
285 "GROUP_SIZE_M": group_m,
286 "num_warps": num_warps,
287 "num_stages": num_stages,
288 }
289 return config
292def invoke_fused_moe_triton_kernel(
293 A: torch.Tensor,
294 B: torch.Tensor,
295 C: torch.Tensor,
296 A_scale: Optional[torch.Tensor],
297 B_scale: Optional[torch.Tensor],
298 topk_weights: Optional[torch.Tensor],
299 sorted_token_ids: torch.Tensor,
300 expert_ids: torch.Tensor,
301 num_tokens_post_padded: torch.Tensor,
302 mul_routed_weight: bool,
303 top_k: int,
304 config: dict[str, Any],
305 compute_type: tl.dtype,
306 use_fp8_w8a8: bool = False,
307 use_int8_w8a8: bool = False,
308 per_channel_quant: bool = False,
309 block_shape: Optional[list[int]] = None,
310) -> None:
311 """
312 Launch the fused_moe_kernel Triton kernel.
314 Args:
315 A: Input activations [M, K]
316 B: Expert weight matrices [E, N, K]
317 C: Output buffer [M, topk, N]
318 A_scale: Activation quantization scale (or None)
319 B_scale: Weight quantization scale (or None)
320 topk_weights: Router weights [M, topk] (or None)
321 sorted_token_ids: From moe_align_block_size
322 expert_ids: From moe_align_block_size
323 num_tokens_post_padded: From moe_align_block_size
324 mul_routed_weight: Whether to multiply router weights in-kernel
325 top_k: Number of top experts per token
326 config: Triton config dict with BLOCK_SIZE_M/N/K, GROUP_SIZE_M, etc.
327 compute_type: Triton dtype for compute (tl.bfloat16, tl.float16, etc.)
328 use_fp8_w8a8: FP8 weight+activation quantization
329 use_int8_w8a8: INT8 weight+activation quantization
330 per_channel_quant: Per-channel quantization mode
331 block_shape: [block_n, block_k] for block-wise quantization
332 """
333 assert topk_weights is not None or not mul_routed_weight
334 assert topk_weights is None or topk_weights.stride(1) == 1
335 assert sorted_token_ids.stride(0) == 1
337 if use_fp8_w8a8 or use_int8_w8a8:
338 assert B_scale is not None
339 else:
340 assert A_scale is None
341 assert B_scale is None
343 M = A.size(0)
344 num_tokens = M * top_k
345 EM = sorted_token_ids.size(0)
346 if A.size(0) < config["BLOCK_SIZE_M"]:
347 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
349 grid = lambda META: (
350 triton.cdiv(EM, META["BLOCK_SIZE_M"])
351 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
352 )
354 config = config.copy()
355 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
356 if block_shape is not None:
357 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
359 fused_moe_kernel[grid](
360 A,
361 B,
362 C,
363 A_scale,
364 B_scale,
365 topk_weights,
366 sorted_token_ids,
367 expert_ids,
368 num_tokens_post_padded,
369 B.size(1), # N
370 B.size(2), # K
371 EM,
372 num_tokens,
373 A.stride(0),
374 A.stride(1),
375 B.stride(0),
376 B.stride(2),
377 B.stride(1),
378 C.stride(1),
379 C.stride(2),
380 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
381 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
382 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
383 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
384 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
385 0 if block_shape is None else block_shape[0],
386 0 if block_shape is None else block_shape[1],
387 MUL_ROUTED_WEIGHT=mul_routed_weight,
388 top_k=top_k,
389 compute_type=compute_type,
390 use_fp8_w8a8=use_fp8_w8a8,
391 use_int8_w8a8=use_int8_w8a8,
392 per_channel_quant=per_channel_quant,
393 BLOCK_SIZE_K=BLOCK_SIZE_K,
394 **config,
395 )
398def _apply_silu_and_mul(out: torch.Tensor, inp: torch.Tensor) -> None:
399 """Apply SiLU-and-Mul activation: out = SiLU(inp[:, :N]) * inp[:, N:]."""
400 N = inp.shape[-1] // 2
401 x, y = inp[:, :N], inp[:, N:]
402 silu_and_mul_kernel(x, y, out0=out)
405def fused_experts_impl(
406 hidden_states: torch.Tensor,
407 w1: torch.Tensor,
408 w2: torch.Tensor,
409 topk_weights: torch.Tensor,
410 topk_ids: torch.Tensor,
411 num_experts: int = -1,
412 activation: str = "silu",
413) -> torch.Tensor:
414 """
415 Complete fused MoE forward pass (bf16/fp16, no quantization).
417 Pipeline:
418 moe_align_block_size → GEMM1(up+gate) → SiLU+Mul → GEMM2(down) → moe_sum
420 Args:
421 hidden_states: [num_tokens, hidden_size]
422 w1: [E, intermediate_size * 2, hidden_size] (gate + up projection)
423 w2: [E, hidden_size, intermediate_size] (down projection)
424 topk_weights: [num_tokens, topk]
425 topk_ids: [num_tokens, topk]
426 num_experts: Total number of experts (default: inferred from w1)
427 activation: Activation function name ("silu")
429 Returns:
430 output: [num_tokens, hidden_size]
431 """
432 logger.debug("GEMS FUSED MOE")
433 assert (
434 activation == "silu"
435 ), f"Only 'silu' activation is supported, got {activation}"
437 M, K = hidden_states.shape
438 E = w1.shape[0]
439 N = w1.shape[1] # intermediate_size * 2
440 top_k = topk_ids.shape[1]
442 if num_experts <= 0:
443 num_experts = E
445 # Determine compute type
446 if hidden_states.dtype == torch.bfloat16:
447 compute_type = tl.bfloat16
448 elif hidden_states.dtype == torch.float16:
449 compute_type = tl.float16
450 elif hidden_states.dtype == torch.float32:
451 compute_type = tl.float32
452 else:
453 raise ValueError(f"Unsupported dtype: {hidden_states.dtype}")
455 # Get kernel config
456 config = get_default_config(M, E, w2.shape[1], K, top_k, None)
458 # Step 1: Align tokens to experts
459 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
460 topk_ids, config["BLOCK_SIZE_M"], num_experts
461 )
463 # Allocate intermediate buffers
464 # GEMM1 output: [M, topk, N]
465 intermediate_cache1 = torch.empty(
466 (M, top_k, N), dtype=hidden_states.dtype, device=hidden_states.device
467 )
468 # After activation (SiLU+Mul): [M * topk, N // 2]
469 intermediate_cache2 = torch.empty(
470 (M * top_k, N // 2), dtype=hidden_states.dtype, device=hidden_states.device
471 )
472 # GEMM2 output: [M, topk, K]
473 intermediate_cache3 = torch.empty(
474 (M, top_k, K), dtype=hidden_states.dtype, device=hidden_states.device
475 )
476 # Final output: [M, K]
477 output = torch.zeros((M, K), dtype=hidden_states.dtype, device=hidden_states.device)
479 # Step 2: GEMM1 — hidden_states @ W1 → intermediate_cache1
480 invoke_fused_moe_triton_kernel(
481 A=hidden_states,
482 B=w1,
483 C=intermediate_cache1,
484 A_scale=None,
485 B_scale=None,
486 topk_weights=None,
487 sorted_token_ids=sorted_token_ids,
488 expert_ids=expert_ids,
489 num_tokens_post_padded=num_tokens_post_padded,
490 mul_routed_weight=False,
491 top_k=top_k,
492 config=config,
493 compute_type=compute_type,
494 )
496 # Step 3: Activation — SiLU(gate) * up
497 _apply_silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
499 # Step 4: GEMM2 — intermediate @ W2 → intermediate_cache3
500 # Multiply router weights here
501 invoke_fused_moe_triton_kernel(
502 A=intermediate_cache2,
503 B=w2,
504 C=intermediate_cache3,
505 A_scale=None,
506 B_scale=None,
507 topk_weights=topk_weights,
508 sorted_token_ids=sorted_token_ids,
509 expert_ids=expert_ids,
510 num_tokens_post_padded=num_tokens_post_padded,
511 mul_routed_weight=True,
512 top_k=1, # After activation, each token-expert pair is independent
513 config=config,
514 compute_type=compute_type,
515 )
517 # Step 5: Reduce — sum over topK experts
518 moe_sum(intermediate_cache3, output)
520 return output