Coverage for src/flag_gems/runtime/backend/__init__.py: 82%

217 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import ast 

2import functools 

3import importlib 

4import inspect 

5import os 

6import sys 

7from pathlib import Path 

8 

9from ..common import vendors 

10from . import backend_utils 

11 

12vendor_module = None 

13device_name = None 

14torch_device_object = None 

15torch_device_fn_device = None 

16tl_extra_backend_module = None 

17ops_module = None 

18fused_module = None 

19heuristic_config_module = None 

20vendor_extra_lib_imported = False 

21device_fn_cache = {} 

22customized_ops = None 

23 

24 

25class BackendArchEvent: 

26 has_arch: bool = False 

27 _instance = None 

28 _initialized: bool = False 

29 

30 def __new__(cls, *args, **kwargs): 

31 if cls._instance is None: 

32 cls._instance = super().__new__(cls) 

33 return cls._instance 

34 

35 def __init__(self, backend=None): 

36 if BackendArchEvent._initialized: 

37 return 

38 BackendArchEvent._initialized = True 

39 self.backend = backend 

40 self.error_msgs = [] 

41 self.arch = self.get_arch() 

42 if self.has_arch: 

43 self.supported_archs = self._get_supported_archs() 

44 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper 

45 self.current_arch_path = self.supported_archs.get(self.arch) 

46 self.arch_module = self.get_arch_module() 

47 self.autotune_configs = self.get_autotune_configs() 

48 self.heuristics_configs = self.get_heuristics_configs() 

49 

50 def get_functions_from_module(self, module): 

51 return inspect.getmembers(module, inspect.isfunction) if module else [] 

52 

53 def get_heuristics_configs(self): 

54 heuristic_module = None 

55 try: 

56 heuristic_module = self.arch_module 

57 except Exception: # noqa E722 

58 sys.path.insert(0, str(self.current_arch_path)) 

59 heuristic_module = importlib.import_module("heuristics_config_utils") 

60 sys.path.remove(str(self.current_arch_path)) 

61 if hasattr(heuristic_module, "HEURISTICS_CONFIGS"): 

62 return heuristic_module.HEURISTICS_CONFIGS 

63 return None 

64 

65 def get_autotune_configs(self): 

66 path = self.current_arch_path 

67 return backend_utils.get_tune_config(file_path=path) 

68 

69 def get_arch(self, device=0): 

70 if not hasattr(vendor_module, "ARCH_MAP"): 

71 return 

72 arch_map = vendor_module.ARCH_MAP 

73 arch_string = os.environ.get("ARCH", "") 

74 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string 

75 if not arch_string_num: 

76 try: 

77 if not torch_device_object.is_available(): 

78 return False 

79 props = torch_device_object.get_device_properties(device) 

80 arch_string_num = str(props.major) 

81 except Exception: 

82 self.has_arch = False 

83 if arch_string_num not in arch_map: 

84 print( 

85 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization" 

86 ) 

87 else: 

88 self.has_arch = True 

89 return arch_map[arch_string_num] 

90 

91 def _get_supported_archs(self, path=None): 

92 path = path or vendor_module.__path__[0] 

93 excluded = ("ops", "fused") 

94 path = Path(path) 

95 path = path.parent if path.is_file() else path 

96 archs = {} 

97 for p in path.iterdir(): 

98 name = str(p).split("/")[-1] 

99 if p.is_dir() and name not in excluded and not name.startswith("_"): 

100 archs.update({name: str(p)}) 

101 return archs 

102 

103 def get_supported_archs(self): 

104 return list(self.supported_archs.keys()) 

105 

106 def get_arch_module(self): 

107 """Load backend.<arch>""" 

108 path_dir = os.path.dirname(self.current_arch_path) 

109 sys.path.insert(0, str(path_dir)) 

110 current_arch_module = importlib.import_module(self.arch) 

111 sys.path.remove(str(path_dir)) 

112 return current_arch_module 

113 

114 def get_arch_ops(self): 

115 arch_specialized_ops = [] 

116 modules = [] 

117 sys.path.append(self.current_arch_path) 

118 ops_module = importlib.import_module(f"{self.arch}.ops") 

119 try: 

120 ops_module = self.arch_module.ops 

121 modules.append(ops_module) 

122 except Exception: 

123 try: 

124 sys.path.append(self.current_arch_path) 

125 ops_module = importlib.import_module(f"{self.arch}.ops") 

126 modules.append(ops_module) 

127 except Exception as err_msg: 

128 self.error_msgs.append(err_msg) 

129 

130 for mod in modules: 

131 arch_specialized_ops.extend(self.get_functions_from_module(mod)) 

132 

133 return arch_specialized_ops 

134 

135 

136def import_vendor_extra_lib(vendor_name=None): 

137 global vendor_extra_lib_imported 

138 if vendor_extra_lib_imported is True: 

139 return 

140 global ops_module, fused_module 

141 try: 

142 ops_module = importlib.import_module(f"_{vendor_name}.ops") 

143 except ModuleNotFoundError: 

144 print( 

145 f"[Note] No specialized common operators were found in" 

146 f"the {vendor_name} implementation, and general common operators are used by default." 

147 ) 

148 except Exception as e: 

149 raise RuntimeError(f"Import vendor extra lib failed: {e}") 

150 

151 try: 

152 fused_module = importlib.import_module(f"_{vendor_name}.fused") 

