Coverage for src/flag_gems/runtime/backend/__init__.py: 83%

218 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import ast 

2import functools 

3import importlib 

4import inspect 

5import os 

6import sys 

7from pathlib import Path 

8 

9from ..common import vendors 

10from . import backend_utils 

11 

12 

13class BackendState: 

14 """Singleton class to manage backend state variables.""" 

15 

16 _instance = None 

17 

18 def __new__(cls): 

19 if cls._instance is None: 

20 cls._instance = super().__new__(cls) 

21 cls._instance._initialized = False 

22 return cls._instance 

23 

24 def __init__(self): 

25 if self._initialized: 

26 return 

27 self._initialized = True 

28 self.vendor_module = None 

29 self.device_name = None 

30 self.torch_device_object = None 

31 self.torch_device_fn_device = None 

32 self.tl_extra_backend_module = None 

33 self.ops_module = None 

34 self.fused_module = None 

35 self.heuristic_config_module = None 

36 self.vendor_extra_lib_imported = False 

37 self.device_fn_cache = {} 

38 self.customized_ops = None 

39 

40 

41# Global singleton instance 

42_state = BackendState() 

43 

44 

45class BackendArchEvent: 

46 has_arch: bool = False 

47 _instance = None 

48 _initialized: bool = False 

49 

50 def __new__(cls, *args, **kwargs): 

51 if cls._instance is None: 

52 cls._instance = super().__new__(cls) 

53 return cls._instance 

54 

55 def __init__(self, backend=None): 

56 if BackendArchEvent._initialized: 

57 return 

58 BackendArchEvent._initialized = True 

59 self.backend = backend 

60 self.error_msgs = [] 

61 self.arch = self.get_arch() 

62 if self.has_arch: 

63 self.supported_archs = self._get_supported_archs() 

64 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper 

65 self.current_arch_path = self.supported_archs.get(self.arch) 

66 self.arch_module = self.get_arch_module() 

67 self.autotune_configs = self.get_autotune_configs() 

68 self.heuristics_configs = self.get_heuristics_configs() 

69 

70 def get_functions_from_module(self, module): 

71 return inspect.getmembers(module, inspect.isfunction) if module else [] 

72 

73 def get_heuristics_configs(self): 

74 try: 

75 heuristic_module = self.arch_module 

76 except Exception: # noqa E722 

77 sys.path.insert(0, str(self.current_arch_path)) 

78 heuristic_module = importlib.import_module("heuristics_config_utils") 

79 sys.path.remove(str(self.current_arch_path)) 

80 return getattr(heuristic_module, "HEURISTICS_CONFIGS", None) 

81 

82 def get_autotune_configs(self): 

83 path = self.current_arch_path 

84 return backend_utils.get_tune_config(file_path=path) 

85 

86 def get_arch(self, device=0): 

87 if not hasattr(_state.vendor_module, "ARCH_MAP"): 

88 return 

89 arch_map = _state.vendor_module.ARCH_MAP 

90 arch_string = os.environ.get("ARCH", "") 

91 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string 

92 if not arch_string_num: 

93 try: 

94 if not _state.torch_device_object.is_available(): 

95 return False 

96 props = _state.torch_device_object.get_device_properties(device) 

97 arch_string_num = str(props.major) 

98 except Exception: 

99 self.has_arch = False 

100 if arch_string_num not in arch_map: 

101 print( 

102 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization" 

103 ) 

104 else: 

105 self.has_arch = True 

106 return arch_map[arch_string_num] 

107 

108 def _get_supported_archs(self, path=None): 

109 path = Path(path or _state.vendor_module.__path__[0]) 

110 path = path.parent if path.is_file() else path 

111 excluded = ("ops", "fused") 

112 return { 

113 p.name: str(p) 

114 for p in path.iterdir() 

115 if p.is_dir() and p.name not in excluded and not p.name.startswith("_") 

116 } 

117 

118 def get_supported_archs(self): 

119 return list(self.supported_archs.keys()) 

120 

121 def get_arch_module(self): 

122 """Load backend.<arch>""" 

123 path_dir = os.path.dirname(self.current_arch_path) 

124 sys.path.insert(0, str(path_dir)) 

125 current_arch_module = importlib.import_module(self.arch) 

126 sys.path.remove(str(path_dir)) 

127 return current_arch_module 

128 

129 def get_arch_ops(self): 

130 arch_specialized_ops = [] 

131 sys.path.append(self.current_arch_path) 

132 ops_module = getattr(self.arch_module, "ops", None) 

133 try: 

134 if ops_module is None: 

135 ops_module = importlib.import_module(f"{self.arch}.ops") 

136 except Exception: 

137 try: 

138 sys.path.append(self.current_arch_path) 

139 ops_module = importlib.import_module(f"{self.arch}.ops") 

140 arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) 

141 except Exception as err_msg: 

142 self.error_msgs.append(err_msg) 

143 

144 return arch_specialized_ops 

145 

146 

147def _import_module_safe(module_name, vendor_name, module_type): 

148 """Helper to import a module with proper error handling.""" 

149 try: 

150 return importlib.import_module(module_name) 

151 except ModuleNotFoundError: 

152 print( 

153 f"[Note] No specialized {module_type} operators were found for " 

154 f"the {vendor_name}, generic {module_type} operators will be used by default." 

155 ) 

