Coverage for src/flag_gems/fused/fused_marlin_moe.py: 34%

326 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# SPDX-License-Identifier: Apache-2.0 

2""" 

3Fused Marlin MoE for FlagGems. 

4 

5Aligns the interface of vLLM v0.20.0: 

6 vllm/model_executor/layers/fused_moe/fused_marlin_moe.py :: fused_marlin_moe 

7 

8PHASE 2 (this file): bypass `fused_experts_impl`'s dequant-then-FP16-GEMM 

9shortcut and dispatch directly to the wna16 Triton kernel 

10(`fused_moe_kernel_gptq_awq`) for true fused-dequant W4A16/W8A16 GEMM. 

11 

12The local helper `_fused_marlin_moe_impl` mirrors `fused_experts_impl`'s 

13orchestration (chunk loop, moe_align, two GEMMs, activation, reduction) 

14but deletes the INT4/INT8 dequant branch and forwards `block_shape` so 

15the wna16 path is actually taken. 

16 

17MVP scope: 

18 - quant_type: GPTQ uint4b8 (INT4) and uint8b128 (INT8) 

19 - activation: SwiGLU / SiLU 

20 - act_order: NOT supported (g_idx / sort_indices must be None) 

21 - FP8 input: NOT supported 

22 - LoRA, clamp_limit, expert_map: NOT supported 

23""" 

24import functools 

25from typing import Any, Callable, Optional, Tuple 

26 

27import torch 

28import triton 

29import triton.language as tl 

30from torch.utils.weak import WeakTensorKeyDictionary 

31 

32from flag_gems.fused.fused_moe import ( 

33 MoEActivation, 

34 _get_config_dtype_str, 

35 _get_config_quant_dtype, 

36 apply_moe_activation, 

37 dispatch_fused_moe_kernel, 

38 moe_kernel_quantize_input, 

39 try_get_optimal_moe_config, 

40 write_zeros_to_output, 

41) 

42from flag_gems.fused.moe_align_block_size import moe_align_block_size 

43from flag_gems.fused.moe_sum import moe_sum 

44from flag_gems.fused.silu_and_mul import silu_and_mul_out 

45 

46# ---------------------------------------------------------------------------- 

47# quant_type_id constants — mirror a subset of vLLM scalar_types ids. 

48# ---------------------------------------------------------------------------- 

49# GPTQ INT4 (weight stored as w + 8, dequant subtracts 8) 

50QUANT_TYPE_UINT4B8 = 0 

51# INT8 (weight stored as w + 128) 

52QUANT_TYPE_UINT8B128 = 1 

53 

54_QUANT_TYPE_INT4 = {QUANT_TYPE_UINT4B8} 

55_QUANT_TYPE_INT8 = {QUANT_TYPE_UINT8B128} 

56_SUPPORTED_QUANT_TYPES = _QUANT_TYPE_INT4 | _QUANT_TYPE_INT8 

57 

58 

59@functools.lru_cache(maxsize=1) 

60def _is_hopper() -> bool: 

61 return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9 

62 

63 

64# ============================================================================ 

65# W4A16 (GPTQ uint4b8) fast path: tile-B + nibble-interleaved weight packing 

66# fed to a magic-number SIMD INT4->bf16/fp16 dequant + tl.dot kernel. This is 

67# the Hopper-gated short path taken by fused_marlin_moe for plain GPTQ uint4b8. 

68# ============================================================================ 

69_W_PACK_CACHE: WeakTensorKeyDictionary = WeakTensorKeyDictionary() 

70_SCALE_PACK_CACHE: WeakTensorKeyDictionary = WeakTensorKeyDictionary() 

71 

72 

73def _pack_w_interleave(w: torch.Tensor, block_size_k: int) -> torch.Tensor: 

74 assert w.dtype == torch.uint8 

75 assert w.ndim == 3 

76 assert ( 

77 block_size_k % 8 == 0 

78 ), f"BLOCK_SIZE_K={block_size_k} must be multiple of 8 (8 logical K per int32)" 

79 E, N_out, K_half = w.shape 

80 K = K_half * 2 

81 B = block_size_k // 8 

82 assert K % (8 * B) == 0, f"K={K} must be divisible by BLOCK_SIZE_K={block_size_k}" 

83 num_groups = K // (8 * B) 

84 

85 _NIBBLE_PERM = (0, 4, 1, 5, 2, 6, 3, 7) 

86 _BIT_SHIFTS = tuple(4 * p for p in _NIBBLE_PERM) 

87 shifts = torch.tensor(_BIT_SHIFTS, dtype=torch.int32, device=w.device) 

