Coverage for src/flag_gems/runtime/backend/backend_utils.py: 59%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import functools
2import os
3from dataclasses import dataclass
5import yaml
8# Metadata template, Each vendor needs to specialize instances of this template
9@dataclass
10class VendorInfoBase:
11 vendor_name: str
12 device_name: str
13 device_query_cmd: str
14 dispatch_key: str = None
15 triton_extra_name: str = None
18def get_tune_config(vendor_name=None, file_mode="r", file_path=None):
19 BACKEND_EVENT = file_path is not None
20 config = None
21 try:
22 if not file_path:
23 vendor_name = "_" + vendor_name
24 script_path = os.path.abspath(__file__)
25 base_dir = os.path.dirname(script_path)
26 file_path = os.path.join(base_dir, vendor_name, "tune_configs.yaml")
27 else:
28 file_path = os.path.join(file_path, "tune_configs.yaml")
29 with open(file_path, file_mode) as file:
30 config = yaml.safe_load(file)
31 except FileNotFoundError:
32 if not BACKEND_EVENT:
33 raise FileNotFoundError(f"Configuration file not found: {file_path}")
34 except yaml.YAMLError as e:
35 raise ValueError(f"Failed to parse YAML file: {e}")
36 except Exception as e:
37 raise RuntimeError(f"An unexpected error occurred: {e}")
39 return config
42@functools.lru_cache(maxsize=None)
43def _load_expand_config(file_path, file_mode="r"):
44 with open(file_path, file_mode) as file:
45 return yaml.safe_load(file) or {}
48def get_expand_config(op_name=None, file_mode="r", file_path=None):
49 if not file_path:
50 raise ValueError("expand config file path is required")
51 try:
52 config = _load_expand_config(file_path, file_mode)
53 except FileNotFoundError:
54 raise FileNotFoundError(f"Configuration file not found: {file_path}")
55 except yaml.YAMLError as e:
56 raise ValueError(f"Failed to parse YAML file: {e}")
57 except Exception as e:
58 raise RuntimeError(f"An unexpected error occurred: {e}")
59 if op_name is None:
60 return config
61 return config.get(op_name)