Coverage for src/flag_gems/patches/patch_vllm_all.py: 14%
185 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import os
2from typing import Optional, Tuple
4import torch
5import torch.nn.functional as F
7import flag_gems
8from flag_gems.patches.patch_util import patch_module_method, patch_vllm_lib
11def custom_gems_rms_forward_cuda(self, x, residual=None):
12 from flag_gems.modules.normalization import gems_rms_forward
14 return gems_rms_forward(x, residual, self.weight, self.variance_epsilon)
17def custom_gems_rope_forward_cuda(
18 self,
19 positions: torch.Tensor,
20 query: torch.Tensor,
21 key: torch.Tensor,
22 offsets: Optional[torch.Tensor] = None,
23) -> Tuple[torch.Tensor, torch.Tensor]:
24 from flag_gems.modules.rotary_embedding import gems_rope_forward
26 self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
27 if offsets is not None:
28 positions = positions + offsets
29 positions = positions.flatten()
30 num_tokens = positions.shape[0]
32 query_shape = query.shape
33 key_shape = key.shape
34 query = query.view(num_tokens, -1, self.head_size)
35 key = key.view(num_tokens, -1, self.head_size)
37 query_rot = query[..., : self.rotary_dim]
38 key_rot = key[..., : self.rotary_dim]
39 if self.rotary_dim < self.head_size:
40 query_pass = query[..., self.rotary_dim :]
41 key_pass = key[..., self.rotary_dim :]
43 cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
45 q_embed, k_embed = gems_rope_forward(
46 query_rot,
47 key_rot,
48 cos,
49 sin,
50 position_ids=positions,
51 rotary_interleaved=not self.is_neox_style,
52 inplace=True, # set inplace to True for vLLM compatibility
53 )
55 if self.rotary_dim < self.head_size:
56 query = torch.cat((q_embed, query_pass), dim=-1).reshape(query_shape)
57 key = torch.cat((k_embed, key_pass), dim=-1).reshape(key_shape)
58 else:
59 query = q_embed.reshape(query_shape)
60 key = k_embed.reshape(key_shape)
62 return query, key
65def custom_gems_silu_and_mul(self, x: torch.Tensor) -> torch.Tensor:
66 from flag_gems.modules.activation import gems_silu_and_mul
68 d = x.shape[-1] // 2
69 x1, x2 = x[..., :d], x[..., d:]
70 return gems_silu_and_mul(x1, x2)
73def custom_gems_write_to_paged_cache(
74 key,
75 value,
76 key_cache,
77 value_cache,
78 slot_mapping,
79 kv_cache_dtype,
80 k_scale,
81 v_scale,
82):
83 from flag_gems.fused.reshape_and_cache import reshape_and_cache
85 reshape_and_cache(
86 key,
87 value,
88 key_cache,
89 value_cache,
90 slot_mapping.flatten(),
91 kv_cache_dtype,
92 k_scale,
93 v_scale,
94 )
97def custom_gems_flash_mla_forward(
98 self,
99 q_nope,
100 q_pe,
101 kv_c_and_k_pe_cache,
102 attn_metadata,
103) -> torch.Tensor:
104 from flag_gems.fused import flash_mla
106 assert kv_c_and_k_pe_cache.numel() > 0
107 assert attn_metadata.decode is not None
109 if self.kv_cache_dtype.startswith("fp8"):
110 raise NotImplementedError("FP8 Triton MLA not yet supported")
112 batch, num_head_q, head_dim_v = q_nope.shape
113 seqlen_q = 1
115 q = torch.cat([q_nope, q_pe], dim=-1)
116 head_dim = q.shape[-1]
117 q = q.view(batch, seqlen_q, num_head_q, head_dim)
119 # Add a head dim of 1
120 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
121 PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
123 block_table = attn_metadata.decode.block_table
124 output = flash_mla(
125 q,
126 block_table,
127 kv_c_and_k_pe_cache,
128 None,
129 PAGE_SIZE,
130 batch,
131 seqlen_q,
132 attn_metadata.decode.seq_lens,
133 num_head_q,
134 None,
135 head_dim,
136 head_dim_v,
137 True,
138 )
140 o = self._v_up_proj_and_o_proj(output)
141 return o
144def custom_gems_flash_attention_impl_forward(
145 self,
146 layer: torch.nn.Module,
147 query: torch.Tensor,
148 key: torch.Tensor,
149 value: torch.Tensor,
150 kv_cache: torch.Tensor,
151 attn_metadata, #: FlashAttentionMetadata,
152 output: Optional[torch.Tensor] = None,
153 output_scale: Optional[torch.Tensor] = None,
154 output_block_scale: Optional[torch.Tensor] = None,
155 **kwargs,
156) -> torch.Tensor:
157 from flag_gems import flash_attn_varlen_func, reshape_and_cache_flash
159 assert output is not None, "Output tensor must be provided."
161 if output_scale is not None:
162 raise NotImplementedError(
163 "fused output quantization is not yet supported" " for FlashAttentionImpl"
164 )
166 if attn_metadata is None:
167 # Profiling run.
168 return output
170 num_actual_tokens = attn_metadata.num_actual_tokens
171 key_cache, value_cache = kv_cache.unbind(0)
173 reshape_and_cache_flash(
174 key,
175 value,
176 key_cache,
177 value_cache,
178 attn_metadata.slot_mapping,
179 self.kv_cache_dtype,
180 layer._k_scale,
181 layer._v_scale,
182 )
184 # TODO: Support FP8
185 if self.kv_cache_dtype.startswith("fp8"):
186 raise NotImplementedError(
187 "FP8 quantization is not yet supported for FlashAttentionImpl"
188 )
189 # key_cache = key_cache.view(torch.float8_e4m3fn)
190 # value_cache = value_cache.view(torch.float8_e4m3fn)
191 # num_tokens, num_heads, head_size = query.shape
192 # query, _ = ops.scaled_fp8_quant(
193 # query.reshape((num_tokens, num_heads * head_size)).contiguous(),
194 # layer._q_scale,
195 # )
196 # query = query.reshape((num_tokens, num_heads, head_size))
198 # Compute attention and update output up to `num_actual_tokens`.
199 # use_local_attn = self.use_irope and attn_metadata.local_attn_metadata is not None
200 use_local_attn = (
201 getattr(self, "use_irope", False)
202 and getattr(attn_metadata, "local_attn_metadata", None) is not None
203 )
204 if not attn_metadata.use_cascade or use_local_attn:
205 if use_local_attn:
206 assert attn_metadata.local_attn_metadata is not None
207 local_metadata = attn_metadata.local_attn_metadata
208 cu_seqlens_q = local_metadata.local_query_start_loc
209 seqused_k = local_metadata.local_seqused_k
210 max_seqlen_q = local_metadata.local_max_query_len
211 max_seqlen_k = local_metadata.local_max_seq_len
212 block_table = local_metadata.local_block_table
213 scheduler_metadata = local_metadata.local_scheduler_metadata
214 else:
215 cu_seqlens_q = attn_metadata.query_start_loc
216 seqused_k = attn_metadata.seq_lens
217 max_seqlen_q = attn_metadata.max_query_len
218 max_seqlen_k = attn_metadata.max_seq_len
219 block_table = attn_metadata.block_table
220 scheduler_metadata = attn_metadata.scheduler_metadata
222 descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
224 flash_attn_varlen_func(
225 q=query[:num_actual_tokens],
226 k=key_cache,
227 v=value_cache,
228 out=output[:num_actual_tokens],
229 cu_seqlens_q=cu_seqlens_q,
230 max_seqlen_q=max_seqlen_q,
231 seqused_k=seqused_k,
232 max_seqlen_k=max_seqlen_k,
233 softmax_scale=self.scale,
234 causal=True,
235 alibi_slopes=self.alibi_slopes,
236 window_size=self.sliding_window,
237 block_table=block_table,
238 softcap=self.logits_soft_cap,
239 scheduler_metadata=scheduler_metadata,
240 fa_version=2,
241 q_descale=layer._q_scale.expand(descale_shape),
242 k_descale=layer._k_scale.expand(descale_shape),
243 v_descale=layer._v_scale.expand(descale_shape),
244 s_aux=None,
245 num_splits=0,
246 cp_world_size=1,
247 cp_rank=0,
248 cp_tot_seqused_k=None,
249 )
250 return output
252 # TODO: Support cascade_attention.
253 raise NotImplementedError("Cascade attention is not implemented in flag_gems.")
256def custom_silu_and_mul(out: torch.Tensor, input: torch.Tensor):
257 d = input.size(-1) // 2
258 x, y = input.split(d, dim=-1)
259 flag_gems.silu_and_mul_out(x, y, out)
262def custom_moe_align_block_size(
263 topk_ids: torch.Tensor,
264 num_experts: int,
265 block_size: int,
266 sorted_token_ids: torch.Tensor,
267 experts_ids: torch.Tensor,
268 num_tokens_post_pad: torch.Tensor,
269):
270 flag_gems.moe_align_block_size_triton(
271 topk_ids,
272 num_experts,
273 block_size,
274 sorted_token_ids,
275 experts_ids,
276 num_tokens_post_pad,
277 )
280def custom_moe_grouped_topk(
281 gating_output: torch.Tensor,
282 n_group: int,
283 topk_group: int,
284 topk: int,
285 renormalize: bool,
286 routed_scaling_factor: float,
287 bias: torch.Tensor,
288 scoring_func: int = 0,
289):
290 from flag_gems.fused import grouped_topk
292 return grouped_topk(
293 scores=gating_output,
294 n_group=n_group,
295 topk_group=topk_group,
296 topk=topk,
297 renormalize=renormalize,
298 routed_scaling_factor=routed_scaling_factor,
299 bias=bias,
300 scoring_func=scoring_func,
301 )
304def custom_topk_softmax(
305 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize=False
306):
307 flag_gems.topk_softmax(
308 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
309 )
312def custom_moe_sum(input: torch.Tensor, output: torch.Tensor):
313 from flag_gems.fused import moe_sum
315 moe_sum(input, output)
318def custom_apply_repetition_penalties(
319 logits: torch.Tensor,
320 prompt_mask: torch.Tensor,
321 output_mask: torch.Tensor,
322 repetition_penalties: torch.Tensor,
323):
324 return flag_gems.apply_repetition_penalties(
325 logits, prompt_mask, output_mask, repetition_penalties
326 )
329def custom_get_scheduler_metadata(
330 batch_size: int,
331 max_seqlen_q: int,
332 max_seqlen_k: int,
333 num_heads: int,
334 num_heads_k: int,
335 headdim: int,
336 headdim_v: int,
337 qkv_dtype: torch.dtype,
338 seqused_k: torch.Tensor,
339 cu_seqlens_q: Optional[torch.Tensor] = None,
340 cu_seqlens_k: Optional[torch.Tensor] = None,
341 cu_seqlens_k_new: Optional[torch.Tensor] = None,
342 seqused_q: Optional[torch.Tensor] = None,
343 leftpad_k: Optional[torch.Tensor] = None,
344 page_size: Optional[int] = None,
345 max_seqlen_k_new: int = 0,
346 is_causal: bool = False,
347 window_size_left: int = -1,
348 window_size_right: int = -1,
349 has_softcap: bool = False,
350 num_splits: int = 0,
351 pack_gqa: Optional[bool] = None,
352 sm_margin: int = 0,
353):
354 return flag_gems.get_scheduler_metadata(
355 batch_size,
356 max_seqlen_q,
357 max_seqlen_k,
358 num_heads,
359 num_heads_k,
360 headdim,
361 headdim_v,
362 qkv_dtype,
363 seqused_k,
364 cu_seqlens_q=cu_seqlens_q,
365 cu_seqlens_k=cu_seqlens_k,
366 cu_seqlens_k_new=cu_seqlens_k_new,
367 seqused_q=seqused_q,
368 leftpad_k=leftpad_k,
369 page_size=page_size,
370 max_seqlen_k_new=max_seqlen_k_new,
371 is_causal=is_causal,
372 window_size_left=window_size_left,
373 window_size_right=window_size_right,
374 has_softcap=has_softcap,
375 num_splits=num_splits,
376 pack_gqa=pack_gqa,
377 sm_margin=sm_margin,
378 )
381def custom_per_token_group_fp8_quant(
382 input: torch.Tensor,
383 output_q: torch.Tensor,
384 output_s: torch.Tensor,
385 group_size: int,
386 eps: float,
387 fp8_min: float,
388 fp8_max: float,
389 scale_ue8m0: bool = False,
390):
391 from flag_gems.ops import per_token_group_quant_fp8
393 column_major_scales = output_s.stride(0) < output_s.stride(1)
395 x_q, x_s = per_token_group_quant_fp8(
396 x=input,
397 group_size=group_size,
398 eps=eps,
399 column_major_scales=column_major_scales,
400 scale_ue8m0=scale_ue8m0,
401 )
403 output_q.copy_(x_q)
404 output_s.copy_(x_s)
407def custom_cutlass_scaled_mm(
408 output: torch.Tensor,
409 input: torch.Tensor,
410 weight: torch.Tensor,
411 scale_a: torch.Tensor,
412 scale_b: torch.Tensor,
413 bias: torch.Tensor | None = None,
414):
415 return flag_gems.cutlass_scaled_mm(output, input, weight, scale_a, scale_b, bias)
418def custom_concat_and_cache_mla(
419 kv_c: torch.Tensor,
420 k_pe: torch.Tensor,
421 kv_cache: torch.Tensor,
422 slot_mapping: torch.Tensor,
423 kv_cache_dtype: str,
424 scale: torch.Tensor,
425) -> None:
426 return flag_gems.concat_and_cache_mla(
427 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
428 )
431def custom_gems_flashattn_mla_forward_decode(
432 self,
433 q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
434 kv_c_and_k_pe_cache: torch.Tensor,
435 attn_metadata, # FlashAttnMLAMetadata
436 layer, # AttentionLayer
437) -> tuple[torch.Tensor, torch.Tensor | None]:
438 from flag_gems import flash_attn_varlen_func
440 assert kv_c_and_k_pe_cache.numel() > 0
441 assert attn_metadata.decode is not None
443 if type(q) is tuple:
444 q_nope, q_pe = q
445 else:
446 q_nope, q_pe = torch.split(
447 q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
448 )
450 if self.kv_cache_dtype.startswith("fp8"):
451 raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
453 kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
454 k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
456 # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
457 # kernel uses this to calculate grid dimensions. Ensure it's at least 1
458 # to prevent invalid grid configuration during graph capture.
459 max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
461 attn_out = flash_attn_varlen_func(
462 q=q_pe,
463 k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
464 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
465 q_v=q_nope,
466 max_seqlen_q=max_seqlen_q,
467 cu_seqlens_q=attn_metadata.decode.query_start_loc,
468 max_seqlen_k=attn_metadata.decode.max_seq_len,
469 seqused_k=attn_metadata.decode.seq_lens,
470 block_table=attn_metadata.decode.block_table,
471 softmax_scale=self.scale,
472 causal=True,
473 return_softmax_lse=self.need_to_return_lse_for_decode,
474 fa_version=2,
475 scheduler_metadata=attn_metadata.decode.scheduler_metadata,
476 num_splits=0,
477 cp_world_size=self.dcp_world_size,
478 cp_rank=self.dcp_rank,
479 cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
480 )
482 if self.need_to_return_lse_for_decode:
483 o, lse = attn_out
484 # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
485 return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
486 else:
487 o = attn_out
488 return o, None
491# use gems flash attention in vit attention
492def patch_vllm_vit_to_attn(vitw):
493 _orig_vit = vitw.vit_xformers_attn_wrapper
495 def _seqlens_to_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
496 cu_seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)
497 return F.pad(cu_seqlens, (1, 0))
499 def _torch_sdpa_wrapper_gems(
500 q: torch.Tensor,
501 k: torch.Tensor,
502 v: torch.Tensor,
503 cu_seqlens: torch.Tensor,
504 ):
505 import flag_gems.ops.attention as gems_attn
507 outputs = []
508 for i in range(1, int(cu_seqlens.numel())):
509 start = int(cu_seqlens[i - 1].item())
510 end = int(cu_seqlens[i].item())
511 q_i = q[:, start:end]
512 k_i = k[:, start:end]
513 v_i = v[:, start:end]
515 out_i, *_ = gems_attn.flash_attention_forward(
516 q_i,
517 k_i,
518 v_i,
519 None,
520 None,
521 int(q_i.shape[1]),
522 int(k_i.shape[1]),
523 0.0,
524 False,
525 False,
526 scale=None,
527 softcap=0.0,
528 window_size_left=None,
529 window_size_right=None,
530 seqused_k=None,
531 alibi_slopes=None,
532 disable_splitkv=True,
533 )
534 outputs.append(out_i)
536 context_layer = torch.cat(outputs, dim=1)
537 x = context_layer.transpose(0, 1).contiguous()
538 return x.view(x.shape[0], x.shape[1], -1)
540 def _wrapped_vit_xformers_attn_wrapper(
541 q: torch.Tensor,
542 k: torch.Tensor,
543 v: torch.Tensor,
544 seqlens: torch.Tensor,
545 ) -> torch.Tensor:
546 if os.getenv("VIT_ATTN_BACKEND", "xformers") == "no-sdpa":
547 return _orig_vit(q, k, v, seqlens)
549 cu_seqlens = _seqlens_to_cu_seqlens(seqlens)
550 return _torch_sdpa_wrapper_gems(q, k, v, cu_seqlens)
552 vitw.vit_xformers_attn_wrapper = _wrapped_vit_xformers_attn_wrapper
555def apply_gems_patches_to_vllm(verbose=True):
556 import vllm # noqa: F401
557 import vllm._custom_ops as ops # noqa: F401
559 try:
560 from vllm.attention.ops import vit_attn_wrappers as vitw
561 except (ModuleNotFoundError, ImportError):
562 vitw = None
563 from vllm.attention.ops.paged_attn import PagedAttention
564 from vllm.model_executor.layers.activation import SiluAndMul
565 from vllm.model_executor.layers.layernorm import RMSNorm
566 from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
567 from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl
568 from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLAImpl
569 from vllm.v1.attention.backends.mla.triton_mla import TritonMLAImpl
571 dispatch_key = flag_gems.runtime.device.dispatch_key
573 module_patches = [
574 (RMSNorm, "forward_cuda", custom_gems_rms_forward_cuda),
575 (RotaryEmbedding, "forward_cuda", custom_gems_rope_forward_cuda),
576 (PagedAttention, "write_to_paged_cache", custom_gems_write_to_paged_cache),
577 (SiluAndMul, "forward_cuda", custom_gems_silu_and_mul),
578 (TritonMLAImpl, "_forward_decode", custom_gems_flash_mla_forward),
579 (FlashAttentionImpl, "forward", custom_gems_flash_attention_impl_forward),
580 (FlashAttnMLAImpl, "_forward_decode", custom_gems_flashattn_mla_forward_decode),
581 ]
582 for cls, method_name, new_method in module_patches:
583 patch_module_method(cls, method_name, new_method, verbose)
585 lib_patches = [
586 ("_C", "silu_and_mul", custom_silu_and_mul),
587 ("_C", "cutlass_scaled_mm", custom_cutlass_scaled_mm),
588 ("_moe_C", "moe_align_block_size", custom_moe_align_block_size),
589 ("_moe_C", "topk_softmax", custom_topk_softmax),
590 ("_moe_C", "moe_sum", custom_moe_sum),
591 ("_vllm_fa3_C", "get_scheduler_metadata", custom_get_scheduler_metadata),
592 ("_moe_C", "grouped_topk", custom_moe_grouped_topk),
593 ("_C", "per_token_group_fp8_quant", custom_per_token_group_fp8_quant),
594 ("_C", "apply_repetition_penalties_", custom_apply_repetition_penalties),
595 ("_C_cache_ops", "concat_and_cache_mla", custom_concat_and_cache_mla),
596 ]
597 for lib_name, fn_name, fn in lib_patches:
598 patch_vllm_lib(lib_name, fn_name, fn, dispatch_key, verbose)
600 if vitw is not None:
601 patch_vllm_vit_to_attn(vitw)