88 out = torch.empty(E, K // 8, N_out, dtype=torch.int32, device=w.device) 

89 

90 for e in range(E): 

91 we = w[e] # (N_out, K//2) uint8 

92 low = (we & 0xF).to(torch.uint8) 

93 high = ((we >> 4) & 0xF).to(torch.uint8) 

94 unpacked = torch.stack([low, high], dim=-1).reshape(N_out, K) 

95 tiled = unpacked.reshape(N_out, num_groups, 8, B).transpose(-1, -2) 

96 # (N_out, num_groups, B, 8) 

97 packed = (tiled.to(torch.int32) << shifts).sum(dim=-1, dtype=torch.int32) 

98 # (N_out, num_groups, B) -> (N_out, K//8) 

99 packed = packed.reshape(N_out, K // 8) 

100 out[e].copy_(packed.transpose(0, 1)) 

101 return out # (E, K//8, N_out) 

102 

103 

104def _pack_scale_transpose(s: torch.Tensor) -> torch.Tensor: 

105 assert s.ndim == 3 

106 return s.transpose(-2, -1).contiguous() 

107 

108 

109def _cached_pack_w(w: torch.Tensor, block_size_k: int, cached: bool) -> torch.Tensor: 

110 if not cached: 

111 return _pack_w_interleave(w, block_size_k) 

112 per_w = _W_PACK_CACHE.get(w) 

113 if per_w is None: 

114 per_w = {} 

115 _W_PACK_CACHE[w] = per_w 

116 packed = per_w.get(block_size_k) 

117 if packed is None: 

118 packed = _pack_w_interleave(w, block_size_k) 

119 per_w[block_size_k] = packed 

120 return packed 

121 

122 

123def _cached_pack_scale(s: torch.Tensor, cached: bool) -> torch.Tensor: 

124 if not cached: 

125 return _pack_scale_transpose(s) 

126 packed = _SCALE_PACK_CACHE.get(s) 

127 if packed is None: 

128 packed = _pack_scale_transpose(s) 

129 _SCALE_PACK_CACHE[s] = packed 

130 return packed 

131 

132 

133def w4a16_pack( 

134 w1: torch.Tensor, 

135 w2: torch.Tensor, 

136 w1_scale: Optional[torch.Tensor] = None, 

137 w2_scale: Optional[torch.Tensor] = None, 

138 *, 

139 cached: bool = True, 

140 pack_strategy: str = "interleave", 

141 block_size_k: int = 16, 

142) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: 

143 if pack_strategy != "interleave": 

144 raise NotImplementedError( 

145 f"pack_strategy={pack_strategy!r} not supported (only 'interleave')" 

146 ) 

147 w1_packed = _cached_pack_w(w1, block_size_k, cached=cached) 

148 w2_packed = _cached_pack_w(w2, block_size_k, cached=cached) 

149 w1_scale_packed = ( 

150 _cached_pack_scale(w1_scale, cached=cached) if w1_scale is not None else None 

151 ) 

152 w2_scale_packed = ( 

153 _cached_pack_scale(w2_scale, cached=cached) if w2_scale is not None else None 

154 ) 

155 return w1_packed, w2_packed, w1_scale_packed, w2_scale_packed 

156 

157 

158@triton.jit 

159def _dequant_int4_fp16(b, scales): 

160 x1, x2, x3, x4, x5, x6, x7, x8 = tl.inline_asm_elementwise( 

161 asm=""" 

162 { 

163 .reg .b32 r0, r1, r2, r3, r4, r5, r6, r8, r9, r10, r11, r12; 

164 .reg .b16 h0, h1, h2, h3, h4, h5, h6, h7; 

165 .reg .b16 s; 

166 mov.u32 r0, $8; 

167 shr.u32 r1, r0, 8; 

168 lop3.b32 r2, r0, 983055, 1677747200, 234; // (r0 & 0x000F000F) | 0x64006400 

169 lop3.b32 r3, r0, 15728880, 1677747200, 234; // (r0 & 0x00F000F0) | 0x64006400 

170 lop3.b32 r4, r1, 983055, 1677747200, 234; 

171 lop3.b32 r5, r1, 15728880, 1677747200, 234; 

172 mov.u32 r6, 1678271496; // 0x64086408 = (1032,1032) 

173 mov.u32 r8, 738208768; // 0x2C002C00 = (1/16,1/16) 

174 mov.u32 r9, -729754496; // 0xD480D480 = (-72,-72) 

175 sub.f16x2 r10, r2, r6; 

176 sub.f16x2 r12, r4, r6; 

177 fma.rn.f16x2 r11, r3, r8, r9; 

178 fma.rn.f16x2 r4, r5, r8, r9; 

179 mov.b32 {h0, h1}, r10; 

180 mov.b32 {h2, h3}, r11; 

181 mov.b32 {h4, h5}, r12; 

182 mov.b32 {h6, h7}, r4; 

183 mov.b16 s, $9; 

184 mul.f16 h0, h0, s; 

185 mul.f16 h1, h1, s; 

186 mul.f16 h2, h2, s; 

187 mul.f16 h3, h3, s; 

188 mul.f16 h4, h4, s; 

189 mul.f16 h5, h5, s; 

190 mul.f16 h6, h6, s; 

191 mul.f16 h7, h7, s; 

192 mov.b16 $0, h0; 

193 mov.b16 $1, h1; 

194 mov.b16 $2, h2; 

195 mov.b16 $3, h3; 

196 mov.b16 $4, h4; 

197 mov.b16 $5, h5; 

198 mov.b16 $6, h6; 

199 mov.b16 $7, h7; 

200 } 

201 """, 

202 constraints="=h,=h,=h,=h,=h,=h,=h,=h,r,h", 

203 args=[b, scales], 

204 dtype=(tl.float16,) * 8, 

205 is_pure=True, 

206 pack=1, 

207 ) 

208 return x1, x2, x3, x4, x5, x6, x7, x8 

209 

210 

211@triton.jit 

212def _dequant_int4_bf16(b, scales): 

213 x1, x2, x3, x4, x5, x6, x7, x8 = tl.inline_asm_elementwise( 

214 asm=""" 

215 { 

216 .reg .b32 r0, r1, r2, r3, q0, q1, q2, q3, s0, s1, s2, s3, magic; 

217 .reg .b16 h0, h1, h2, h3, h4, h5, h6, h7; 

218 .reg .b16 s; 

219 mov.u32 r0, $8; 

220 shr.u32 r1, r0, 4; // high nibble of bytes 0,2 -> bits 0-3 

221 shr.u32 r2, r0, 8; // low nibble of bytes 1,3 -> bits 0-3 

222 shr.u32 r3, r0, 12; // high nibble of bytes 1,3 -> bits 0-3 

223 // (x & 0x000F000F) | 0x43004300 -> bf16x2 of (128+nibble, 128+nibble) 

224 lop3.b32 q0, r0, 983055, 1124090624, 234; 

225 lop3.b32 q1, r1, 983055, 1124090624, 234; 

226 lop3.b32 q2, r2, 983055, 1124090624, 234; 

227 lop3.b32 q3, r3, 983055, 1124090624, 234; 

228 mov.u32 magic, 1124614920; // 0x43084308 = (136,136) 

229 sub.rn.bf16x2 s0, q0, magic; 

230 sub.rn.bf16x2 s1, q1, magic; 

231 sub.rn.bf16x2 s2, q2, magic; 

232 sub.rn.bf16x2 s3, q3, magic; 

233 mov.b32 {h0, h1}, s0; // (n0-8, n4-8) 

234 mov.b32 {h2, h3}, s1; // (n1-8, n5-8) 

235 mov.b32 {h4, h5}, s2; // (n2-8, n6-8) 

236 mov.b32 {h6, h7}, s3; // (n3-8, n7-8) 

237 mov.b16 s, $9; 

238 mul.rn.bf16 h0, h0, s; 

239 mul.rn.bf16 h1, h1, s; 

240 mul.rn.bf16 h2, h2, s; 

241 mul.rn.bf16 h3, h3, s; 

242 mul.rn.bf16 h4, h4, s; 

243 mul.rn.bf16 h5, h5, s; 

244 mul.rn.bf16 h6, h6, s; 

245 mul.rn.bf16 h7, h7, s; 

246 mov.b16 $0, h0; 

247 mov.b16 $1, h1; 

248 mov.b16 $2, h2; 

249 mov.b16 $3, h3; 

250 mov.b16 $4, h4; 

251 mov.b16 $5, h5; 

252 mov.b16 $6, h6; 

253 mov.b16 $7, h7; 

254 } 

255 """, 

256 constraints="=h,=h,=h,=h,=h,=h,=h,=h,r,h", 

257 args=[b, scales], 

258 dtype=(tl.bfloat16,) * 8, 

259 is_pure=True, 

260 pack=1, 

261 ) 

262 return x1, x2, x3, x4, x5, x6, x7, x8 

263 

264 

265@triton.jit 

266def _stack_along_dim0(a, b, X: tl.constexpr, Y: tl.constexpr): 

267 j = tl.join(a, b) # (X, Y, 2) 

268 p = tl.permute(j, (2, 0, 1)) # (2, X, Y) 

269 return tl.reshape(p, (2 * X, Y)) # (2X, Y) block-concat 

270 

271 

272@triton.jit 

273def _stack_8(bs, K_PACK: tl.constexpr, N: tl.constexpr): 

274 s01 = _stack_along_dim0(bs[0], bs[1], K_PACK, N) # (2*K_PACK, N) 

275 s23 = _stack_along_dim0(bs[2], bs[3], K_PACK, N) 

276 s45 = _stack_along_dim0(bs[4], bs[5], K_PACK, N) 

277 s67 = _stack_along_dim0(bs[6], bs[7], K_PACK, N) 

278 s0123 = _stack_along_dim0(s01, s23, 2 * K_PACK, N) # (4*K_PACK, N) 

279 s4567 = _stack_along_dim0(s45, s67, 2 * K_PACK, N) 

280 return _stack_along_dim0(s0123, s4567, 4 * K_PACK, N) # (8*K_PACK, N) 

281 

282 

283@triton.autotune( 

284 configs=[ 

285 triton.Config( 

286 {"BLOCK_SIZE_N": 64, "GROUP_SIZE_M": 1}, num_warps=4, num_stages=4 

287 ), 

288 triton.Config( 

289 {"BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 1}, num_warps=4, num_stages=4 

290 ), 

291 triton.Config( 

292 {"BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=4 

293 ), 

294 triton.Config( 

295 {"BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 4}, num_warps=8, num_stages=3 

296 ), 

297 triton.Config( 

298 {"BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 4}, num_warps=8, num_stages=3 

299 ), 

300 triton.Config( 

301 {"BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 4}, num_warps=8, num_stages=2 

302 ), 

303 ], 

304 key=["N", "K"], 

305) 

306@triton.jit 

307def _w4a16_moe_gemm_kernel( 

308 a_ptr, 

309 b_ptr, 

310 c_ptr, 

311 b_scale_ptr, 

312 topk_weights_ptr, 

313 sorted_token_ids_ptr, 

314 expert_ids_ptr, 

315 num_tokens_post_padded_ptr, 

316 N: tl.constexpr, 

317 K: tl.constexpr, 

318 EM, 

319 num_valid_tokens, 

320 stride_am, 

321 stride_ak, 

322 stride_be, 

323 stride_bk, 

324 stride_bn, 

325 stride_cm, 

326 stride_cn, 

327 stride_bse, 

328 stride_bsg, 

329 stride_bsn, 

330 BLOCK_SIZE_M: tl.constexpr, # token tile (MMA M-dim, or N-dim if SWAP_AB) 

331 BLOCK_SIZE_N: tl.constexpr, # weight tile (MMA N-dim, or M-dim if SWAP_AB) 

332 BLOCK_SIZE_K: tl.constexpr, # logical-K tile (must match packing) 

333 GROUP_SIZE_M: tl.constexpr, 

334 GROUP_SIZE_K: tl.constexpr, # = quant group_size (e.g. 128) 

335 MUL_ROUTED_WEIGHT: tl.constexpr, 

336 top_k: tl.constexpr, 

337 compute_type: tl.constexpr, 

338 SWAP_AB: tl.constexpr, 

339): 

340 BLOCK_SIZE_K_PACK: tl.constexpr = BLOCK_SIZE_K // 8 

341 

342 pid = tl.program_id(axis=0) 

343 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

344 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

345 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

346 group_id = pid // num_pid_in_group 

347 first_pid_m = group_id * GROUP_SIZE_M 

348 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

349 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

350 pid_n = (pid % num_pid_in_group) // group_size_m 

351 

352 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

353 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

354 return 

355 

356 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

357 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) 

358 token_mask = offs_token < num_valid_tokens 

359 

360 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

361 if off_experts == -1: 

362 if SWAP_AB: 

363 offs_cn0 = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

364 c_ptrs0 = ( 

365 c_ptr + stride_cm * offs_token[None, :] + stride_cn * offs_cn0[:, None] 

366 ) 

367 c_mask0 = token_mask[None, :] & (offs_cn0[:, None] < N) 

368 tl.store( 

369 c_ptrs0, 

370 tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=compute_type), 

371 mask=c_mask0, 

372 ) 

373 else: 

374 write_zeros_to_output( 

375 c_ptr, 

376 stride_cm, 

377 stride_cn, 

378 pid_n, 

379 N, 

380 offs_token, 

381 token_mask, 

382 BLOCK_SIZE_M, 

383 BLOCK_SIZE_N, 

384 compute_type, 

385 ) 

386 return 

387 

388 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

389 offs_ak_pack = tl.arange(0, BLOCK_SIZE_K_PACK) 

390 offs_bk = tl.arange(0, BLOCK_SIZE_K_PACK) 

391 

392 if SWAP_AB: 

393 a_base = a_ptr + (offs_token[None, :] // top_k * stride_am) 

394 b_ptrs = ( 

395 b_ptr 

396 + off_experts * stride_be 

397 + offs_bn[:, None] * stride_bn 

398 + offs_bk[None, :] * stride_bk 

399 ) 

400 accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32) 

401 else: 

402 a_base = a_ptr + (offs_token[:, None] // top_k * stride_am) 

403 b_ptrs = ( 

404 b_ptr 

405 + off_experts * stride_be 

406 + offs_bk[:, None] * stride_bk 

407 + offs_bn[None, :] * stride_bn 

408 ) 

409 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

410 scale_base = b_scale_ptr + off_experts * stride_bse + offs_bn * stride_bsn 

411 

412 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

413 b_packed = tl.load(b_ptrs) 

414 scale_idx = k * BLOCK_SIZE_K // GROUP_SIZE_K 

415 scale = tl.load(scale_base + scale_idx * stride_bsg) 

416 scale_bc = scale[:, None] if SWAP_AB else scale[None, :] 

417 

418 if compute_type == tl.float16: 

419 bs = _dequant_int4_fp16(b_packed, scale_bc) 

420 else: 

421 bs = _dequant_int4_bf16(b_packed, scale_bc) 

422 

423 k_logical_base = k * BLOCK_SIZE_K 

424 for j in tl.static_range(8): 

425 k_off = k_logical_base + j * BLOCK_SIZE_K_PACK 

426 if SWAP_AB: 

427 a_j_ptrs = a_base + (k_off + offs_ak_pack[:, None]) * stride_ak 

428 a_j = tl.load( 

429 a_j_ptrs, mask=token_mask[None, :], other=0.0 

430 ) # (K_PACK, M) 

431 accumulator = tl.dot(bs[j], a_j, acc=accumulator) # (N, M) 

432 else: 

433 a_j_ptrs = a_base + (k_off + offs_ak_pack[None, :]) * stride_ak 

434 a_j = tl.load( 

435 a_j_ptrs, mask=token_mask[:, None], other=0.0 

436 ) # (M, K_PACK) 

437 accumulator = tl.dot(a_j, bs[j], acc=accumulator) # (M, N) 

438 

439 b_ptrs += BLOCK_SIZE_K_PACK * stride_bk 

440 

441 if MUL_ROUTED_WEIGHT: 

442 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) 

443 accumulator = accumulator * ( 

444 moe_weight[None, :] if SWAP_AB else moe_weight[:, None] 

445 ) 

446 

447 accumulator = accumulator.to(compute_type) 

448 

449 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

450 if SWAP_AB: 

451 c_ptrs = c_ptr + stride_cm * offs_token[None, :] + stride_cn * offs_cn[:, None] 

452 c_mask = token_mask[None, :] & (offs_cn[:, None] < N) 

453 else: 

454 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

455 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

456 tl.store(c_ptrs, accumulator, mask=c_mask) 

457 

458 

459def _invoke_w4a16_moe_gemm( 

460 A: torch.Tensor, # (M, K) for GEMM1, (M*top_k, K) for GEMM2 

461 B: torch.Tensor, # (E, K//8, N) int32 

462 C: torch.Tensor, # (M, top_k, N) or (M*top_k, N) view 

463 B_scale: torch.Tensor, # (E, K/gs, N) fp16/bf16 

464 topk_weights: Optional[torch.Tensor], 

465 sorted_token_ids: torch.Tensor, 

466 expert_ids: torch.Tensor, 

467 num_tokens_post_padded: torch.Tensor, 

468 *, 

469 mul_routed_weight: bool, 

470 top_k: int, 

471 block_m: int, 

472 block_size_k: int, 

473 group_size: int, 

474 compute_type, # tl.float16 or tl.bfloat16 

475 swap_ab: bool = False, 

476): 

477 M_a = A.size(0) 

478 K = A.size(1) 

479 N = B.size(2) 

480 EM = sorted_token_ids.size(0) 

481 if M_a < block_m: 

482 EM = min(EM, M_a * top_k * block_m) 

483 

484 if C.ndim == 3: 

485 stride_cm = C.stride(1) 

486 stride_cn = C.stride(2) 

487 else: 

488 stride_cm = C.stride(0) 

489 stride_cn = C.stride(1) 

490 

491 grid = lambda META: ( # noqa: E731 

492 triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 

493 ) 

494 

495 _w4a16_moe_gemm_kernel[grid]( 

496 A, 

497 B, 

498 C, 

499 B_scale, 

500 topk_weights, 

501 sorted_token_ids, 

502 expert_ids, 

503 num_tokens_post_padded, 

504 N, 

505 K, 

506 EM, 

507 A.size(0) * top_k, 

508 A.stride(0), 

509 A.stride(1), 

510 B.stride(0), 

511 B.stride(1), 

512 B.stride(2), 

513 stride_cm, 

514 stride_cn, 

515 B_scale.stride(0), 

516 B_scale.stride(1), 

517 B_scale.stride(2), 

518 BLOCK_SIZE_M=block_m, 

519 BLOCK_SIZE_K=block_size_k, 

520 GROUP_SIZE_K=group_size, 

521 MUL_ROUTED_WEIGHT=mul_routed_weight, 

522 top_k=top_k, 

523 compute_type=compute_type, 

524 SWAP_AB=swap_ab, 

525 ) 

526 

527 

528def fused_moe_w4a16_gptq( 

529 hidden_states: torch.Tensor, 

530 w1: torch.Tensor, 

531 w2: torch.Tensor, 

532 w1_scale: torch.Tensor, 

533 w2_scale: torch.Tensor, 

534 topk_weights: torch.Tensor, 

535 topk_ids: torch.Tensor, 

536 *, 

537 activation: str = "silu", 

538 group_size: int = 128, 

539 apply_router_weight_on_input: bool = False, 

540 inplace: bool = False, 

541 swap_ab: bool = True, 

542) -> torch.Tensor: 

543 assert activation == "silu" 

544 assert hidden_states.dtype in (torch.float16, torch.bfloat16) 

545 assert hidden_states.is_contiguous() 

546 assert w1.dtype == torch.uint8 and w2.dtype == torch.uint8 

547 assert w1.stride(-1) == 1 and w2.stride(-1) == 1 

548 

549 M = hidden_states.size(0) 

550 K = hidden_states.size(1) 

551 E = w1.size(0) 

552 intermediate_size = w1.size(1) // 2 

553 top_k_num = topk_ids.size(1) 

554 

555 assert w1.shape == (E, 2 * intermediate_size, K // 2) 

556 assert w2.shape == (E, K, intermediate_size // 2) 

557 assert K % group_size == 0 

558 assert intermediate_size % group_size == 0 

559 assert w1_scale.shape == (E, 2 * intermediate_size, K // group_size) 

560 assert w2_scale.shape == (E, K, intermediate_size // group_size) 

561 assert w1_scale.dtype == hidden_states.dtype 

562 assert w2_scale.dtype == hidden_states.dtype 

563 assert topk_weights.shape == topk_ids.shape 

564 

565 block_size_k = group_size 

566 # Compute_type for the kernel. 

567 if hidden_states.dtype == torch.float16: 

568 compute_type = tl.float16 

569 else: 

570 compute_type = tl.bfloat16 

571 

572 w1_packed, w2_packed, w1_scale_packed, w2_scale_packed = w4a16_pack( 

573 w1, 

574 w2, 

575 w1_scale, 

576 w2_scale, 

577 block_size_k=block_size_k, 

578 cached=True, 

579 ) 

580 

581 cache13_size = M * top_k_num * max(2 * intermediate_size, K) 

582 cache13 = torch.empty( 

583 cache13_size, device=hidden_states.device, dtype=hidden_states.dtype 

584 ) 

585 intermediate_cache1 = cache13[: M * top_k_num * 2 * intermediate_size].view( 

586 M * top_k_num, 2 * intermediate_size 

587 ) 

588 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) 

589 intermediate_cache2 = torch.empty( 

590 (M * top_k_num, intermediate_size), 

591 device=hidden_states.device, 

592 dtype=hidden_states.dtype, 

593 ) 

594 

595 avg_tokens = max(M * top_k_num // max(E, 1), 1) 

596 cutoff = 8 if swap_ab else 16 

597 block_m = 16 if avg_tokens <= cutoff else (32 if avg_tokens <= 64 else 64) 

598 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

599 topk_ids=topk_ids, 

600 block_size=block_m, 

601 num_experts=E, 

602 expert_map=None, 

603 ) 

604 

605 _invoke_w4a16_moe_gemm( 

606 A=hidden_states, 

607 B=w1_packed, 

608 C=intermediate_cache1, 

609 B_scale=w1_scale_packed, 

610 topk_weights=topk_weights if apply_router_weight_on_input else None, 

611 sorted_token_ids=sorted_token_ids, 

612 expert_ids=expert_ids, 

613 num_tokens_post_padded=num_tokens_post_padded, 

614 mul_routed_weight=apply_router_weight_on_input, 

615 top_k=top_k_num, 

616 block_m=block_m, 

617 block_size_k=block_size_k, 

618 group_size=group_size, 

619 compute_type=compute_type, 

620 swap_ab=swap_ab, 

621 ) 

622 

623 gate = intermediate_cache1[:, :intermediate_size] 

624 up = intermediate_cache1[:, intermediate_size:] 

625 silu_and_mul_out(gate, up, intermediate_cache2) 

626 

627 _invoke_w4a16_moe_gemm( 

628 A=intermediate_cache2, 

629 B=w2_packed, 

630 C=intermediate_cache3, 

631 B_scale=w2_scale_packed, 

632 topk_weights=topk_weights if not apply_router_weight_on_input else None, 

633 sorted_token_ids=sorted_token_ids, 

634 expert_ids=expert_ids, 

635 num_tokens_post_padded=num_tokens_post_padded, 

636 mul_routed_weight=not apply_router_weight_on_input, 

637 top_k=1, 

638 block_m=block_m, 

639 block_size_k=block_size_k, 

640 group_size=group_size, 

641 compute_type=compute_type, 

642 swap_ab=swap_ab, 

643 ) 

644 

645 if inplace: 

646 out_hidden_states = hidden_states 

647 else: 

648 out_hidden_states = torch.empty_like(hidden_states) 

649 moe_sum(intermediate_cache3, out_hidden_states) 

650 

651 return out_hidden_states 

652 

653 

654# ---------------------------------------------------------------------------- 

655# Phase-2 impl: copy of fused_experts_impl but with the dequant shortcut 

656# removed so the wna16 Triton kernel is actually invoked for W4A16/W8A16. 

657# ---------------------------------------------------------------------------- 

658def _fused_marlin_moe_impl( 

659 hidden_states: torch.Tensor, 

660 w1: torch.Tensor, 

661 w2: torch.Tensor, 

662 topk_weights: torch.Tensor, 

663 topk_ids: torch.Tensor, 

664 inplace: bool = False, 

665 activation: str = "silu", 

666 apply_router_weight_on_input: bool = False, 

667 use_int8_w8a16: bool = False, 

668 use_int4_w4a16: bool = False, 

669 per_channel_quant: bool = False, 

670 global_num_experts: int = -1, 

671 expert_map: torch.Tensor | None = None, 

672 w1_scale: Optional[torch.Tensor] = None, 

673 w2_scale: Optional[torch.Tensor] = None, 

674 w1_zp: torch.Tensor | None = None, 

675 w2_zp: torch.Tensor | None = None, 

676 block_shape: Optional[list[int]] = None, 

677 w1_bias: Optional[torch.Tensor] = None, 

678 w2_bias: Optional[torch.Tensor] = None, 

679) -> torch.Tensor: 

680 """ 

681 Like fused_experts_impl, but: 

682 - drops all paths irrelevant to W4A16/W8A16 (no fp8, int8_w8a8, mxfp). 

683 - REMOVES the `w = w.to(fp16) * scale.unsqueeze(-1)` dequant shortcut. 

684 - forwards block_shape so the wna16 kernel uses the right group_size. 

685 """ 

686 assert ( 

687 activation == "silu" 

688 ), f"Only 'silu' activation is supported, got {activation}" 

689 assert ( 

690 use_int4_w4a16 or use_int8_w8a16 

691 ), "_fused_marlin_moe_impl expects a quantized path" 

692 

693 activation_enum = MoEActivation.from_str(activation) 

694 

695 # Packed-aware shape check. 

696 # W4A16 (pack_factor=2): w1.size(2) == K // 2 

697 # W8A16 (pack_factor=1): w1.size(2) == K 

698 expected_packed_k = ( 

699 hidden_states.size(1) // 2 if use_int4_w4a16 else hidden_states.size(1) 

700 ) 

701 assert w1.size(2) == expected_packed_k, ( 

702 f"w1 packed K mismatch: hidden_size={hidden_states.size(1)}, " 

703 f"use_int4_w4a16={use_int4_w4a16}, expected w1.size(2)={expected_packed_k}, " 

704 f"got {w1.size(2)}" 

705 ) 

706 

707 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" 

708 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" 

709 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" 

710 assert w2.stride(-1) == 1, "Stride of last dimension must be 1" 

711 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] 

712 

713 num_tokens = hidden_states.size(0) 

714 E, N, _ = w1.size() 

715 K = w2.size(1) 

716 if global_num_experts == -1: 

717 global_num_experts = E 

718 top_k_num = topk_ids.size(1) 

719 

720 CHUNK_SIZE: int = 16 * 1024 

721 M = min(num_tokens, CHUNK_SIZE) 

722 

723 config_dtype = _get_config_dtype_str( 

724 use_fp8_w8a8=False, 

725 use_int8_w8a16=use_int8_w8a16, 

726 use_int4_w4a16=use_int4_w4a16, 

727 ocp_mx_scheme=None, 

728 dtype=hidden_states.dtype, 

729 ) 

730 quant_dtype = _get_config_quant_dtype( 

731 use_fp8_w8a8=False, 

732 use_int8_w8a8=False, 

733 ocp_mx_scheme=None, 

734 ) 

735 

736 get_config_func = functools.partial( 

737 try_get_optimal_moe_config, 

738 w1.size(), 

739 w2.size(), 

740 top_k_num, 

741 config_dtype, 

742 block_shape=block_shape, 

743 E=E, 

744 ) 

745 config = get_config_func(M) 

746 config["SPLIT_K"] = 1 

747 

748 # cache1 and cache3 share memory (non-overlapping lifetime) 

749 cache13 = torch.empty( 

750 M * top_k_num * max(N, K), 

751 device=hidden_states.device, 

752 dtype=hidden_states.dtype, 

753 ) 

754 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) 

755 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) 

756 

757 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum) 

758 intermediate_cache2 = torch.empty( 

759 (M * top_k_num, activation_out_dim), 

760 device=hidden_states.device, 

761 dtype=hidden_states.dtype, 

762 ) 

763 

764 if hidden_states.dtype == torch.bfloat16: 

765 compute_type = tl.bfloat16 

766 elif hidden_states.dtype == torch.float16: 

767 compute_type = tl.float16 

768 elif hidden_states.dtype == torch.float32: 

769 compute_type = tl.float32 

770 else: 

771 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") 

772 

773 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) 

774 

775 # ★ Phase-2 KEY DIFFERENCE: the W4A16/W8A16 dequant shortcut that lived 

776 # here in `fused_experts_impl` is intentionally REMOVED. The wna16 

777 # Triton kernel will consume INT4 weights + scale directly. 

778 

779 for chunk in range((num_tokens // CHUNK_SIZE) + 1): 

780 begin_chunk_idx, end_chunk_idx = ( 

781 chunk * CHUNK_SIZE, 

782 min((chunk + 1) * CHUNK_SIZE, num_tokens), 

783 ) 

784 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] 

785 tokens_in_chunk, _ = curr_hidden_states.size() 

786 

787 if tokens_in_chunk == 0: 

788 break 

789 

790 if tokens_in_chunk < CHUNK_SIZE and chunk > 0: 

791 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] 

792 intermediate_cache2 = intermediate_cache2[ 

793 : tokens_in_chunk * topk_ids.size(1) 

794 ] 

795 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] 

796 config = get_config_func(tokens_in_chunk) 

797 config["SPLIT_K"] = 1 

798 

799 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] 

800 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] 

