Coverage for src/flag_gems/patches/patch_vllm_all.py: 14%

185 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import os 

2from typing import Optional, Tuple 

3 

4import torch 

5import torch.nn.functional as F 

6 

7import flag_gems 

8from flag_gems.patches.patch_util import patch_module_method, patch_vllm_lib 

9 

10 

11def custom_gems_rms_forward_cuda(self, x, residual=None): 

12 from flag_gems.modules.normalization import gems_rms_forward 

13 

14 return gems_rms_forward(x, residual, self.weight, self.variance_epsilon) 

15 

16 

17def custom_gems_rope_forward_cuda( 

18 self, 

19 positions: torch.Tensor, 

20 query: torch.Tensor, 

21 key: torch.Tensor, 

22 offsets: Optional[torch.Tensor] = None, 

23) -> Tuple[torch.Tensor, torch.Tensor]: 

24 from flag_gems.modules.rotary_embedding import gems_rope_forward 

25 

26 self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) 

27 if offsets is not None: 

28 positions = positions + offsets 

29 positions = positions.flatten() 

30 num_tokens = positions.shape[0] 

31 

32 query_shape = query.shape 

33 key_shape = key.shape 

34 query = query.view(num_tokens, -1, self.head_size) 

35 key = key.view(num_tokens, -1, self.head_size) 

36 

37 query_rot = query[..., : self.rotary_dim] 

38 key_rot = key[..., : self.rotary_dim] 

39 if self.rotary_dim < self.head_size: 

40 query_pass = query[..., self.rotary_dim :] 

41 key_pass = key[..., self.rotary_dim :] 

42 

43 cos, sin = self.cos_sin_cache.chunk(2, dim=-1) 

44 

45 q_embed, k_embed = gems_rope_forward( 

46 query_rot, 

47 key_rot, 

48 cos, 

49 sin, 

50 position_ids=positions, 

51 rotary_interleaved=not self.is_neox_style, 

52 inplace=True, # set inplace to True for vLLM compatibility 

53 ) 

54 

55 if self.rotary_dim < self.head_size: 

56 query = torch.cat((q_embed, query_pass), dim=-1).reshape(query_shape) 

57 key = torch.cat((k_embed, key_pass), dim=-1).reshape(key_shape) 

58 else: 

59 query = q_embed.reshape(query_shape) 

60 key = k_embed.reshape(key_shape) 

61 

62 return query, key 

63 

64 

65def custom_gems_silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: 

66 from flag_gems.modules.activation import gems_silu_and_mul 

67 

68 d = x.shape[-1] // 2 

69 x1, x2 = x[..., :d], x[..., d:] 

70 return gems_silu_and_mul(x1, x2) 

71 

72 

73def custom_gems_write_to_paged_cache( 

74 key, 

75 value, 

76 key_cache, 

77 value_cache, 

78 slot_mapping, 

79 kv_cache_dtype, 

80 k_scale, 

81 v_scale, 

82): 

83 from flag_gems.fused.reshape_and_cache import reshape_and_cache 

84 

85 reshape_and_cache( 

86 key, 

87 value, 

88 key_cache, 

89 value_cache, 

90 slot_mapping.flatten(), 

91 kv_cache_dtype, 

92 k_scale, 

93 v_scale, 

94 ) 

95 

96 

97def custom_gems_flash_mla_forward( 

98 self, 

99 q_nope, 

100 q_pe, 

101 kv_c_and_k_pe_cache, 

102 attn_metadata, 

103) -> torch.Tensor: 

104 from flag_gems.fused import flash_mla 

105 

106 assert kv_c_and_k_pe_cache.numel() > 0 

107 assert attn_metadata.decode is not None 

108 

109 if self.kv_cache_dtype.startswith("fp8"): 

110 raise NotImplementedError("FP8 Triton MLA not yet supported") 

111 

112 batch, num_head_q, head_dim_v = q_nope.shape 

113 seqlen_q = 1 

114 

115 q = torch.cat([q_nope, q_pe], dim=-1) 

116 head_dim = q.shape[-1] 

117 q = q.view(batch, seqlen_q, num_head_q, head_dim) 

118 

119 # Add a head dim of 1 

120 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) 

121 PAGE_SIZE = kv_c_and_k_pe_cache.size(1) 

122 

123 block_table = attn_metadata.decode.block_table 

124 output = flash_mla( 

125 q, 

126 block_table, 

127 kv_c_and_k_pe_cache, 

128 None, 

129 PAGE_SIZE, 

130 batch, 

131 seqlen_q, 

132 attn_metadata.decode.seq_lens, 

133 num_head_q, 

134 None, 

135 head_dim, 

136 head_dim_v, 

137 True, 

138 ) 

