Coverage for src/flag_gems/runtime/register.py: 70%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import warnings 

2 

3from . import backend, common, error 

4from .backend.device import DeviceDetector 

5 

6 

7class Register: 

8 def __init__( 

9 self, 

10 config, 

11 user_include_ops=None, 

12 user_exclude_ops=None, 

13 cpp_patched_ops=None, 

14 lib=None, 

15 full_config_by_func=None, 

16 ): 

17 self.device = DeviceDetector() 

18 

19 # lib is a instance of torch.library.Library 

20 # Some inference chips may not support the backward implementation of operators 

21 self.lib = lib 

22 

23 # reg_key like 'CUDA' 

24 self.reg_key = self.device.dispatch_key 

25 self.all_ops = [] 

26 self.all_keys = [] 

27 

28 # optional mapping func_name -> list of config entries 

29 self.full_config_by_func = full_config_by_func 

30 

31 if user_include_ops: 

32 self.include_ops = list(user_include_ops or []) 

33 self.exclude_ops = [] 

34 self.config = config 

35 self.extract_include_config() 

36 # Use the filtered include config to avoid registering all ops. 

37 self.config = self.include_config 

38 self.for_each() 

39 else: 

40 self.vendor_unused_ops_list = self.get_vendor_unused_op() 

41 self.exclude_ops = ( 

42 list(user_exclude_ops or []) + self.vendor_unused_ops_list 

43 ) 

44 self.cpp_patched_ops = set(cpp_patched_ops or []) 

45 self.config = config 

46 self.config_filter() 

47 self.for_each() 

48 

49 def extract_include_config(self): 

50 # Simple fast path: if we have a full_config_by_func mapping, iterate 

51 # over the requested function names and collect matching config items. 

52 self.include_config = [] 

53 

54 if self.full_config_by_func: 

55 for name in self.include_ops: 

56 for config_item in self.full_config_by_func.get(name, []): 

57 op_name, func = config_item[0], config_item[1] 

58 # respect optional condition functions 

59 if len(config_item) > 2: 

60 condition_func = config_item[2] 

61 if not condition_func(): 

62 continue 

63 self.include_config.append((op_name, func)) 

64 else: 

65 # fallback: scan provided config and match by func name or op name 

66 for config_item in self.config: 

67 op_name, func = config_item[0], config_item[1] 

68 func_name = func.__name__ if hasattr(func, "__name__") else str(func) 

69 if ( 

70 func_name not in self.include_ops 

71 and op_name not in self.include_ops 

72 ): 

73 continue 

74 if len(config_item) > 2: 

75 condition_func = config_item[2] 

76 if not condition_func(): 

77 continue 

78 self.include_config.append((op_name, func)) 

79 

80 if not self.include_config: 

81 warnings.warn( 

82 "only_enable failed: No op to register. Check if include is correct." 

83 ) 

84 return 

85 

86 def config_filter(self): 

87 def enabled(item): 

88 return len(item) < 3 or bool(item[2]()) 

89 

90 self.config = [ 

91 (item[0], item[1]) 

92 for item in self.config 

93 if enabled(item) 

94 and item[1].__name__ not in self.exclude_ops 

95 and item[0] not in self.cpp_patched_ops 

96 ] 

97 

98 def get_vendor_unused_op(self): 

99 if self.device.vendor != common.vendors.NVIDIA: 

100 return backend.get_curent_device_unused_op(self.device.vendor_name) 

101 return [] 

102 

103 def register_impl(self, key, fn): 

104 if self.lib is None: 

105 raise ValueError("Library instance is not provided.") 

106 device_key = self.reg_key 

107 self.all_ops.append(fn.__name__) 

108 self.all_keys.append(key) 

109 self.lib.impl(key, fn, device_key) 

110 

111 def for_each(self): 

112 for key, func in self.config: 

113 try: 

114 self.register_impl(key, func) 

115 except Exception as e: 

116 error.register_error(e) 

117 

118 def get_all_ops(self): 

119 return self.all_ops 

120 

121 def get_all_keys(self): 

122 return self.all_keys 

123 

124 def get_unused_ops(self): 

125 return self.exclude_ops 

126 

127 def get_vendor_name(self): 

128 return self.device.vendor_name 

129 

130 def get_current_device(self): 

131 return self.device.name