Coverage for src/flag_gems/runtime/register.py: 70%
79 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import warnings
3from . import backend, common, error
4from .backend.device import DeviceDetector
7class Register:
8 def __init__(
9 self,
10 config,
11 user_include_ops=None,
12 user_exclude_ops=None,
13 cpp_patched_ops=None,
14 lib=None,
15 full_config_by_func=None,
16 ):
17 self.device = DeviceDetector()
19 # lib is a instance of torch.library.Library
20 # Some inference chips may not support the backward implementation of operators
21 self.lib = lib
23 # reg_key like 'CUDA'
24 self.reg_key = self.device.dispatch_key
25 self.all_ops = []
26 self.all_keys = []
28 # optional mapping func_name -> list of config entries
29 self.full_config_by_func = full_config_by_func
31 if user_include_ops:
32 self.include_ops = list(user_include_ops or [])
33 self.exclude_ops = []
34 self.config = config
35 self.extract_include_config()
36 # Use the filtered include config to avoid registering all ops.
37 self.config = self.include_config
38 self.for_each()
39 else:
40 self.vendor_unused_ops_list = self.get_vendor_unused_op()
41 self.exclude_ops = (
42 list(user_exclude_ops or []) + self.vendor_unused_ops_list
43 )
44 self.cpp_patched_ops = set(cpp_patched_ops or [])
45 self.config = config
46 self.config_filter()
47 self.for_each()
49 def extract_include_config(self):
50 # Simple fast path: if we have a full_config_by_func mapping, iterate
51 # over the requested function names and collect matching config items.
52 self.include_config = []
54 if self.full_config_by_func:
55 for name in self.include_ops:
56 for config_item in self.full_config_by_func.get(name, []):
57 op_name, func = config_item[0], config_item[1]
58 # respect optional condition functions
59 if len(config_item) > 2:
60 condition_func = config_item[2]
61 if not condition_func():
62 continue
63 self.include_config.append((op_name, func))
64 else:
65 # fallback: scan provided config and match by func name or op name
66 for config_item in self.config:
67 op_name, func = config_item[0], config_item[1]
68 func_name = func.__name__ if hasattr(func, "__name__") else str(func)
69 if (
70 func_name not in self.include_ops
71 and op_name not in self.include_ops
72 ):
73 continue
74 if len(config_item) > 2:
75 condition_func = config_item[2]
76 if not condition_func():
77 continue
78 self.include_config.append((op_name, func))
80 if not self.include_config:
81 warnings.warn(
82 "only_enable failed: No op to register. Check if include is correct."
83 )
84 return
86 def config_filter(self):
87 def enabled(item):
88 return len(item) < 3 or bool(item[2]())
90 self.config = [
91 (item[0], item[1])
92 for item in self.config
93 if enabled(item)
94 and item[1].__name__ not in self.exclude_ops
95 and item[0] not in self.cpp_patched_ops
96 ]
98 def get_vendor_unused_op(self):
99 if self.device.vendor != common.vendors.NVIDIA:
100 return backend.get_curent_device_unused_op(self.device.vendor_name)
101 return []
103 def register_impl(self, key, fn):
104 if self.lib is None:
105 raise ValueError("Library instance is not provided.")
106 device_key = self.reg_key
107 self.all_ops.append(fn.__name__)
108 self.all_keys.append(key)
109 self.lib.impl(key, fn, device_key)
111 def for_each(self):
112 for key, func in self.config:
113 try:
114 self.register_impl(key, func)
115 except Exception as e:
116 error.register_error(e)
118 def get_all_ops(self):
119 return self.all_ops
121 def get_all_keys(self):
122 return self.all_keys
124 def get_unused_ops(self):
125 return self.exclude_ops
127 def get_vendor_name(self):
128 return self.device.vendor_name
130 def get_current_device(self):
131 return self.device.name