Coverage for src/flag_gems/config.py: 57%

77 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import os 

2import warnings 

3from pathlib import Path 

4 

5import yaml 

6 

7# Optional imports used inside helper functions to avoid hard dependencies at 

8# module import time. 

9try: # pragma: no cover - best effort fallback 

10 from flag_gems import runtime as _runtime 

11except Exception: # noqa: BLE001 

12 _runtime = None 

13 

14has_c_extension = False 

15use_c_extension = False 

16aten_patch_list = [] 

17 

18# set FLAGGEMS_SOURCE_DIR for cpp extension to find 

19os.environ["FLAGGEMS_SOURCE_DIR"] = str(Path(__file__).parent.resolve()) 

20 

21try: 

22 from flag_gems import c_operators 

23 

24 has_c_extension = True 

25except ImportError: 

26 c_operators = None 

27 has_c_extension = False 

28 

29 

30use_env_c_extension = os.environ.get("USE_C_EXTENSION", "0") == "1" 

31if use_env_c_extension and not has_c_extension: 

32 warnings.warn( 

33 "[FlagGems] USE_C_EXTENSION is set, but C extension is not available. " 

34 "Falling back to pure Python implementation.", 

35 RuntimeWarning, 

36 ) 

37 

38if has_c_extension and use_env_c_extension: 

39 try: 

40 from flag_gems import aten_patch 

41 

42 aten_patch_list = aten_patch.get_registered_ops() 

43 use_c_extension = True 

44 except (ImportError, AttributeError): 

45 aten_patch_list = [] 

46 use_c_extension = False 

47 

48 

49def load_enable_config_from_yaml(yaml_path, key="include"): 

50 """ 

51 Load include/exclude operator lists from a YAML file. 

52 

53 Expected YAML structure: 

54 include: # operators to explicitly enable 

55 - op_a 

56 - op_b 

57 exclude: # operators to skip 

58 - op_c 

59 

60 Both keys are optional; missing keys default to empty lists. 

61 Returns two lists `include` and `exclude`. 

62 """ 

63 yaml_path = Path(yaml_path) 

64 if not yaml_path.is_file(): 

65 warnings.warn(f"load_enable_config_from_yaml: yaml not found: {yaml_path}") 

66 return [] 

67 

68 try: 

69 data = yaml.safe_load(yaml_path.read_text()) 

70 except Exception as err: 

71 warnings.warn( 

72 f"load_enable_config_from_yaml: unexpected error reading {yaml_path}: {err}" 

73 ) 

74 return [] 

75 

76 if key not in ("include", "exclude"): 

77 warnings.warn( 

78 f"load_enable_config_from_yaml: key must be 'include' or 'exclude', got: {key}" 

79 ) 

80 return [] 

81 

82 if data is None: 

83 return [] 

84 

85 if isinstance(data, dict): 

86 operator_list = list(set(data.get(key, []))) 

87 return operator_list 

88 

89 warnings.warn( 

90 f"load_enable_config_from_yaml: yaml {yaml_path} must be a mapping with 'include'/'exclude' lists" 

91 ) 

92 return [] 

93 

94 

95def get_default_enable_config(vendor_name=None, arch_name=None): 

96 base_dir = Path(__file__).resolve().parent / "runtime" / "backend" 

97 vendor_dir = base_dir / f"_{vendor_name}" if vendor_name else base_dir 

98 

99 candidates = [] 

100 if vendor_dir.is_dir(): 

101 if arch_name: 

102 candidates.append(vendor_dir / arch_name / "enable_configs.yaml") 

103 candidates.append(vendor_dir / "enable_configs.yaml") 

104 candidates.append( 

105 base_dir / "_nvidia" / "enable_configs.yaml" 

106 ) # use nvidia as default 

107 return candidates 

108 

109 

110def resolve_user_setting(user_setting_info, user_setting_type="include"): 

111 """ 

112 Resolve user setting for include/exclude operator lists. 

113 

114 Args: 

115 user_setting_info: Can be a list/tuple/set of operators, "default", None, or a path to a YAML file. 

116 user_setting_type: Either "include" or "exclude". 

117 

118 Returns: 

119 List of operators based on the user setting. 

120 """ 

121 # If user_setting_info is a list, tuple, or set, use it directly as the operator list (deduplicated) 

122 if isinstance(user_setting_info, (list, tuple, set)): 

123 return list(set(user_setting_info)) 

124 

125 yaml_candidates = [] 

126 # If set to "default" or None (for include type), 

127 # load from default YAML config files based on vendor and architecture 

128 if user_setting_info == "default" or ( 

129 user_setting_type == "include" and user_setting_info is None 

130 ): 

131 # Lazily infer vendor/arch if not provided. 

132 vendor_name = _runtime.device.vendor_name 

133 arch_event = _runtime.backend.BackendArchEvent() 

134 if arch_event.has_arch: 

135 arch_name = getattr(arch_event, "arch", None) 

136 yaml_candidates = get_default_enable_config(vendor_name, arch_name) 

137 

138 # If user_setting_info is a string, treat it as a YAML file path 

139 elif isinstance(user_setting_info, str): 

140 yaml_candidates.append(user_setting_info) 

141 

142 # Iterate through candidate YAML paths and try to load the operator list 

143 for yaml_path in yaml_candidates: 

144 operator_list = load_enable_config_from_yaml(yaml_path, user_setting_type) 

145 if operator_list: 

146 return operator_list 

147 else: 

148 warnings.warn( 

149 f"resolve_user_setting: {user_setting_type} yaml not found: {yaml_path}" 

150 ) 

151 

152 # If no operators found in any YAML, warn and return empty list 

153 warnings.warn( 

154 f"resolve_user_setting: no {user_setting_type} ops found; returning empty list" 

155 ) 

156 return [] 

157 

158 

159__all__ = [ 

160 "aten_patch_list", 

161 "has_c_extension", 

162 "use_c_extension", 

163 "resolve_user_setting", 

164]