153 except ModuleNotFoundError: 

154 print( 

155 f"[Note] No specialized fused operators were found in" 

156 f"the {vendor_name} implementation, and general fused operators are used by default." 

157 ) 

158 except Exception as e: 

159 raise RuntimeError(f"Import vendor extra lib failed: {e}") 

160 vendor_extra_lib_imported = True 

161 

162 

163def get_codegen_result(code, result_key): 

164 parsed_ast = ast.parse(code) 

165 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec") 

166 try: 

167 exec(compiled_code, globals()) 

168 except Exception as e: 

169 raise e 

170 return globals()[result_key] 

171 

172 

173@functools.lru_cache(maxsize=32) 

174def gen_torch_tensor_attr_res(tensor, attr_name): 

175 global device_name 

176 device_name = device_name or get_vendor_info().device_name 

177 code = f""" 

178import torch 

179res = {tensor}.{attr_name} 

180 """ 

181 return get_codegen_result(code, "res") 

182 

183 

184def set_tl_extra_backend_module(vendor_name=None): 

185 global device_name, tl_extra_backend_module 

186 vendor_info = get_vendor_info(vendor_name) 

187 device_name = device_name or vendor_info.device_name 

188 extra_name = vendor_info.triton_extra_name or device_name 

189 module_str = f"triton.language.extra.{extra_name}.libdevice" 

190 tl_extra_backend_module = importlib.import_module(module_str) 

191 

192 

193def get_tl_extra_backend_module(): 

194 return tl_extra_backend_module 

195 

196 

197def set_torch_backend_device_fn(vendor_name=None): 

198 global device_name, torch_device_fn_device 

199 device_name = device_name or get_vendor_info(vendor_name).device_name 

200 module_str = f"torch.backends.{device_name}" 

201 if device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"): 

202 torch_device_fn_device = None 

203 else: 

204 torch_device_fn_device = importlib.import_module(module_str) 

205 

206 

207def get_torch_backend_device_fn(): 

208 return torch_device_fn_device 

209 

210 

211def gen_torch_device_object(vendor_name=None): 

212 global device_name, torch_device_object 

213 if torch_device_object is not None: 

214 return torch_device_object 

215 device_name = device_name or get_vendor_info(vendor_name).device_name 

216 code = f""" 

217import torch 

218fn = torch.{device_name} 

219""" 

220 torch_device_object = get_codegen_result(code, "fn") 

221 return torch_device_object 

222 

223 

224def get_vendor_module(vendor_name, query=False): 

225 def get_module(vendor_name): 

226 current_file_path = os.path.abspath(__file__) 

227 current_dir_path = os.path.dirname(current_file_path) 

228 sys.path.append(current_dir_path) 

229 return importlib.import_module(vendor_name) 

230 

231 if ( 

232 query 

233 ): # The purpose of a query is to provide the user with the instance that he wants to import 

234 return get_module(vendor_name) 

235 

236 global vendor_module 

237 if vendor_module is None: 

238 vendor_module = get_module("_" + vendor_name) 

239 return vendor_module 

240 

241 

242def get_vendor_info(vendor_name=None, query=False): 

243 if query: 

244 return get_vendor_module(vendor_name, query).vendor_info 

245 global vendor_module # noqa: F824 

246 get_vendor_module(vendor_name) 

247 return vendor_module.vendor_info 

248 

249 

250def get_vendor_infos(): 

251 infos = [] 

252 for vendor_name in vendors.get_all_vendors(): 

253 vendor_name = "_" + vendor_name 

254 try: 

255 single_info = get_vendor_info(vendor_name, query=True) 

256 infos.append(single_info) 

257 except Exception: 

258 pass 

259 

260 return infos 

261 

262 

263def get_current_device_extend_op(vendor_name=None): 

264 import_vendor_extra_lib(vendor_name) 

265 global customized_ops 

266 if customized_ops is not None: 

267 return customized_ops 

268 customized_ops = [] 

269 if ops_module is not None: 

270 ops = inspect.getmembers(ops_module, inspect.isfunction) 

271 customized_ops += ops 

272 if fused_module is not None: 

273 fused_ops = inspect.getmembers(fused_module, inspect.isfunction) 

274 customized_ops += fused_ops 

275 return customized_ops 

276 

277 

278def get_curent_device_unused_op(vendor_name=None): 

279 global vendor_module # noqa: F824 

280 get_vendor_module(vendor_name) 

281 return list(vendor_module.CUSTOMIZED_UNUSED_OPS) 

282 

283 

284def get_heuristic_config(vendor_name=None): 

285 global heuristic_config_module 

286 try: 

287 heuristic_config_module = importlib.import_module( 

288 f"_{vendor_name}.heuristics_config_utils" 

289 ) 

290 except: # noqa E722 

291 heuristic_config_module = importlib.import_module( 

292 "_nvidia.heuristics_config_utils" 

293 ) 

294 if hasattr(heuristic_config_module, "HEURISTICS_CONFIGS"): 

295 return heuristic_config_module.HEURISTICS_CONFIGS 

296 return None 

297 

298 

299def get_tune_config(vendor_name=None): 

300 global vendor_module # noqa: F824 

301 get_vendor_module(vendor_name) 

302 return backend_utils.get_tune_config(vendor_name) 

303 

304 

305__all__ = ["*"]