Coverage for src/flag_gems/runtime/configloader.py: 88%

123 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import copy 

2import warnings 

3 

4import triton 

5 

6from . import backend 

7from .backend.device import DeviceDetector 

8 

9 

10class ConfigLoader(object): 

11 _instance = None 

12 

13 def __new__(cls, *args, **kargs): 

14 if cls._instance is None: 

15 cls._instance = super(ConfigLoader, cls).__new__(cls) 

16 return cls._instance 

17 

18 def __init__(self): 

19 if not hasattr(self, "initialized"): 

20 self.initialized = True 

21 self.device = DeviceDetector() 

22 # primitive_yaml_config is simply the dictionary returned by yaml 

23 # and is reserved from being an attr for vendor customizability 

24 self.arch_specialized_yaml_config = None 

25 self.arch_heuristics_config = None 

26 self.vendor_primitive_yaml_config = self.get_vendor_tune_config() 

27 self.default_primitive_yaml_config = self.get_default_tune_config() 

28 self.vendor_heuristics_config = self.get_vendor_heuristics_config() 

29 self.default_heuristics_config = self.get_default_heuristics_config() 

30 try: 

31 if backend.BackendArchEvent().has_arch: 

32 self.arch_specialized_yaml_config = ( 

33 backend.BackendArchEvent().autotune_configs 

34 ) 

35 self.arch_heuristics_config = ( 

36 backend.BackendArchEvent().heuristics_configs 

37 ) 

38 except Exception as err: 

39 print(f"[INFO] : {err}") 

40 

41 if self.vendor_heuristics_config is None: 

42 vendorname = self.device.vendor_name 

43 warnings.warn( 

44 f"The {vendorname} configuration of heuristics_config is None" 

45 ) 

46 # gen_key is an identifier that indicates whether the current config needs to be generated automatically 

47 self.gen_key = "gen" 

48 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config 

49 self.loaded_triton_config = {} 

50 self.triton_config_default = { 

51 "num_stages": 2, 

52 "num_warps": 4, 

53 "num_ctas": 1, 

54 } 

55 if self.device.vendor_name in ["hygon"]: 

56 self.triton_config_default = { 

57 "num_stages": 2, 

58 "num_warps": 4, 

59 "num_ctas": 1, 

60 "num_ldmatrixes": 0, 

61 } 

62 self.load_all() 

63 

64 def load_all(self): 

65 for key in self.vendor_primitive_yaml_config: 

66 self.loaded_triton_config[key] = self.get_tuned_config(key) 

67 

68 def get_vendor_heuristics_config(self): 

69 return backend.get_heuristic_config(self.device.vendor_name) 

70 

71 def get_default_heuristics_config(self): 

72 return backend.get_heuristic_config("nvidia") 

73 

74 def get_default_tune_config(self): 

75 return backend.get_tune_config("nvidia") 

76 

77 def get_vendor_tune_config(self): 

78 return backend.get_tune_config(self.device.vendor_name) 

79 

80 def get_heuristics_config(self, op_name): 

81 if self.arch_heuristics_config and op_name in self.arch_heuristics_config: 

82 return self.arch_heuristics_config[op_name] 

83 elif op_name in self.vendor_heuristics_config: 

84 return self.vendor_heuristics_config[op_name] 

85 elif op_name in self.default_heuristics_config: 

86 return self.default_heuristics_config[op_name] 

87 else: 

88 warnings.warn(f"No heuristics config found for {op_name}") 

89 return None 

90 

91 def _resolve_iteration_values(self, gen_config, config_var_key): 

92 if isinstance(config_var_key, (list, tuple)): 

93 return config_var_key 

94 if isinstance(config_var_key, int): 

95 return [config_var_key] 

96 return gen_config[config_var_key] 

97 

98 def _gen_impl( 

99 self, 

100 gen_config, 

101 iteration_plan, 

102 std_config, 

103 ): 

104 all_configs = [] 

105 final_step = len(iteration_plan) 