801 

802 # Activation quantization is a no-op for W4A16/W8A16 (no input quant). 

803 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( 

804 A=curr_hidden_states, 

805 A_scale=None, 

806 quant_dtype=quant_dtype, 

807 per_act_token_quant=per_channel_quant, 

808 block_shape=block_shape, 

809 ocp_mx_scheme=None, 

810 ) 

811 

812 # Use the routed-path (skip the SPARSITY_FACTOR shortcut, which is 

813 # explicitly disabled for quantized + block_shape configs anyway). 

814 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

815 curr_topk_ids, 

816 config["BLOCK_SIZE_M"], 

817 global_num_experts, 

818 expert_map, 

819 ) 

820 

821 # ----- GEMM 1: hidden @ w1 (fused dequant on B inside the kernel) ----- 

822 dispatch_fused_moe_kernel( 

823 qcurr_hidden_states, 

824 w1, 

825 intermediate_cache1, 

826 a1q_scale, 

827 w1_scale, 

828 w1_zp, 

829 curr_topk_weights, 

830 sorted_token_ids, 

831 expert_ids, 

832 num_tokens_post_padded, 

833 apply_router_weight_on_input, 

834 top_k_num, 

835 config, 

836 compute_type=compute_type, 

837 use_fp8_w8a8=False, 

838 use_int8_w8a8=False, 

839 use_int8_w8a16=use_int8_w8a16, 

840 use_int4_w4a16=use_int4_w4a16, 

841 per_channel_quant=per_channel_quant, 

842 block_shape=block_shape, 

843 B_bias=w1_bias, 

844 ) 