139 

140 o = self._v_up_proj_and_o_proj(output) 

141 return o 

142 

143 

144def custom_gems_flash_attention_impl_forward( 

145 self, 

146 layer: torch.nn.Module, 

147 query: torch.Tensor, 

148 key: torch.Tensor, 

149 value: torch.Tensor, 

150 kv_cache: torch.Tensor, 

151 attn_metadata, #: FlashAttentionMetadata, 

152 output: Optional[torch.Tensor] = None, 

153 output_scale: Optional[torch.Tensor] = None, 

154 output_block_scale: Optional[torch.Tensor] = None, 

155 **kwargs, 

156) -> torch.Tensor: 

157 from flag_gems import flash_attn_varlen_func, reshape_and_cache_flash 

158 

159 assert output is not None, "Output tensor must be provided." 

160 

161 if output_scale is not None: 

162 raise NotImplementedError( 

163 "fused output quantization is not yet supported" " for FlashAttentionImpl" 

164 ) 

165 

166 if attn_metadata is None: 

167 # Profiling run. 

168 return output 

169 

170 num_actual_tokens = attn_metadata.num_actual_tokens 

171 key_cache, value_cache = kv_cache.unbind(0) 

172 

173 reshape_and_cache_flash( 

174 key, 

175 value, 

176 key_cache, 

177 value_cache, 

178 attn_metadata.slot_mapping, 

179 self.kv_cache_dtype, 

180 layer._k_scale, 

181 layer._v_scale, 

182 ) 

183 

184 # TODO: Support FP8 

185 if self.kv_cache_dtype.startswith("fp8"): 

186 raise NotImplementedError( 

187 "FP8 quantization is not yet supported for FlashAttentionImpl" 

188 ) 

189 # key_cache = key_cache.view(torch.float8_e4m3fn) 

190 # value_cache = value_cache.view(torch.float8_e4m3fn) 

191 # num_tokens, num_heads, head_size = query.shape 

192 # query, _ = ops.scaled_fp8_quant( 

193 # query.reshape((num_tokens, num_heads * head_size)).contiguous(), 

194 # layer._q_scale, 

195 # ) 

196 # query = query.reshape((num_tokens, num_heads, head_size)) 

197 

198 # Compute attention and update output up to `num_actual_tokens`. 

199 # use_local_attn = self.use_irope and attn_metadata.local_attn_metadata is not None 

200 use_local_attn = ( 

201 getattr(self, "use_irope", False) 

202 and getattr(attn_metadata, "local_attn_metadata", None) is not None 

203 ) 

204 if not attn_metadata.use_cascade or use_local_attn: 

205 if use_local_attn: 

206 assert attn_metadata.local_attn_metadata is not None 

207 local_metadata = attn_metadata.local_attn_metadata 

208 cu_seqlens_q = local_metadata.local_query_start_loc 

209 seqused_k = local_metadata.local_seqused_k 

210 max_seqlen_q = local_metadata.local_max_query_len 

211 max_seqlen_k = local_metadata.local_max_seq_len 

212 block_table = local_metadata.local_block_table 

213 scheduler_metadata = local_metadata.local_scheduler_metadata 

214 else: 

215 cu_seqlens_q = attn_metadata.query_start_loc 

216 seqused_k = attn_metadata.seq_lens 

217 max_seqlen_q = attn_metadata.max_query_len 

218 max_seqlen_k = attn_metadata.max_seq_len 

219 block_table = attn_metadata.block_table 

220 scheduler_metadata = attn_metadata.scheduler_metadata 

221 

222 descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) 

223 

224 flash_attn_varlen_func( 

225 q=query[:num_actual_tokens], 

226 k=key_cache, 

227 v=value_cache, 

228 out=output[:num_actual_tokens], 

229 cu_seqlens_q=cu_seqlens_q, 

230 max_seqlen_q=max_seqlen_q, 

231 seqused_k=seqused_k, 

232 max_seqlen_k=max_seqlen_k, 

233 softmax_scale=self.scale, 

234 causal=True, 

235 alibi_slopes=self.alibi_slopes, 

236 window_size=self.sliding_window, 

237 block_table=block_table, 

238 softcap=self.logits_soft_cap, 

239 scheduler_metadata=scheduler_metadata, 

240 fa_version=2, 

241 q_descale=layer._q_scale.expand(descale_shape), 

242 k_descale=layer._k_scale.expand(descale_shape), 

243 v_descale=layer._v_scale.expand(descale_shape), 

244 s_aux=None, 

245 num_splits=0, 

246 cp_world_size=1, 

247 cp_rank=0, 

248 cp_tot_seqused_k=None, 

249 ) 