106 stack = [{"cur_config": std_config, "current_step": 0}] 

107 

108 while stack: 

109 cur_state = stack[-1] 

110 stack.pop() 

111 cur_config = cur_state.get("cur_config") 

112 current_step = cur_state.get("current_step") 

113 

114 if current_step == final_step: 

115 all_configs.append( 

116 triton.Config( 

117 cur_config["META"], 

118 num_warps=cur_config["num_warps"], 

119 num_stages=cur_config["num_stages"], 

120 num_ctas=cur_config["num_ctas"], 

121 ) 

122 ) 

123 else: 

124 cur_entry = iteration_plan[current_step] 

125 cur_key = cur_entry["key"] 

126 key_config = self._resolve_iteration_values( 

127 gen_config, cur_entry["source"] 

128 ) 

129 for single_value in key_config: 

130 new_config = copy.deepcopy(cur_config) 

131 if cur_entry["kind"] == "meta_field": 

132 new_config["META"][cur_key] = single_value 

133 elif cur_entry["kind"] == "meta_block": 

134 new_config["META"] = copy.deepcopy(single_value) 

135 else: 

136 new_config[cur_key] = single_value 

137 stack.append( 

138 { 

139 "cur_config": new_config, 

140 "current_step": current_step + 1, 

141 } 

142 ) 

143 return all_configs 

144 

145 def to_gen_config(self, gen_config): 

146 param_config = gen_config["param_map"] 

147 meta_config = param_config["META"] 

148 iteration_plan = [] 

149 

150 if isinstance(meta_config, dict): 

151 for meta_key, source in meta_config.items(): 

152 iteration_plan.append( 

153 {"key": meta_key, "source": source, "kind": "meta_field"} 

154 ) 

155 else: 

156 iteration_plan.append( 

157 {"key": "META", "source": meta_config, "kind": "meta_block"} 

158 ) 

159 

160 for key, source in param_config.items(): 

161 if key == "META": 

162 continue 

163 iteration_plan.append( 

164 {"key": key, "source": source, "kind": "config_field"} 

165 ) 

166 

167 current_config = {"META": {}} 

168 current_config.update(self.triton_config_default) 

169 return self._gen_impl( 

170 gen_config, 

171 iteration_plan, 

172 current_config, 

173 ) 

174 

175 def get_tuned_config(self, op_name): 

176 if op_name in self.loaded_triton_config: 

177 return self.loaded_triton_config[op_name] 

178 

179 if ( 

180 self.arch_specialized_yaml_config 

181 and op_name in self.arch_specialized_yaml_config 

182 ): 

183 current_op_configs = self.arch_specialized_yaml_config[op_name] 

184 elif op_name in self.vendor_primitive_yaml_config: 

185 current_op_configs = self.vendor_primitive_yaml_config[op_name] 

186 else: 

187 current_op_configs = self.default_primitive_yaml_config[op_name] 

188 

189 configs = [] 

190 if len(current_op_configs) == 0: 

191 return configs 

192 

193 for single_config in current_op_configs: 

194 if self.gen_key in single_config: 

195 configs.extend(self.to_gen_config(single_config)) 

196 continue 

197 

198 current_config = copy.deepcopy(self.triton_config_default) 

199 for default_param in current_config: 

200 if default_param in single_config: 

201 current_config[default_param] = single_config[default_param] 

202 

203 if self.device.vendor_name in ["hygon"]: 

204 configs.append( 

205 triton.Config( 

206 single_config["META"], 

207 num_warps=current_config["num_warps"], 

208 num_stages=current_config["num_stages"], 

209 num_ctas=current_config["num_ctas"], 

210 num_ldmatrixes=current_config["num_ldmatrixes"], 

211 ) 

212 ) 

213 else: 

214 configs.append( 

215 triton.Config( 

216 single_config["META"], 

217 num_warps=current_config["num_warps"], 

218 num_stages=current_config["num_stages"], 

219 num_ctas=current_config["num_ctas"], 

220 ) 

221 ) 

222 return configs