845 

846 # ----- Activation: SwiGLU = silu(gate) * up ----- 

847 apply_moe_activation( 

848 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N) 

849 ) 

850 

851 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( 

852 A=intermediate_cache2, 

853 A_scale=None, 

854 quant_dtype=quant_dtype, 

855 per_act_token_quant=per_channel_quant, 

856 block_shape=block_shape, 

857 ocp_mx_scheme=None, 

858 ) 

859 

860 if expert_map is not None: 

861 intermediate_cache3.zero_() 

862 

863 # ----- GEMM 2: act @ w2 (fused dequant on B inside the kernel) ----- 

864 dispatch_fused_moe_kernel( 

865 qintermediate_cache2, 

866 w2, 

867 intermediate_cache3, 

868 a2q_scale, 

869 w2_scale, 

870 w2_zp, 

871 curr_topk_weights, 

872 sorted_token_ids, 

873 expert_ids, 

874 num_tokens_post_padded, 

875 not apply_router_weight_on_input, 

876 1, 

877 config, 

878 compute_type=compute_type, 

879 use_fp8_w8a8=False, 

880 use_int8_w8a8=False, 

881 use_int8_w8a16=use_int8_w8a16, 

882 use_int4_w4a16=use_int4_w4a16, 

883 per_channel_quant=per_channel_quant, 

884 block_shape=block_shape, 

885 B_bias=w2_bias, 

886 ) 

