Coverage for src/flag_gems/runtime/backend/__init__.py: 78%

291 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import ast 

2import functools 

3import importlib 

4import inspect 

5import os 

6import sys 

7from pathlib import Path 

8 

9from ..common import vendors 

10from . import backend_utils 

11from .backend_utils import BackendEventBase 

12 

13 

14class BackendState: 

15 """Singleton class to manage backend state variables.""" 

16 

17 _instance = None 

18 

19 def __new__(cls): 

20 if cls._instance is None: 

21 cls._instance = super().__new__(cls) 

22 cls._instance._initialized = False 

23 return cls._instance 

24 

25 def __init__(self): 

26 if self._initialized: 

27 return 

28 self._initialized = True 

29 self.vendor_module = None 

30 self.device_name = None 

31 self.torch_device_object = None 

32 self.torch_device_fn_device = None 

33 self.tl_extra_backend_module = None 

34 self.ops_module = None 

35 self.fused_module = None 

36 self.heuristic_config_module = None 

37 self.vendor_extra_lib_imported = False 

38 self.device_fn_cache = {} 

39 self.customized_ops = None 

40 

41 def is_available(self): 

42 return True 

43 

44 def get_ops(self, vendor=None): 

45 """Provide a unified interface for the upper layer""" 

46 return get_customized_ops(vendor) 

47 

48 

49# Global singleton instance 

50_state = BackendState() 

51 

52 

53class TritonVersionEvent(BackendEventBase): 

54 _instance = None 

55 has_version_spec = False 

56 

57 def __new__(cls, *args, **kwargs): 

58 if cls._instance is None: 

59 cls._instance = super().__new__(cls) 

60 return cls._instance 

61 

62 def __init__(self, version=None): 

63 self.has_version_spec = False 

64 self.version = version if version is not None else self.get_version() 

65 self.dir = self.get_version_spec_dir() 

66 if self.dir and Path(self.dir).exists(): 

67 self.module = self.get_version_spec_module() 

68 self.has_version_spec = True 

69 

70 def is_available(self): 

71 return self.has_version_spec 

72 

73 def get_version_spec_dir(self, path=None): 

74 dir_name = f"triton_{self.version}" 

75 backend_path = Path(path or _state.vendor_module.__path__[0]) 

76 backend_path = backend_path.parent if backend_path.is_file() else backend_path 

77 excluded = ("ops", "fused") 

78 return { 

79 p.name: str(p) 

80 for p in backend_path.iterdir() 

81 if p.is_dir() and p.name not in excluded and not p.name.startswith("_") 

82 }.get(dir_name, None) 

83 

84 def get_functions_from_module(self, module): 

85 return inspect.getmembers(module, inspect.isfunction) if module else [] 

86 

87 def get_version_spec_module(self): 

88 module_name = f"triton_{self.version}" 

89 path_dir = os.path.dirname(self.dir) 

90 sys.path.insert(0, str(path_dir)) 

91 version_module = importlib.import_module(module_name) 

92 sys.path.remove(str(path_dir)) 

93 return version_module 

94 

95 def get_ops(self, *args, **kwargs): 

96 return self.get_version_ops() 

97 

98 def get_version_ops(self): 

99 pass 

100 

101 def get_version(self): 

102 try: 

103 import triton 

104 except ImportError: 

105 return None 

106 return triton.__version__ 

107 

108 

109class BackendArchEvent(BackendEventBase): 

110 has_arch: bool = False 

111 _instance = None 

112 _initialized: bool = False 

113 

114 def __new__(cls, *args, **kwargs): 

115 if cls._instance is None: 

116 cls._instance = super().__new__(cls) 

117 return cls._instance 

118 

119 def __init__(self, backend=None): 

120 if BackendArchEvent._initialized: 

121 return 

122 BackendArchEvent._initialized = True 

123 self.backend = backend 

124 self.error_msgs = [] 

125 self.arch = self.get_arch() 

126 if self.has_arch: 

127 self.supported_archs = self._get_supported_archs() 

128 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper 

129 self.current_arch_path = self.supported_archs.get(self.arch) 

130 self.arch_module = self.get_arch_module() 

131 self.autotune_configs = self.get_autotune_configs() 

132 self.heuristics_configs = self.get_heuristics_configs() 

133 

134 def is_available(self): 

135 return self.has_arch 

136 

137 def get_functions_from_module(self, module): 

138 return inspect.getmembers(module, inspect.isfunction) if module else [] 

139 

140 def get_heuristics_configs(self): 

141 try: 

142 heuristic_module = self.arch_module 

143 except Exception: # noqa E722 

144 sys.path.insert(0, str(self.current_arch_path)) 

145 heuristic_module = importlib.import_module("heuristics_config_utils") 

146 sys.path.remove(str(self.current_arch_path)) 

147 return getattr(heuristic_module, "HEURISTICS_CONFIGS", None) 