156 except Exception as e: 

157 raise RuntimeError(f"Failed to import vendor extra lib: {e}") 

158 

159 

160def import_vendor_extra_lib(vendor_name=None): 

161 if _state.vendor_extra_lib_imported: 

162 return 

163 _state.ops_module = _import_module_safe( 

164 f"_{vendor_name}.ops", vendor_name, "common" 

165 ) 

166 _state.fused_module = _import_module_safe( 

167 f"_{vendor_name}.fused", vendor_name, "fused" 

168 ) 

169 _state.vendor_extra_lib_imported = True 

170 

171 

172def get_codegen_result(code, result_key): 

173 parsed_ast = ast.parse(code) 

174 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec") 

175 try: 

176 exec(compiled_code, globals()) 

177 except Exception as e: 

178 raise e 

179 return globals()[result_key] 

180 

181 

182@functools.lru_cache(maxsize=32) 

183def gen_torch_tensor_attr_res(tensor, attr_name): 

184 _state.device_name = _state.device_name or get_vendor_info().device_name 

185 code = f""" 

186import torch 

187res = {tensor}.{attr_name} 

188 """ 

189 return get_codegen_result(code, "res") 

190 

191 

192def set_tl_extra_backend_module(vendor_name=None): 

193 vendor_info = get_vendor_info(vendor_name) 

194 _state.device_name = _state.device_name or vendor_info.device_name 

195 extra_name = vendor_info.triton_extra_name or _state.device_name 

196 module_str = f"triton.language.extra.{extra_name}.libdevice" 

197 _state.tl_extra_backend_module = importlib.import_module(module_str) 

198 

199 

200def get_tl_extra_backend_module(): 

201 return _state.tl_extra_backend_module 

202 

203 

204def set_torch_backend_device_fn(vendor_name=None): 

205 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name 

206 module_str = f"torch.backends.{_state.device_name}" 

207 if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"): 

208 _state.torch_device_fn_device = None 

209 else: 

210 _state.torch_device_fn_device = importlib.import_module(module_str) 

211 

212 

213def get_torch_backend_device_fn(): 

214 return _state.torch_device_fn_device 

215 

216 

217def gen_torch_device_object(vendor_name=None): 

218 if _state.torch_device_object is not None: 

219 return _state.torch_device_object 

220 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name 

221 code = f""" 

222import torch 

223fn = torch.{_state.device_name} 

224""" 

225 _state.torch_device_object = get_codegen_result(code, "fn") 

226 return _state.torch_device_object 

227 

228 

229def get_vendor_module(vendor_name, query=False): 

230 def get_module(vendor_name): 

231 current_file_path = os.path.abspath(__file__) 

232 current_dir_path = os.path.dirname(current_file_path) 

233 sys.path.append(current_dir_path) 

234 return importlib.import_module(vendor_name) 

235 

236 if ( 

237 query 

238 ): # The purpose of a query is to provide the user with the instance that he wants to import 

239 return get_module(vendor_name) 

240 

241 if _state.vendor_module is None: 

242 _state.vendor_module = get_module("_" + vendor_name) 

243 return _state.vendor_module 

244 

245 

246def get_vendor_info(vendor_name=None, query=False): 

247 if query: 

248 return get_vendor_module(vendor_name, query).vendor_info 

249 get_vendor_module(vendor_name) 

250 return _state.vendor_module.vendor_info 

251 

252 

253def get_vendor_infos(): 

254 infos = [] 

255 for vendor_name in vendors.get_all_vendors(): 

256 try: 

257 infos.append(get_vendor_info(f"_{vendor_name}", query=True)) 

258 except Exception: 

259 continue 

260 

261 return infos 

262 

263 

264def get_customized_ops(vendor_name=None): 

265 import_vendor_extra_lib(vendor_name) 

266 if _state.customized_ops is not None: 

267 return _state.customized_ops 

268 _state.customized_ops = [] 

269 if _state.ops_module is not None: 

270 ops = inspect.getmembers(_state.ops_module, inspect.isfunction) 

271 _state.customized_ops += ops 

272 if _state.fused_module is not None: 

273 fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction) 

274 _state.customized_ops += fused_ops 

275 return _state.customized_ops 

276 

277 

278def get_unused_ops(vendor_name=None): 

279 global vendor_module # noqa: F824 

280 get_vendor_module(vendor_name) 

281 return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS) 

282 

283 

284def get_heuristic_config(vendor_name=None): 

285 config_name = "heuristics_config_utils" 

286 default_backend = "nvidia" 

287 for backend in (vendor_name, default_backend): 

288 mod_name = f"_{backend}.{config_name}" 

289 try: 

290 _state.heuristic_config_module = importlib.import_module(mod_name) 

291 except Exception: 

292 continue 

293 return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None) 

294 

295 

296def get_tune_config(vendor_name=None): 

297 global vendor_module # noqa: F824 

298 get_vendor_module(vendor_name) 

299 return backend_utils.get_tune_config(vendor_name) 

300 

301 

302def get_expand_config(op_name=None, file_path=None): 

303 return backend_utils.get_expand_config(op_name=op_name, file_path=file_path) 

304 

305 

306def get_backend_state() -> BackendState: 

307 """Get the global BackendState singleton instance.""" 

308 return _state 

309 

310 

311__all__ = ["*"]