250 return output 

251 

252 # TODO: Support cascade_attention. 

253 raise NotImplementedError("Cascade attention is not implemented in flag_gems.") 

254 

255 

256def custom_silu_and_mul(out: torch.Tensor, input: torch.Tensor): 

257 d = input.size(-1) // 2 

258 x, y = input.split(d, dim=-1) 

259 flag_gems.silu_and_mul_out(x, y, out) 

260 

261 

262def custom_moe_align_block_size( 

263 topk_ids: torch.Tensor, 

264 num_experts: int, 

265 block_size: int, 

266 sorted_token_ids: torch.Tensor, 

267 experts_ids: torch.Tensor, 

268 num_tokens_post_pad: torch.Tensor, 

269): 

270 flag_gems.moe_align_block_size_triton( 

271 topk_ids, 

272 num_experts, 

273 block_size, 

274 sorted_token_ids, 

275 experts_ids, 

276 num_tokens_post_pad, 

277 ) 

278 

279 

280def custom_moe_grouped_topk( 

281 gating_output: torch.Tensor, 

282 n_group: int, 

283 topk_group: int, 

284 topk: int, 

285 renormalize: bool, 

286 routed_scaling_factor: float, 

287 bias: torch.Tensor, 

288 scoring_func: int = 0, 

289): 

290 from flag_gems.fused import grouped_topk 

291 

292 return grouped_topk( 

293 scores=gating_output, 

294 n_group=n_group, 

295 topk_group=topk_group, 

296 topk=topk, 

297 renormalize=renormalize, 

298 routed_scaling_factor=routed_scaling_factor, 

299 bias=bias, 

300 scoring_func=scoring_func, 

301 ) 

302 

303 

304def custom_topk_softmax( 

305 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize=False 

306): 

307 flag_gems.topk_softmax( 

308 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize 

309 ) 

310 

311 

312def custom_moe_sum(input: torch.Tensor, output: torch.Tensor): 

313 from flag_gems.fused import moe_sum 

314 

315 moe_sum(input, output) 

316 

317 

318def custom_apply_repetition_penalties( 

319 logits: torch.Tensor, 

320 prompt_mask: torch.Tensor, 

321 output_mask: torch.Tensor, 

322 repetition_penalties: torch.Tensor, 

323): 

324 return flag_gems.apply_repetition_penalties( 

325 logits, prompt_mask, output_mask, repetition_penalties 

326 ) 

327 

328 

329def custom_get_scheduler_metadata( 

330 batch_size: int, 

331 max_seqlen_q: int, 

332 max_seqlen_k: int, 

333 num_heads: int, 

334 num_heads_k: int, 

335 headdim: int, 

336 headdim_v: int, 

337 qkv_dtype: torch.dtype, 

338 seqused_k: torch.Tensor, 

339 cu_seqlens_q: Optional[torch.Tensor] = None, 

340 cu_seqlens_k: Optional[torch.Tensor] = None, 

341 cu_seqlens_k_new: Optional[torch.Tensor] = None, 

342 seqused_q: Optional[torch.Tensor] = None, 

343 leftpad_k: Optional[torch.Tensor] = None, 

344 page_size: Optional[int] = None, 

345 max_seqlen_k_new: int = 0, 

346 is_causal: bool = False, 

347 window_size_left: int = -1, 

348 window_size_right: int = -1, 

349 has_softcap: bool = False, 

350 num_splits: int = 0, 

351 pack_gqa: Optional[bool] = None, 

352 sm_margin: int = 0, 

353): 

354 return flag_gems.get_scheduler_metadata( 

355 batch_size, 

356 max_seqlen_q, 

357 max_seqlen_k, 

358 num_heads, 

359 num_heads_k, 

360 headdim, 

361 headdim_v, 

362 qkv_dtype, 

363 seqused_k, 

364 cu_seqlens_q=cu_seqlens_q, 

365 cu_seqlens_k=cu_seqlens_k, 

366 cu_seqlens_k_new=cu_seqlens_k_new, 

367 seqused_q=seqused_q, 

368 leftpad_k=leftpad_k, 

369 page_size=page_size, 

370 max_seqlen_k_new=max_seqlen_k_new, 

371 is_causal=is_causal, 

372 window_size_left=window_size_left, 

373 window_size_right=window_size_right, 

374 has_softcap=has_softcap, 

375 num_splits=num_splits, 

376 pack_gqa=pack_gqa, 

377 sm_margin=sm_margin, 

378 ) 

