Coverage for src/flag_gems/runtime/backend/__init__.py: 78%
291 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import ast
2import functools
3import importlib
4import inspect
5import os
6import sys
7from pathlib import Path
9from ..common import vendors
10from . import backend_utils
11from .backend_utils import BackendEventBase
14class BackendState:
15 """Singleton class to manage backend state variables."""
17 _instance = None
19 def __new__(cls):
20 if cls._instance is None:
21 cls._instance = super().__new__(cls)
22 cls._instance._initialized = False
23 return cls._instance
25 def __init__(self):
26 if self._initialized:
27 return
28 self._initialized = True
29 self.vendor_module = None
30 self.device_name = None
31 self.torch_device_object = None
32 self.torch_device_fn_device = None
33 self.tl_extra_backend_module = None
34 self.ops_module = None
35 self.fused_module = None
36 self.heuristic_config_module = None
37 self.vendor_extra_lib_imported = False
38 self.device_fn_cache = {}
39 self.customized_ops = None
41 def is_available(self):
42 return True
44 def get_ops(self, vendor=None):
45 """Provide a unified interface for the upper layer"""
46 return get_customized_ops(vendor)
49# Global singleton instance
50_state = BackendState()
53class TritonVersionEvent(BackendEventBase):
54 _instance = None
55 has_version_spec = False
57 def __new__(cls, *args, **kwargs):
58 if cls._instance is None:
59 cls._instance = super().__new__(cls)
60 return cls._instance
62 def __init__(self, version=None):
63 self.has_version_spec = False
64 self.version = version if version is not None else self.get_version()
65 self.dir = self.get_version_spec_dir()
66 if self.dir and Path(self.dir).exists():
67 self.module = self.get_version_spec_module()
68 self.has_version_spec = True
70 def is_available(self):
71 return self.has_version_spec
73 def get_version_spec_dir(self, path=None):
74 dir_name = f"triton_{self.version}"
75 backend_path = Path(path or _state.vendor_module.__path__[0])
76 backend_path = backend_path.parent if backend_path.is_file() else backend_path
77 excluded = ("ops", "fused")
78 return {
79 p.name: str(p)
80 for p in backend_path.iterdir()
81 if p.is_dir() and p.name not in excluded and not p.name.startswith("_")
82 }.get(dir_name, None)
84 def get_functions_from_module(self, module):
85 return inspect.getmembers(module, inspect.isfunction) if module else []
87 def get_version_spec_module(self):
88 module_name = f"triton_{self.version}"
89 path_dir = os.path.dirname(self.dir)
90 sys.path.insert(0, str(path_dir))
91 version_module = importlib.import_module(module_name)
92 sys.path.remove(str(path_dir))
93 return version_module
95 def get_ops(self, *args, **kwargs):
96 return self.get_version_ops()
98 def get_version_ops(self):
99 pass
101 def get_version(self):
102 try:
103 import triton
104 except ImportError:
105 return None
106 return triton.__version__
109class BackendArchEvent(BackendEventBase):
110 has_arch: bool = False
111 _instance = None
112 _initialized: bool = False
114 def __new__(cls, *args, **kwargs):
115 if cls._instance is None:
116 cls._instance = super().__new__(cls)
117 return cls._instance
119 def __init__(self, backend=None):
120 if BackendArchEvent._initialized:
121 return
122 BackendArchEvent._initialized = True
123 self.backend = backend
124 self.error_msgs = []
125 self.arch = self.get_arch()
126 if self.has_arch:
127 self.supported_archs = self._get_supported_archs()
128 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper
129 self.current_arch_path = self.supported_archs.get(self.arch)
130 self.arch_module = self.get_arch_module()
131 self.autotune_configs = self.get_autotune_configs()
132 self.heuristics_configs = self.get_heuristics_configs()
134 def is_available(self):
135 return self.has_arch
137 def get_functions_from_module(self, module):
138 return inspect.getmembers(module, inspect.isfunction) if module else []
140 def get_heuristics_configs(self):
141 try:
142 heuristic_module = self.arch_module
143 except Exception: # noqa E722
144 sys.path.insert(0, str(self.current_arch_path))
145 heuristic_module = importlib.import_module("heuristics_config_utils")
146 sys.path.remove(str(self.current_arch_path))
147 return getattr(heuristic_module, "HEURISTICS_CONFIGS", None)
149 def get_autotune_configs(self):
150 path = self.current_arch_path
151 return backend_utils.get_tune_config(file_path=path)
153 def get_arch(self, device=0):
154 if not hasattr(_state.vendor_module, "ARCH_MAP"):
155 return
156 arch_map = _state.vendor_module.ARCH_MAP
157 arch_string = os.environ.get("ARCH", "")
158 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string
159 if not arch_string_num:
160 try:
161 if not _state.torch_device_object.is_available():
162 return False
163 props = _state.torch_device_object.get_device_properties(device)
164 arch_string_num = str(props.major)
165 except Exception:
166 self.has_arch = False
167 if arch_string_num not in arch_map:
168 print(
169 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization"
170 )
171 else:
172 self.has_arch = True
173 return arch_map[arch_string_num]
175 def _get_supported_archs(self, path=None):
176 path = Path(path or _state.vendor_module.__path__[0])
177 path = path.parent if path.is_file() else path
178 excluded = ("ops", "fused")
179 return {
180 p.name: str(p)
181 for p in path.iterdir()
182 if p.is_dir() and p.name not in excluded and not p.name.startswith("_")
183 }
185 def get_supported_archs(self):
186 return list(self.supported_archs.keys())
188 def get_arch_module(self):
189 """Load backend.<arch>"""
190 path_dir = os.path.dirname(self.current_arch_path)
191 sys.path.insert(0, str(path_dir))
192 current_arch_module = importlib.import_module(self.arch)
193 sys.path.remove(str(path_dir))
194 return current_arch_module
196 def get_ops(self, *args, **kwargs):
197 """Provide a unified interface for the upper layer"""
198 return self.get_arch_ops()
200 def get_arch_ops(self):
201 arch_specialized_ops = []
202 sys.path.append(self.current_arch_path)
203 ops_module = getattr(self.arch_module, "ops", None)
204 try:
205 if ops_module is None:
206 ops_module = importlib.import_module(f"{self.arch}.ops")
207 except Exception:
208 try:
209 sys.path.append(self.current_arch_path)
210 ops_module = importlib.import_module(f"{self.arch}.ops")
211 arch_specialized_ops.extend(self.get_functions_from_module(ops_module))
212 except Exception as err_msg:
213 self.error_msgs.append(err_msg)
215 if ops_module is not None:
216 arch_specialized_ops.extend(self.get_functions_from_module(ops_module))
218 return arch_specialized_ops
221class SpecOpRegistrar:
222 def __init__(self, registry, vendor=None):
223 self._globals = registry
224 self.vendor = vendor
226 def apply(self, vendor=None):
227 vendor = vendor or self.vendor
228 spec_events = self._get_specific_events()
229 for event in spec_events:
230 if not event.is_available():
231 continue
232 operators = event.get_ops(vendor)
233 for fn_name, fn in operators:
234 self._globals[fn_name] = fn
236 def _get_specific_events(self):
237 return (_state, BackendArchEvent(), TritonVersionEvent())
240def _import_module_safe(module_name, vendor_name, module_type):
241 """Helper to import a module with proper error handling."""
242 try:
243 return importlib.import_module(module_name)
244 except ModuleNotFoundError:
245 print(
246 f"[Note] No specialized {module_type} operators were found for "
247 f"the {vendor_name}, generic {module_type} operators will be used by default."
248 )
249 except Exception as e:
250 raise RuntimeError(f"Failed to import vendor extra lib: {e}")
253def import_vendor_extra_lib(vendor_name=None):
254 if _state.vendor_extra_lib_imported:
255 return
256 _state.ops_module = _import_module_safe(
257 f"_{vendor_name}.ops", vendor_name, "common"
258 )
259 _state.fused_module = _import_module_safe(
260 f"_{vendor_name}.fused", vendor_name, "fused"
261 )
262 _state.vendor_extra_lib_imported = True
265def get_codegen_result(code, result_key):
266 parsed_ast = ast.parse(code)
267 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec")
268 try:
269 exec(compiled_code, globals())
270 except Exception as e:
271 raise e
272 return globals()[result_key]
275@functools.lru_cache(maxsize=32)
276def gen_torch_tensor_attr_res(tensor, attr_name):
277 _state.device_name = _state.device_name or get_vendor_info().device_name
278 code = f"""
279import torch
280res = {tensor}.{attr_name}
281 """
282 return get_codegen_result(code, "res")
285def set_tl_extra_backend_module(vendor_name=None):
286 vendor_info = get_vendor_info(vendor_name)
287 _state.device_name = _state.device_name or vendor_info.device_name
288 extra_name = vendor_info.triton_extra_name or _state.device_name
289 module_str = f"triton.language.extra.{extra_name}.libdevice"
290 _state.tl_extra_backend_module = importlib.import_module(module_str)
293def get_tl_extra_backend_module():
294 return _state.tl_extra_backend_module
297def set_torch_backend_device_fn(vendor_name=None):
298 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
299 module_str = f"torch.backends.{_state.device_name}"
300 if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"):
301 _state.torch_device_fn_device = None
302 else:
303 _state.torch_device_fn_device = importlib.import_module(module_str)
306def get_torch_backend_device_fn():
307 return _state.torch_device_fn_device
310def gen_torch_device_object(vendor_name=None):
311 if _state.torch_device_object is not None:
312 return _state.torch_device_object
313 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
314 code = f"""
315import torch
316fn = torch.{_state.device_name}
317"""
318 _state.torch_device_object = get_codegen_result(code, "fn")
320 # SPACEMIT CPU backend needs special device guard handling
321 if vendor_name == "spacemit":
322 backends_module = importlib.import_module("flag_gems.runtime.backend._spacemit")
323 setattr(
324 _state.torch_device_object,
325 "_DeviceGuard",
326 getattr(backends_module, "_DeviceGuard"),
327 )
328 setattr(
329 _state.torch_device_object,
330 "device",
331 getattr(backends_module, "_DeviceWrapper"),
332 )
333 # Override current_device to return integer 0 for kernel cache indexing
334 setattr(_state.torch_device_object, "current_device", lambda: 0)
336 return _state.torch_device_object
339def get_vendor_module(vendor_name, query=False):
340 def get_module(vendor_name):
341 current_file_path = os.path.abspath(__file__)
342 current_dir_path = os.path.dirname(current_file_path)
343 sys.path.append(current_dir_path)
344 return importlib.import_module(vendor_name)
346 if (
347 query
348 ): # The purpose of a query is to provide the user with the instance that he wants to import
349 return get_module(vendor_name)
351 if _state.vendor_module is None:
352 _state.vendor_module = get_module("_" + vendor_name)
353 return _state.vendor_module
356def get_vendor_info(vendor_name=None, query=False):
357 if query:
358 return get_vendor_module(vendor_name, query).vendor_info
359 get_vendor_module(vendor_name)
360 return _state.vendor_module.vendor_info
363def get_vendor_infos():
364 infos = []
365 for vendor_name in vendors.get_all_vendors():
366 try:
367 infos.append(get_vendor_info(f"_{vendor_name}", query=True))
368 except Exception:
369 continue
371 return infos
374def get_customized_ops(vendor_name=None):
375 import_vendor_extra_lib(vendor_name)
376 if _state.customized_ops is not None:
377 return _state.customized_ops
378 _state.customized_ops = []
379 if _state.ops_module is not None:
380 ops = inspect.getmembers(_state.ops_module, inspect.isfunction)
381 _state.customized_ops += ops
382 if _state.fused_module is not None:
383 fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction)
384 _state.customized_ops += fused_ops
385 return _state.customized_ops
388def get_ops(vendor_name=None):
389 """Provide a unified interface for the upper layer"""
390 return get_customized_ops(vendor_name)
393def get_unused_ops(vendor_name=None):
394 global vendor_module # noqa: F824
395 get_vendor_module(vendor_name)
396 return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS)
399def get_heuristic_config(vendor_name=None):
400 config_name = "heuristics_config_utils"
401 mod_name = f"_{vendor_name}.{config_name}"
402 try:
403 _state.heuristic_config_module = importlib.import_module(mod_name)
404 except Exception:
405 mod_name = f"_nvidia.{config_name}"
406 _state.heuristic_config_module = importlib.import_module(mod_name)
407 return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None)
410def get_tune_config(vendor_name=None):
411 global vendor_module # noqa: F824
412 get_vendor_module(vendor_name)
413 return backend_utils.get_tune_config(vendor_name)
416def get_expand_config(op_name=None, file_path=None):
417 return backend_utils.get_expand_config(op_name=op_name, file_path=file_path)
420def get_backend_state() -> BackendState:
421 """Get the global BackendState singleton instance."""
422 return _state
425__all__ = ["*"]