Coverage for src/flag_gems/patches/patch_util.py: 64%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import torch 

2 

3 

4def _try_import_vllm_extension(module_name): 

5 try: 

6 __import__(module_name) 

7 return True 

8 except ImportError: 

9 return False 

10 

11 

12def _is_op_registered(lib_name, op_name): 

13 try: 

14 lib = getattr(torch.ops, lib_name) 

15 return hasattr(lib, op_name) 

16 except Exception: 

17 return False 

18 

19 

20def _ensure_vllm_library_exists(lib_name, ops_to_check=None): 

21 module_map = { 

22 "_C": "vllm._C", 

23 "_moe_C": "vllm._moe_C", 

24 "_vllm_fa3_C": "vllm.vllm_flash_attn._vllm_fa3_C", 

25 "_C_cache_ops": "vllm._C_cache_ops", 

26 } 

27 

28 module_name = module_map.get(lib_name) 

29 if module_name: 

30 imported = _try_import_vllm_extension(module_name) 

31 if imported: 

32 if ops_to_check: 

33 for op_name in ops_to_check: 

34 if _is_op_registered(lib_name, op_name): 

35 return True 

36 else: 

37 return True 

38 

39 return False 

40 

41 

42_LIB_OPS = { 

43 "_C": [ 

44 "silu_and_mul", 

45 "cutlass_scaled_mm", 

46 "per_token_group_fp8_quant", 

47 "apply_repetition_penalties_", 

48 ], 

49 "_moe_C": ["topk_softmax", "moe_align_block_size", "grouped_topk", "moe_sum"], 

50 "_vllm_fa3_C": ["get_scheduler_metadata"], 

51 "_C_cache_ops": ["concat_and_cache_mla"], 

52} 

53 

54_OP_SIGNATURES = { 

55 "_moe_C": { 

56 "topk_softmax": "(Tensor(a!) topk_weights, Tensor(b!) topk_indices, " 

57 "Tensor(c!) token_expert_indices, Tensor gating_output) -> ()", 

58 "moe_align_block_size": "(Tensor topk_ids, int num_experts, " 

59 "int block_size, Tensor(a!) sorted_token_ids, Tensor(b!) experts_ids, " 

60 "Tensor(c!) num_tokens_post_pad) -> ()", 

61 "grouped_topk": "(Tensor gating_output, int n_group, int topk_group, " 

62 "int topk, bool renormalize, float routed_scaling_factor, Tensor? bias, " 

63 "int scoring_func=0) -> (Tensor, Tensor, Tensor)", 

64 "moe_sum": "(Tensor input, Tensor(a!) output) -> ()", 

65 }, 

66 "_C": { 

67 "silu_and_mul": "(Tensor(a!) out, Tensor input) -> ()", 

68 "cutlass_scaled_mm": "(Tensor(a!) out, Tensor input, Tensor weight, " 

69 "Tensor scale_a, Tensor scale_b, Tensor? bias=None) -> ()", 

70 "per_token_group_fp8_quant": "(Tensor input, Tensor(a!) output_q, " 

71 "Tensor(b!) output_s, int group_size, float eps, float fp8_min, " 

72 "float fp8_max, bool scale_ue8m0=False) -> ()", 

73 "apply_repetition_penalties_": "(Tensor(a!) logits, Tensor prompt_mask, " 

74 "Tensor output_mask, Tensor repetition_penalties) -> Tensor", 

75 }, 

76 "_vllm_fa3_C": { 

77 "get_scheduler_metadata": "(int batch_size, int max_seqlen_q, int max_seqlen_k, " 

78 "int num_heads, int num_heads_k, int headdim, int headdim_v, " 

79 "ScalarType qkv_dtype, Tensor seqused_k, " 

80 "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " 

81 "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, " 

82 "Tensor? leftpad_k=None, int? page_size=None, " 

83 "int max_seqlen_k_new=0, bool is_causal=False, " 

84 "int window_size_left=-1, int window_size_right=-1, " 

85 "bool has_softcap=False, int num_splits=0, " 

86 "bool? pack_gqa=None, int sm_margin=0) -> Tensor", 

87 }, 

88 "_C_cache_ops": { 

89 "concat_and_cache_mla": "(Tensor kv_c, Tensor k_pe, Tensor(a!) kv_cache, " 

90 "Tensor slot_mapping, str kv_cache_dtype, Tensor scale) -> ()", 

91 }, 

92} 

93 

94 

95def _define_op_if_not_exists(lib_name, op_name, signature): 

96 if not _is_op_registered(lib_name, op_name): 

97 try: 

98 torch.library.define(f"{lib_name}::{op_name}", signature) 

99 except Exception as e: 

100 print(f"Warning: Failed to define {lib_name}::{op_name}: {e}") 

101 

102 

103_libs_loaded = {} 

104for lib_name, ops in _LIB_OPS.items(): 

105 loaded = _ensure_vllm_library_exists(lib_name, ops) 

106 _libs_loaded[lib_name] = loaded 

107 

108 # looks like it happens only when vllm is not compiled 

109 # with custom ops like tsingmicro vllm 

110 if not loaded: 

111 if lib_name in _OP_SIGNATURES: 

112 for op_name, signature in _OP_SIGNATURES[lib_name].items(): 

113 _define_op_if_not_exists(lib_name, op_name, signature) 

114 

115vllm_C_lib = torch.library.Library("_C", "IMPL") 

116vllm_moe_C_lib = torch.library.Library("_moe_C", "IMPL") 

117vllm_fa3_C_lib = torch.library.Library("_vllm_fa3_C", "IMPL") 

118vllm_C_cache_ops_lib = torch.library.Library("_C_cache_ops", "IMPL") 

119 

120libs = { 

121 "_C": vllm_C_lib, 

122 "_moe_C": vllm_moe_C_lib, 

123 "_vllm_fa3_C": vllm_fa3_C_lib, 

124 "_C_cache_ops": vllm_C_cache_ops_lib, 

125} 

126 

127 

128def patch_module_method(cls, method_name: str, new_method: callable, verbose=True): 

129 old_method = getattr(cls, method_name, None) 

130 setattr(cls, method_name, new_method) 

131 if verbose: 

132 print( 

133 f"Patched {cls.__name__}.{method_name} with FLAGGEMS {new_method.__name__}" 

134 ) 

135 return old_method 

136 

137 

138def patch_vllm_lib(lib_name, fn_name, fn, key, verbose=True): 

139 if lib_name not in libs: 

140 raise ValueError(f"Library {lib_name} is not recognized.") 

141 

142 lib = libs[lib_name] 

143 lib.impl(fn_name, fn, key) 

144 

145 if verbose: 

146 print(f"Patched torch.ops.{lib_name}.{fn_name} with FLAGGEMS {fn.__name__}")