887 

888 # ----- Reduce: sum topk-weighted expert outputs back per token ----- 

889 moe_sum( 

890 intermediate_cache3.view(*intermediate_cache3.size()), 

891 out_hidden_states[begin_chunk_idx:end_chunk_idx], 

892 ) 

893 

894 return out_hidden_states 

895 

896 

897# ---------------------------------------------------------------------------- 

898# Public entry point: vLLM-aligned wrapper. 

899# ---------------------------------------------------------------------------- 

900def fused_marlin_moe( 

901 hidden_states: torch.Tensor, 

902 w1: torch.Tensor, 

903 w2: torch.Tensor, 

904 bias1: Optional[torch.Tensor], 

905 bias2: Optional[torch.Tensor], 

906 w1_scale: torch.Tensor, 

907 w2_scale: torch.Tensor, 

908 topk_weights: torch.Tensor, 

909 topk_ids: torch.Tensor, 

910 quant_type_id: int, 

911 apply_router_weight_on_input: bool = False, 

912 global_num_experts: int = -1, 

913 activation: Any = None, 

914 activation_func: Optional[Callable] = None, 

915 moe_sum: Optional[Callable] = None, 

916 expert_map: Optional[torch.Tensor] = None, 

917 input_global_scale1: Optional[torch.Tensor] = None, 

918 input_global_scale2: Optional[torch.Tensor] = None, 

919 global_scale1: Optional[torch.Tensor] = None, 

920 global_scale2: Optional[torch.Tensor] = None, 

921 g_idx1: Optional[torch.Tensor] = None, 

922 g_idx2: Optional[torch.Tensor] = None, 

923 sort_indices1: Optional[torch.Tensor] = None, 

924 sort_indices2: Optional[torch.Tensor] = None, 

925 w1_zeros: Optional[torch.Tensor] = None, 

926 w2_zeros: Optional[torch.Tensor] = None, 

927 workspace: Optional[torch.Tensor] = None, 

928 intermediate_cache13: Optional[torch.Tensor] = None, 

929 intermediate_cache2: Optional[torch.Tensor] = None, 

930 is_k_full: bool = True, 

931 output: Optional[torch.Tensor] = None, 

932 input_dtype: Optional[torch.dtype] = None, 

933 inplace: bool = False, 

934 clamp_limit: Optional[float] = None, 

935 group_size: int = 128, 

936) -> torch.Tensor: 