379 

380 

381def custom_per_token_group_fp8_quant( 

382 input: torch.Tensor, 

383 output_q: torch.Tensor, 

384 output_s: torch.Tensor, 

385 group_size: int, 

386 eps: float, 

387 fp8_min: float, 

388 fp8_max: float, 

389 scale_ue8m0: bool = False, 

390): 

391 from flag_gems.ops import per_token_group_quant_fp8 

392 

393 column_major_scales = output_s.stride(0) < output_s.stride(1) 

394 

395 x_q, x_s = per_token_group_quant_fp8( 

396 x=input, 

397 group_size=group_size, 

398 eps=eps, 

399 column_major_scales=column_major_scales, 

400 scale_ue8m0=scale_ue8m0, 

401 ) 

402 

403 output_q.copy_(x_q) 

404 output_s.copy_(x_s) 

405 

406 

407def custom_cutlass_scaled_mm( 

408 output: torch.Tensor, 

409 input: torch.Tensor, 

410 weight: torch.Tensor, 

411 scale_a: torch.Tensor, 

412 scale_b: torch.Tensor, 

413 bias: torch.Tensor | None = None, 

414): 

415 return flag_gems.cutlass_scaled_mm(output, input, weight, scale_a, scale_b, bias) 

416 

417 

418def custom_concat_and_cache_mla( 

419 kv_c: torch.Tensor, 

420 k_pe: torch.Tensor, 

421 kv_cache: torch.Tensor, 

422 slot_mapping: torch.Tensor, 

423 kv_cache_dtype: str, 

424 scale: torch.Tensor, 

425) -> None: 

426 return flag_gems.concat_and_cache_mla( 

427 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale 

428 ) 

429 

430 

431def custom_gems_flashattn_mla_forward_decode( 

432 self, 

433 q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], 

434 kv_c_and_k_pe_cache: torch.Tensor, 

435 attn_metadata, # FlashAttnMLAMetadata 

436 layer, # AttentionLayer 

437) -> tuple[torch.Tensor, torch.Tensor | None]: 

438 from flag_gems import flash_attn_varlen_func 

439 

440 assert kv_c_and_k_pe_cache.numel() > 0 

441 assert attn_metadata.decode is not None 

442 

443 if type(q) is tuple: 

444 q_nope, q_pe = q 

445 else: 

446 q_nope, q_pe = torch.split( 

447 q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 

448 ) 

449 

450 if self.kv_cache_dtype.startswith("fp8"): 

451 raise NotImplementedError("FP8 FlashAttention MLA not yet supported") 

452 

453 kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] 

454 k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :] 

455 

456 # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the 

457 # kernel uses this to calculate grid dimensions. Ensure it's at least 1 

458 # to prevent invalid grid configuration during graph capture. 

459 max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) 

460 

461 attn_out = flash_attn_varlen_func( 

462 q=q_pe, 

463 k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 

464 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 

465 q_v=q_nope, 

466 max_seqlen_q=max_seqlen_q, 

467 cu_seqlens_q=attn_metadata.decode.query_start_loc, 

468 max_seqlen_k=attn_metadata.decode.max_seq_len, 

469 seqused_k=attn_metadata.decode.seq_lens, 

470 block_table=attn_metadata.decode.block_table, 

471 softmax_scale=self.scale, 

472 causal=True, 

473 return_softmax_lse=self.need_to_return_lse_for_decode, 

474 fa_version=2, 

475 scheduler_metadata=attn_metadata.decode.scheduler_metadata, 

476 num_splits=0, 

477 cp_world_size=self.dcp_world_size, 

478 cp_rank=self.dcp_rank, 

479 cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, 

480 ) 

481 

482 if self.need_to_return_lse_for_decode: 

483 o, lse = attn_out 

484 # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] 

485 return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] 

486 else: 

487 o = attn_out 

488 return o, None 

489 

490 

491# use gems flash attention in vit attention 

492def patch_vllm_vit_to_attn(vitw): 

493 _orig_vit = vitw.vit_xformers_attn_wrapper 

494 

495 def _seqlens_to_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor: 

496 cu_seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32) 

497 return F.pad(cu_seqlens, (1, 0)) 

498 

499 def _torch_sdpa_wrapper_gems( 

500 q: torch.Tensor, 

501 k: torch.Tensor, 

502 v: torch.Tensor, 

503 cu_seqlens: torch.Tensor, 

504 ): 

505 import flag_gems.ops.attention as gems_attn 

506 

507 outputs = [] 