148 

149 def get_autotune_configs(self): 

150 path = self.current_arch_path 

151 return backend_utils.get_tune_config(file_path=path) 

152 

153 def get_arch(self, device=0): 

154 if not hasattr(_state.vendor_module, "ARCH_MAP"): 

155 return 

156 arch_map = _state.vendor_module.ARCH_MAP 

157 arch_string = os.environ.get("ARCH", "") 

158 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string 

159 if not arch_string_num: 

160 try: 

161 if not _state.torch_device_object.is_available(): 

162 return False 

163 props = _state.torch_device_object.get_device_properties(device) 

164 arch_string_num = str(props.major) 

165 except Exception: 

166 self.has_arch = False 

167 if arch_string_num not in arch_map: 

168 print( 

169 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization" 

170 ) 

171 else: 

172 self.has_arch = True 

173 return arch_map[arch_string_num] 

174 

175 def _get_supported_archs(self, path=None): 

176 path = Path(path or _state.vendor_module.__path__[0]) 

177 path = path.parent if path.is_file() else path 

178 excluded = ("ops", "fused") 

179 return { 

180 p.name: str(p) 

181 for p in path.iterdir() 

182 if p.is_dir() and p.name not in excluded and not p.name.startswith("_") 

183 } 

184 

185 def get_supported_archs(self): 

186 return list(self.supported_archs.keys()) 

187 

188 def get_arch_module(self): 

189 """Load backend.<arch>""" 

190 path_dir = os.path.dirname(self.current_arch_path) 

191 sys.path.insert(0, str(path_dir)) 

192 current_arch_module = importlib.import_module(self.arch) 

193 sys.path.remove(str(path_dir)) 

194 return current_arch_module 

195 

196 def get_ops(self, *args, **kwargs): 

197 """Provide a unified interface for the upper layer""" 

198 return self.get_arch_ops() 

199 

200 def get_arch_ops(self): 

201 arch_specialized_ops = [] 

202 sys.path.append(self.current_arch_path) 

203 ops_module = getattr(self.arch_module, "ops", None) 

204 try: 

205 if ops_module is None: 

206 ops_module = importlib.import_module(f"{self.arch}.ops") 

207 except Exception: 

208 try: 

209 sys.path.append(self.current_arch_path) 

210 ops_module = importlib.import_module(f"{self.arch}.ops") 

211 arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) 

212 except Exception as err_msg: 

213 self.error_msgs.append(err_msg) 

214 

215 if ops_module is not None: 

216 arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) 

217 

218 return arch_specialized_ops 

219 

220 

221class SpecOpRegistrar: 

222 def __init__(self, registry, vendor=None): 

223 self._globals = registry 

224 self.vendor = vendor 

225 

226 def apply(self, vendor=None): 

227 vendor = vendor or self.vendor 

228 spec_events = self._get_specific_events() 

229 for event in spec_events: 

230 if not event.is_available(): 

231 continue 

232 operators = event.get_ops(vendor) 

233 for fn_name, fn in operators: 

234 self._globals[fn_name] = fn 

235 

236 def _get_specific_events(self): 

237 return (_state, BackendArchEvent(), TritonVersionEvent()) 

238 

239 

240def _import_module_safe(module_name, vendor_name, module_type): 

241 """Helper to import a module with proper error handling.""" 

242 try: 

243 return importlib.import_module(module_name) 

244 except ModuleNotFoundError: 

245 print( 

246 f"[Note] No specialized {module_type} operators were found for " 

247 f"the {vendor_name}, generic {module_type} operators will be used by default." 

248 ) 

249 except Exception as e: 

250 raise RuntimeError(f"Failed to import vendor extra lib: {e}") 

251 

252 

253def import_vendor_extra_lib(vendor_name=None): 

254 if _state.vendor_extra_lib_imported: 

255 return 

256 _state.ops_module = _import_module_safe( 

257 f"_{vendor_name}.ops", vendor_name, "common" 

258 ) 

259 _state.fused_module = _import_module_safe( 

260 f"_{vendor_name}.fused", vendor_name, "fused" 

261 ) 

262 _state.vendor_extra_lib_imported = True 

263 

264 

265def get_codegen_result(code, result_key): 

266 parsed_ast = ast.parse(code) 

267 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec") 

268 try: 

269 exec(compiled_code, globals()) 

270 except Exception as e: 

271 raise e 

272 return globals()[result_key] 

273 

274 

275@functools.lru_cache(maxsize=32) 

276def gen_torch_tensor_attr_res(tensor, attr_name): 

277 _state.device_name = _state.device_name or get_vendor_info().device_name 

278 code = f""" 

279import torch 

280res = {tensor}.{attr_name} 

281 """ 

282 return get_codegen_result(code, "res") 

283 

284 

285def set_tl_extra_backend_module(vendor_name=None): 

286 vendor_info = get_vendor_info(vendor_name) 