937 """Phase-2 entry point: dispatch to local wna16-using impl.""" 

938 # ---- MVP guardrails -------------------------------------------------- 

939 if quant_type_id not in _SUPPORTED_QUANT_TYPES: 

940 raise NotImplementedError( 

941 f"MVP supports quant_type_id in {_SUPPORTED_QUANT_TYPES}, " 

942 f"got {quant_type_id}" 

943 ) 

944 if g_idx1 is not None or g_idx2 is not None: 

945 raise NotImplementedError("act_order (g_idx) not yet supported in MVP") 

946 if sort_indices1 is not None or sort_indices2 is not None: 

947 raise NotImplementedError("act_order (sort_indices) not yet supported in MVP") 

948 if input_dtype is not None: 

949 raise NotImplementedError("FP8 / INT8 input quantization not supported") 

950 if clamp_limit is not None: 

951 raise NotImplementedError("clamp_limit (GLM-4 swiglu) not supported") 

952 if input_global_scale1 is not None or input_global_scale2 is not None: 

953 raise NotImplementedError("input_global_scale not supported in MVP") 

954 if global_scale1 is not None or global_scale2 is not None: 

955 raise NotImplementedError("global_scale not supported in MVP") 

956 

957 use_int4_w4a16 = quant_type_id in _QUANT_TYPE_INT4 