508 for i in range(1, int(cu_seqlens.numel())): 

509 start = int(cu_seqlens[i - 1].item()) 

510 end = int(cu_seqlens[i].item()) 

511 q_i = q[:, start:end] 

512 k_i = k[:, start:end] 

513 v_i = v[:, start:end] 

514 

515 out_i, *_ = gems_attn.flash_attention_forward( 

516 q_i, 

517 k_i, 

518 v_i, 

519 None, 

520 None, 

521 int(q_i.shape[1]), 

522 int(k_i.shape[1]), 

523 0.0, 

524 False, 

525 False, 

526 scale=None, 

527 softcap=0.0, 

528 window_size_left=None, 

529 window_size_right=None, 

530 seqused_k=None, 

531 alibi_slopes=None, 

532 disable_splitkv=True, 

533 ) 

534 outputs.append(out_i) 

535 

536 context_layer = torch.cat(outputs, dim=1) 

537 x = context_layer.transpose(0, 1).contiguous() 

538 return x.view(x.shape[0], x.shape[1], -1) 

539 

540 def _wrapped_vit_xformers_attn_wrapper( 

541 q: torch.Tensor, 

542 k: torch.Tensor, 

543 v: torch.Tensor, 

544 seqlens: torch.Tensor, 

545 ) -> torch.Tensor: 

546 if os.getenv("VIT_ATTN_BACKEND", "xformers") == "no-sdpa": 

547 return _orig_vit(q, k, v, seqlens) 

548 

549 cu_seqlens = _seqlens_to_cu_seqlens(seqlens) 

550 return _torch_sdpa_wrapper_gems(q, k, v, cu_seqlens) 

551 

552 vitw.vit_xformers_attn_wrapper = _wrapped_vit_xformers_attn_wrapper 

553 

554 

555def apply_gems_patches_to_vllm(verbose=True): 

556 import vllm # noqa: F401 

557 import vllm._custom_ops as ops # noqa: F401 

558 

559 try: 

560 from vllm.attention.ops import vit_attn_wrappers as vitw 

561 except (ModuleNotFoundError, ImportError): 

562 vitw = None 

563 from vllm.attention.ops.paged_attn import PagedAttention 

564 from vllm.model_executor.layers.activation import SiluAndMul 

565 from vllm.model_executor.layers.layernorm import RMSNorm 

566 from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding 

567 from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl 

568 from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLAImpl 

569 from vllm.v1.attention.backends.mla.triton_mla import TritonMLAImpl 

570 

571 dispatch_key = flag_gems.runtime.device.dispatch_key 

572 

573 module_patches = [ 

574 (RMSNorm, "forward_cuda", custom_gems_rms_forward_cuda), 

575 (RotaryEmbedding, "forward_cuda", custom_gems_rope_forward_cuda), 

576 (PagedAttention, "write_to_paged_cache", custom_gems_write_to_paged_cache), 

577 (SiluAndMul, "forward_cuda", custom_gems_silu_and_mul), 

578 (TritonMLAImpl, "_forward_decode", custom_gems_flash_mla_forward), 

579 (FlashAttentionImpl, "forward", custom_gems_flash_attention_impl_forward), 

580 (FlashAttnMLAImpl, "_forward_decode", custom_gems_flashattn_mla_forward_decode), 

581 ] 

582 for cls, method_name, new_method in module_patches: 

583 patch_module_method(cls, method_name, new_method, verbose) 

584 

585 lib_patches = [ 

586 ("_C", "silu_and_mul", custom_silu_and_mul), 

587 ("_C", "cutlass_scaled_mm", custom_cutlass_scaled_mm), 

588 ("_moe_C", "moe_align_block_size", custom_moe_align_block_size), 

589 ("_moe_C", "topk_softmax", custom_topk_softmax), 

590 ("_moe_C", "moe_sum", custom_moe_sum), 

591 ("_vllm_fa3_C", "get_scheduler_metadata", custom_get_scheduler_metadata), 

592 ("_moe_C", "grouped_topk", custom_moe_grouped_topk), 

593 ("_C", "per_token_group_fp8_quant", custom_per_token_group_fp8_quant), 

594 ("_C", "apply_repetition_penalties_", custom_apply_repetition_penalties), 

595 ("_C_cache_ops", "concat_and_cache_mla", custom_concat_and_cache_mla), 

596 ] 

597 for lib_name, fn_name, fn in lib_patches: 

598 patch_vllm_lib(lib_name, fn_name, fn, dispatch_key, verbose) 

599 

600 if vitw is not None: 

601 patch_vllm_vit_to_attn(vitw)