Coverage for src/flag_gems/config.py: 57%
77 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import os
2import warnings
3from pathlib import Path
5import yaml
7# Optional imports used inside helper functions to avoid hard dependencies at
8# module import time.
9try: # pragma: no cover - best effort fallback
10 from flag_gems import runtime as _runtime
11except Exception: # noqa: BLE001
12 _runtime = None
14has_c_extension = False
15use_c_extension = False
16aten_patch_list = []
18# set FLAGGEMS_SOURCE_DIR for cpp extension to find
19os.environ["FLAGGEMS_SOURCE_DIR"] = str(Path(__file__).parent.resolve())
21try:
22 from flag_gems import c_operators
24 has_c_extension = True
25except ImportError:
26 c_operators = None
27 has_c_extension = False
30use_env_c_extension = os.environ.get("USE_C_EXTENSION", "0") == "1"
31if use_env_c_extension and not has_c_extension:
32 warnings.warn(
33 "[FlagGems] USE_C_EXTENSION is set, but C extension is not available. "
34 "Falling back to pure Python implementation.",
35 RuntimeWarning,
36 )
38if has_c_extension and use_env_c_extension:
39 try:
40 from flag_gems import aten_patch
42 aten_patch_list = aten_patch.get_registered_ops()
43 use_c_extension = True
44 except (ImportError, AttributeError):
45 aten_patch_list = []
46 use_c_extension = False
49def load_enable_config_from_yaml(yaml_path, key="include"):
50 """
51 Load include/exclude operator lists from a YAML file.
53 Expected YAML structure:
54 include: # operators to explicitly enable
55 - op_a
56 - op_b
57 exclude: # operators to skip
58 - op_c
60 Both keys are optional; missing keys default to empty lists.
61 Returns two lists `include` and `exclude`.
62 """
63 yaml_path = Path(yaml_path)
64 if not yaml_path.is_file():
65 warnings.warn(f"load_enable_config_from_yaml: yaml not found: {yaml_path}")
66 return []
68 try:
69 data = yaml.safe_load(yaml_path.read_text())
70 except Exception as err:
71 warnings.warn(
72 f"load_enable_config_from_yaml: unexpected error reading {yaml_path}: {err}"
73 )
74 return []
76 if key not in ("include", "exclude"):
77 warnings.warn(
78 f"load_enable_config_from_yaml: key must be 'include' or 'exclude', got: {key}"
79 )
80 return []
82 if data is None:
83 return []
85 if isinstance(data, dict):
86 operator_list = list(set(data.get(key, [])))
87 return operator_list
89 warnings.warn(
90 f"load_enable_config_from_yaml: yaml {yaml_path} must be a mapping with 'include'/'exclude' lists"
91 )
92 return []
95def get_default_enable_config(vendor_name=None, arch_name=None):
96 base_dir = Path(__file__).resolve().parent / "runtime" / "backend"
97 vendor_dir = base_dir / f"_{vendor_name}" if vendor_name else base_dir
99 candidates = []
100 if vendor_dir.is_dir():
101 if arch_name:
102 candidates.append(vendor_dir / arch_name / "enable_configs.yaml")
103 candidates.append(vendor_dir / "enable_configs.yaml")
104 candidates.append(
105 base_dir / "_nvidia" / "enable_configs.yaml"
106 ) # use nvidia as default
107 return candidates
110def resolve_user_setting(user_setting_info, user_setting_type="include"):
111 """
112 Resolve user setting for include/exclude operator lists.
114 Args:
115 user_setting_info: Can be a list/tuple/set of operators, "default", None, or a path to a YAML file.
116 user_setting_type: Either "include" or "exclude".
118 Returns:
119 List of operators based on the user setting.
120 """
121 # If user_setting_info is a list, tuple, or set, use it directly as the operator list (deduplicated)
122 if isinstance(user_setting_info, (list, tuple, set)):
123 return list(set(user_setting_info))
125 yaml_candidates = []
126 # If set to "default" or None (for include type),
127 # load from default YAML config files based on vendor and architecture
128 if user_setting_info == "default" or (
129 user_setting_type == "include" and user_setting_info is None
130 ):
131 # Lazily infer vendor/arch if not provided.
132 vendor_name = _runtime.device.vendor_name
133 arch_event = _runtime.backend.BackendArchEvent()
134 if arch_event.has_arch:
135 arch_name = getattr(arch_event, "arch", None)
136 yaml_candidates = get_default_enable_config(vendor_name, arch_name)
138 # If user_setting_info is a string, treat it as a YAML file path
139 elif isinstance(user_setting_info, str):
140 yaml_candidates.append(user_setting_info)
142 # Iterate through candidate YAML paths and try to load the operator list
143 for yaml_path in yaml_candidates:
144 operator_list = load_enable_config_from_yaml(yaml_path, user_setting_type)
145 if operator_list:
146 return operator_list
147 else:
148 warnings.warn(
149 f"resolve_user_setting: {user_setting_type} yaml not found: {yaml_path}"
150 )
152 # If no operators found in any YAML, warn and return empty list
153 warnings.warn(
154 f"resolve_user_setting: no {user_setting_type} ops found; returning empty list"
155 )
156 return []
159__all__ = [
160 "aten_patch_list",
161 "has_c_extension",
162 "use_c_extension",
163 "resolve_user_setting",
164]