958 use_int8_w8a16 = quant_type_id in _QUANT_TYPE_INT8 

959 

960 activation_str = "silu" 

961 if activation is not None: 

962 for attr in ("value", "name"): 

963 v = getattr(activation, attr, None) 

964 if isinstance(v, str): 

965 activation_str = v.lower() 

966 break 

967 if isinstance(activation, str): 

968 activation_str = activation.lower() 

969 if activation_str != "silu": 

970 raise NotImplementedError( 

971 f"MVP only supports SiLU/SwiGLU activation, got {activation_str}" 

972 ) 

973 

974 if inplace and output is not None: 

975 raise ValueError("Cannot pass both inplace=True and output") 

976 

977 if ( 

978 # The magic-trick kernel's bf16 dequant uses sub.bf16x2/mul.bf16 PTX, 

979 # which require sm_90+; on pre-Hopper fall back to the generic wna16 kernel. 

980 _is_hopper() 

981 and use_int4_w4a16 

982 and hidden_states.dtype in (torch.float16, torch.bfloat16) 

983 and w1.dtype == torch.uint8 

984 and w2.dtype == torch.uint8 

985 and bias1 is None 

986 and bias2 is None 

987 and w1_zeros is None 

988 and w2_zeros is None 

989 and expert_map is None 

990 and (global_num_experts == -1 or global_num_experts == w1.size(0)) 

991 and group_size >= 128 

992 and w1_scale.dtype == hidden_states.dtype 

993 and w2_scale.dtype == hidden_states.dtype 

994 ): 