287 _state.device_name = _state.device_name or vendor_info.device_name 

288 extra_name = vendor_info.triton_extra_name or _state.device_name 

289 module_str = f"triton.language.extra.{extra_name}.libdevice" 

290 _state.tl_extra_backend_module = importlib.import_module(module_str) 

291 

292 

293def get_tl_extra_backend_module(): 

294 return _state.tl_extra_backend_module 

295 

296 

297def set_torch_backend_device_fn(vendor_name=None): 

298 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name 

299 module_str = f"torch.backends.{_state.device_name}" 

300 if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"): 

301 _state.torch_device_fn_device = None 

302 else: 

303 _state.torch_device_fn_device = importlib.import_module(module_str) 

304 

305 

306def get_torch_backend_device_fn(): 

307 return _state.torch_device_fn_device 

308 

309 

310def gen_torch_device_object(vendor_name=None): 

311 if _state.torch_device_object is not None: 

312 return _state.torch_device_object 

313 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name 

314 code = f""" 

315import torch 

316fn = torch.{_state.device_name} 

317""" 

318 _state.torch_device_object = get_codegen_result(code, "fn") 

319 

320 # SPACEMIT CPU backend needs special device guard handling 

321 if vendor_name == "spacemit": 

322 backends_module = importlib.import_module("flag_gems.runtime.backend._spacemit") 

323 setattr( 

324 _state.torch_device_object, 

325 "_DeviceGuard", 

326 getattr(backends_module, "_DeviceGuard"), 

327 ) 

328 setattr( 

329 _state.torch_device_object, 

330 "device", 

331 getattr(backends_module, "_DeviceWrapper"), 

332 ) 

333 # Override current_device to return integer 0 for kernel cache indexing 

334 setattr(_state.torch_device_object, "current_device", lambda: 0) 

335 

336 return _state.torch_device_object 

337 

338 

339def get_vendor_module(vendor_name, query=False): 

340 def get_module(vendor_name): 

341 current_file_path = os.path.abspath(__file__) 

342 current_dir_path = os.path.dirname(current_file_path) 

343 sys.path.append(current_dir_path) 

344 return importlib.import_module(vendor_name) 

345 

346 if ( 

347 query 

348 ): # The purpose of a query is to provide the user with the instance that he wants to import 

349 return get_module(vendor_name) 

350 

351 if _state.vendor_module is None: 

352 _state.vendor_module = get_module("_" + vendor_name) 

353 return _state.vendor_module 

354 

355 

356def get_vendor_info(vendor_name=None, query=False): 

357 if query: 

358 return get_vendor_module(vendor_name, query).vendor_info 

359 get_vendor_module(vendor_name) 

360 return _state.vendor_module.vendor_info 

361 

362 

363def get_vendor_infos(): 

364 infos = [] 

365 for vendor_name in vendors.get_all_vendors(): 

366 try: 

367 infos.append(get_vendor_info(f"_{vendor_name}", query=True)) 

368 except Exception: 

369 continue 

370 

371 return infos 

372 

373 

374def get_customized_ops(vendor_name=None): 

375 import_vendor_extra_lib(vendor_name) 

376 if _state.customized_ops is not None: 

377 return _state.customized_ops 

378 _state.customized_ops = [] 

379 if _state.ops_module is not None: 

380 ops = inspect.getmembers(_state.ops_module, inspect.isfunction) 

381 _state.customized_ops += ops 

382 if _state.fused_module is not None: 

383 fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction) 

384 _state.customized_ops += fused_ops 

385 return _state.customized_ops 

386 

387 

388def get_ops(vendor_name=None): 

389 """Provide a unified interface for the upper layer""" 

390 return get_customized_ops(vendor_name) 

391 

392 

393def get_unused_ops(vendor_name=None): 

394 global vendor_module # noqa: F824 

395 get_vendor_module(vendor_name) 

396 return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS) 

397 

398 

399def get_heuristic_config(vendor_name=None): 

400 config_name = "heuristics_config_utils" 

401 mod_name = f"_{vendor_name}.{config_name}" 

402 try: 

403 _state.heuristic_config_module = importlib.import_module(mod_name) 

404 except Exception: 

405 mod_name = f"_nvidia.{config_name}" 

406 _state.heuristic_config_module = importlib.import_module(mod_name) 

407 return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None) 

408 

409 

410def get_tune_config(vendor_name=None): 

411 global vendor_module # noqa: F824 

412 get_vendor_module(vendor_name) 

413 return backend_utils.get_tune_config(vendor_name) 

414 

415 

416def get_expand_config(op_name=None, file_path=None): 

417 return backend_utils.get_expand_config(op_name=op_name, file_path=file_path) 

418 

419 

420def get_backend_state() -> BackendState: 

421 """Get the global BackendState singleton instance.""" 

422 return _state 

423 

424 

425__all__ = ["*"]