Coverage for src/flag_gems/patches/patch_util.py: 64%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import torch
4def _try_import_vllm_extension(module_name):
5 try:
6 __import__(module_name)
7 return True
8 except ImportError:
9 return False
12def _is_op_registered(lib_name, op_name):
13 try:
14 lib = getattr(torch.ops, lib_name)
15 return hasattr(lib, op_name)
16 except Exception:
17 return False
20def _ensure_vllm_library_exists(lib_name, ops_to_check=None):
21 module_map = {
22 "_C": "vllm._C",
23 "_moe_C": "vllm._moe_C",
24 "_vllm_fa3_C": "vllm.vllm_flash_attn._vllm_fa3_C",
25 "_C_cache_ops": "vllm._C_cache_ops",
26 }
28 module_name = module_map.get(lib_name)
29 if module_name:
30 imported = _try_import_vllm_extension(module_name)
31 if imported:
32 if ops_to_check:
33 for op_name in ops_to_check:
34 if _is_op_registered(lib_name, op_name):
35 return True
36 else:
37 return True
39 return False
42_LIB_OPS = {
43 "_C": [
44 "silu_and_mul",
45 "cutlass_scaled_mm",
46 "per_token_group_fp8_quant",
47 "apply_repetition_penalties_",
48 ],
49 "_moe_C": ["topk_softmax", "moe_align_block_size", "grouped_topk", "moe_sum"],
50 "_vllm_fa3_C": ["get_scheduler_metadata"],
51 "_C_cache_ops": ["concat_and_cache_mla"],
52}
54_OP_SIGNATURES = {
55 "_moe_C": {
56 "topk_softmax": "(Tensor(a!) topk_weights, Tensor(b!) topk_indices, "
57 "Tensor(c!) token_expert_indices, Tensor gating_output) -> ()",
58 "moe_align_block_size": "(Tensor topk_ids, int num_experts, "
59 "int block_size, Tensor(a!) sorted_token_ids, Tensor(b!) experts_ids, "
60 "Tensor(c!) num_tokens_post_pad) -> ()",
61 "grouped_topk": "(Tensor gating_output, int n_group, int topk_group, "
62 "int topk, bool renormalize, float routed_scaling_factor, Tensor? bias, "
63 "int scoring_func=0) -> (Tensor, Tensor, Tensor)",
64 "moe_sum": "(Tensor input, Tensor(a!) output) -> ()",
65 },
66 "_C": {
67 "silu_and_mul": "(Tensor(a!) out, Tensor input) -> ()",
68 "cutlass_scaled_mm": "(Tensor(a!) out, Tensor input, Tensor weight, "
69 "Tensor scale_a, Tensor scale_b, Tensor? bias=None) -> ()",
70 "per_token_group_fp8_quant": "(Tensor input, Tensor(a!) output_q, "
71 "Tensor(b!) output_s, int group_size, float eps, float fp8_min, "
72 "float fp8_max, bool scale_ue8m0=False) -> ()",
73 "apply_repetition_penalties_": "(Tensor(a!) logits, Tensor prompt_mask, "
74 "Tensor output_mask, Tensor repetition_penalties) -> Tensor",
75 },
76 "_vllm_fa3_C": {
77 "get_scheduler_metadata": "(int batch_size, int max_seqlen_q, int max_seqlen_k, "
78 "int num_heads, int num_heads_k, int headdim, int headdim_v, "
79 "ScalarType qkv_dtype, Tensor seqused_k, "
80 "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, "
81 "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, "
82 "Tensor? leftpad_k=None, int? page_size=None, "
83 "int max_seqlen_k_new=0, bool is_causal=False, "
84 "int window_size_left=-1, int window_size_right=-1, "
85 "bool has_softcap=False, int num_splits=0, "
86 "bool? pack_gqa=None, int sm_margin=0) -> Tensor",
87 },
88 "_C_cache_ops": {
89 "concat_and_cache_mla": "(Tensor kv_c, Tensor k_pe, Tensor(a!) kv_cache, "
90 "Tensor slot_mapping, str kv_cache_dtype, Tensor scale) -> ()",
91 },
92}
95def _define_op_if_not_exists(lib_name, op_name, signature):
96 if not _is_op_registered(lib_name, op_name):
97 try:
98 torch.library.define(f"{lib_name}::{op_name}", signature)
99 except Exception as e:
100 print(f"Warning: Failed to define {lib_name}::{op_name}: {e}")
103_libs_loaded = {}
104for lib_name, ops in _LIB_OPS.items():
105 loaded = _ensure_vllm_library_exists(lib_name, ops)
106 _libs_loaded[lib_name] = loaded
108 # looks like it happens only when vllm is not compiled
109 # with custom ops like tsingmicro vllm
110 if not loaded:
111 if lib_name in _OP_SIGNATURES:
112 for op_name, signature in _OP_SIGNATURES[lib_name].items():
113 _define_op_if_not_exists(lib_name, op_name, signature)
115vllm_C_lib = torch.library.Library("_C", "IMPL")
116vllm_moe_C_lib = torch.library.Library("_moe_C", "IMPL")
117vllm_fa3_C_lib = torch.library.Library("_vllm_fa3_C", "IMPL")
118vllm_C_cache_ops_lib = torch.library.Library("_C_cache_ops", "IMPL")
120libs = {
121 "_C": vllm_C_lib,
122 "_moe_C": vllm_moe_C_lib,
123 "_vllm_fa3_C": vllm_fa3_C_lib,
124 "_C_cache_ops": vllm_C_cache_ops_lib,
125}
128def patch_module_method(cls, method_name: str, new_method: callable, verbose=True):
129 old_method = getattr(cls, method_name, None)
130 setattr(cls, method_name, new_method)
131 if verbose:
132 print(
133 f"Patched {cls.__name__}.{method_name} with FLAGGEMS {new_method.__name__}"
134 )
135 return old_method
138def patch_vllm_lib(lib_name, fn_name, fn, key, verbose=True):
139 if lib_name not in libs:
140 raise ValueError(f"Library {lib_name} is not recognized.")
142 lib = libs[lib_name]
143 lib.impl(fn_name, fn, key)
145 if verbose:
146 print(f"Patched torch.ops.{lib_name}.{fn_name} with FLAGGEMS {fn.__name__}")