995 result = fused_moe_w4a16_gptq( 

996 hidden_states=hidden_states, 

997 w1=w1, 

998 w2=w2, 

999 w1_scale=w1_scale, 

1000 w2_scale=w2_scale, 

1001 topk_weights=topk_weights, 

1002 topk_ids=topk_ids, 

1003 activation=activation_str, 

1004 group_size=group_size, 

1005 apply_router_weight_on_input=apply_router_weight_on_input, 

1006 inplace=inplace, 

1007 ) 

1008 if output is not None: 

1009 output.copy_(result) 

1010 return output 

1011 return result 

1012 

1013 result = _fused_marlin_moe_impl( 

1014 hidden_states=hidden_states, 

1015 w1=w1, 

1016 w2=w2, 

1017 topk_weights=topk_weights, 

1018 topk_ids=topk_ids, 

1019 inplace=inplace, 

1020 activation=activation_str, 

1021 apply_router_weight_on_input=apply_router_weight_on_input, 

1022 use_int4_w4a16=use_int4_w4a16, 

1023 use_int8_w8a16=use_int8_w8a16, 

1024 global_num_experts=global_num_experts, 

1025 expert_map=expert_map, 

1026 w1_scale=w1_scale, 

1027 w2_scale=w2_scale, 

1028 w1_zp=w1_zeros, 

1029 w2_zp=w2_zeros, 

1030 w1_bias=bias1, 

1031 w2_bias=bias2, 

1032 # Critical for Phase 2: block_shape=[0, group_size] makes the 

1033 # wna16 Triton kernel use the per-group scales correctly. 

1034 block_shape=[0, group_size], 

1035 ) 

1036 

1037 if output is not None: 

1038 output.copy_(result) 

1039 return output 

1040 return result 

1041 

1042 

1043__all__ = ["fused_marlin_moe", "QUANT_TYPE_UINT4B8", "QUANT_TYPE_UINT8B128"]