Coverage for src/flag_gems/runtime/backend/__init__.py: 82%
217 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import ast
2import functools
3import importlib
4import inspect
5import os
6import sys
7from pathlib import Path
9from ..common import vendors
10from . import backend_utils
12vendor_module = None
13device_name = None
14torch_device_object = None
15torch_device_fn_device = None
16tl_extra_backend_module = None
17ops_module = None
18fused_module = None
19heuristic_config_module = None
20vendor_extra_lib_imported = False
21device_fn_cache = {}
22customized_ops = None
25class BackendArchEvent:
26 has_arch: bool = False
27 _instance = None
28 _initialized: bool = False
30 def __new__(cls, *args, **kwargs):
31 if cls._instance is None:
32 cls._instance = super().__new__(cls)
33 return cls._instance
35 def __init__(self, backend=None):
36 if BackendArchEvent._initialized:
37 return
38 BackendArchEvent._initialized = True
39 self.backend = backend
40 self.error_msgs = []
41 self.arch = self.get_arch()
42 if self.has_arch:
43 self.supported_archs = self._get_supported_archs()
44 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper
45 self.current_arch_path = self.supported_archs.get(self.arch)
46 self.arch_module = self.get_arch_module()
47 self.autotune_configs = self.get_autotune_configs()
48 self.heuristics_configs = self.get_heuristics_configs()
50 def get_functions_from_module(self, module):
51 return inspect.getmembers(module, inspect.isfunction) if module else []
53 def get_heuristics_configs(self):
54 heuristic_module = None
55 try:
56 heuristic_module = self.arch_module
57 except Exception: # noqa E722
58 sys.path.insert(0, str(self.current_arch_path))
59 heuristic_module = importlib.import_module("heuristics_config_utils")
60 sys.path.remove(str(self.current_arch_path))
61 if hasattr(heuristic_module, "HEURISTICS_CONFIGS"):
62 return heuristic_module.HEURISTICS_CONFIGS
63 return None
65 def get_autotune_configs(self):
66 path = self.current_arch_path
67 return backend_utils.get_tune_config(file_path=path)
69 def get_arch(self, device=0):
70 if not hasattr(vendor_module, "ARCH_MAP"):
71 return
72 arch_map = vendor_module.ARCH_MAP
73 arch_string = os.environ.get("ARCH", "")
74 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string
75 if not arch_string_num:
76 try:
77 if not torch_device_object.is_available():
78 return False
79 props = torch_device_object.get_device_properties(device)
80 arch_string_num = str(props.major)
81 except Exception:
82 self.has_arch = False
83 if arch_string_num not in arch_map:
84 print(
85 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization"
86 )
87 else:
88 self.has_arch = True
89 return arch_map[arch_string_num]
91 def _get_supported_archs(self, path=None):
92 path = path or vendor_module.__path__[0]
93 excluded = ("ops", "fused")
94 path = Path(path)
95 path = path.parent if path.is_file() else path
96 archs = {}
97 for p in path.iterdir():
98 name = str(p).split("/")[-1]
99 if p.is_dir() and name not in excluded and not name.startswith("_"):
100 archs.update({name: str(p)})
101 return archs
103 def get_supported_archs(self):
104 return list(self.supported_archs.keys())
106 def get_arch_module(self):
107 """Load backend.<arch>"""
108 path_dir = os.path.dirname(self.current_arch_path)
109 sys.path.insert(0, str(path_dir))
110 current_arch_module = importlib.import_module(self.arch)
111 sys.path.remove(str(path_dir))
112 return current_arch_module
114 def get_arch_ops(self):
115 arch_specialized_ops = []
116 modules = []
117 sys.path.append(self.current_arch_path)
118 ops_module = importlib.import_module(f"{self.arch}.ops")
119 try:
120 ops_module = self.arch_module.ops
121 modules.append(ops_module)
122 except Exception:
123 try:
124 sys.path.append(self.current_arch_path)
125 ops_module = importlib.import_module(f"{self.arch}.ops")
126 modules.append(ops_module)
127 except Exception as err_msg:
128 self.error_msgs.append(err_msg)
130 for mod in modules:
131 arch_specialized_ops.extend(self.get_functions_from_module(mod))
133 return arch_specialized_ops
136def import_vendor_extra_lib(vendor_name=None):
137 global vendor_extra_lib_imported
138 if vendor_extra_lib_imported is True:
139 return
140 global ops_module, fused_module
141 try:
142 ops_module = importlib.import_module(f"_{vendor_name}.ops")
143 except ModuleNotFoundError:
144 print(
145 f"[Note] No specialized common operators were found in"
146 f"the {vendor_name} implementation, and general common operators are used by default."
147 )
148 except Exception as e:
149 raise RuntimeError(f"Import vendor extra lib failed: {e}")
151 try:
152 fused_module = importlib.import_module(f"_{vendor_name}.fused")
153 except ModuleNotFoundError:
154 print(
155 f"[Note] No specialized fused operators were found in"
156 f"the {vendor_name} implementation, and general fused operators are used by default."
157 )
158 except Exception as e:
159 raise RuntimeError(f"Import vendor extra lib failed: {e}")
160 vendor_extra_lib_imported = True
163def get_codegen_result(code, result_key):
164 parsed_ast = ast.parse(code)
165 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec")
166 try:
167 exec(compiled_code, globals())
168 except Exception as e:
169 raise e
170 return globals()[result_key]
173@functools.lru_cache(maxsize=32)
174def gen_torch_tensor_attr_res(tensor, attr_name):
175 global device_name
176 device_name = device_name or get_vendor_info().device_name
177 code = f"""
178import torch
179res = {tensor}.{attr_name}
180 """
181 return get_codegen_result(code, "res")
184def set_tl_extra_backend_module(vendor_name=None):
185 global device_name, tl_extra_backend_module
186 vendor_info = get_vendor_info(vendor_name)
187 device_name = device_name or vendor_info.device_name
188 extra_name = vendor_info.triton_extra_name or device_name
189 module_str = f"triton.language.extra.{extra_name}.libdevice"
190 tl_extra_backend_module = importlib.import_module(module_str)
193def get_tl_extra_backend_module():
194 return tl_extra_backend_module
197def set_torch_backend_device_fn(vendor_name=None):
198 global device_name, torch_device_fn_device
199 device_name = device_name or get_vendor_info(vendor_name).device_name
200 module_str = f"torch.backends.{device_name}"
201 if device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"):
202 torch_device_fn_device = None
203 else:
204 torch_device_fn_device = importlib.import_module(module_str)
207def get_torch_backend_device_fn():
208 return torch_device_fn_device
211def gen_torch_device_object(vendor_name=None):
212 global device_name, torch_device_object
213 if torch_device_object is not None:
214 return torch_device_object
215 device_name = device_name or get_vendor_info(vendor_name).device_name
216 code = f"""
217import torch
218fn = torch.{device_name}
219"""
220 torch_device_object = get_codegen_result(code, "fn")
221 return torch_device_object
224def get_vendor_module(vendor_name, query=False):
225 def get_module(vendor_name):
226 current_file_path = os.path.abspath(__file__)
227 current_dir_path = os.path.dirname(current_file_path)
228 sys.path.append(current_dir_path)
229 return importlib.import_module(vendor_name)
231 if (
232 query
233 ): # The purpose of a query is to provide the user with the instance that he wants to import
234 return get_module(vendor_name)
236 global vendor_module
237 if vendor_module is None:
238 vendor_module = get_module("_" + vendor_name)
239 return vendor_module
242def get_vendor_info(vendor_name=None, query=False):
243 if query:
244 return get_vendor_module(vendor_name, query).vendor_info
245 global vendor_module # noqa: F824
246 get_vendor_module(vendor_name)
247 return vendor_module.vendor_info
250def get_vendor_infos():
251 infos = []
252 for vendor_name in vendors.get_all_vendors():
253 vendor_name = "_" + vendor_name
254 try:
255 single_info = get_vendor_info(vendor_name, query=True)
256 infos.append(single_info)
257 except Exception:
258 pass
260 return infos
263def get_current_device_extend_op(vendor_name=None):
264 import_vendor_extra_lib(vendor_name)
265 global customized_ops
266 if customized_ops is not None:
267 return customized_ops
268 customized_ops = []
269 if ops_module is not None:
270 ops = inspect.getmembers(ops_module, inspect.isfunction)
271 customized_ops += ops
272 if fused_module is not None:
273 fused_ops = inspect.getmembers(fused_module, inspect.isfunction)
274 customized_ops += fused_ops
275 return customized_ops
278def get_curent_device_unused_op(vendor_name=None):
279 global vendor_module # noqa: F824
280 get_vendor_module(vendor_name)
281 return list(vendor_module.CUSTOMIZED_UNUSED_OPS)
284def get_heuristic_config(vendor_name=None):
285 global heuristic_config_module
286 try:
287 heuristic_config_module = importlib.import_module(
288 f"_{vendor_name}.heuristics_config_utils"
289 )
290 except: # noqa E722
291 heuristic_config_module = importlib.import_module(
292 "_nvidia.heuristics_config_utils"
293 )
294 if hasattr(heuristic_config_module, "HEURISTICS_CONFIGS"):
295 return heuristic_config_module.HEURISTICS_CONFIGS
296 return None
299def get_tune_config(vendor_name=None):
300 global vendor_module # noqa: F824
301 get_vendor_module(vendor_name)
302 return backend_utils.get_tune_config(vendor_name)
305__all__ = ["*"]