Coverage for src/flag_gems/runtime/configloader.py: 88%
123 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import copy
2import warnings
4import triton
6from . import backend
7from .backend.device import DeviceDetector
10class ConfigLoader(object):
11 _instance = None
13 def __new__(cls, *args, **kargs):
14 if cls._instance is None:
15 cls._instance = super(ConfigLoader, cls).__new__(cls)
16 return cls._instance
18 def __init__(self):
19 if not hasattr(self, "initialized"):
20 self.initialized = True
21 self.device = DeviceDetector()
22 # primitive_yaml_config is simply the dictionary returned by yaml
23 # and is reserved from being an attr for vendor customizability
24 self.arch_specialized_yaml_config = None
25 self.arch_heuristics_config = None
26 self.vendor_primitive_yaml_config = self.get_vendor_tune_config()
27 self.default_primitive_yaml_config = self.get_default_tune_config()
28 self.vendor_heuristics_config = self.get_vendor_heuristics_config()
29 self.default_heuristics_config = self.get_default_heuristics_config()
30 try:
31 if backend.BackendArchEvent().has_arch:
32 self.arch_specialized_yaml_config = (
33 backend.BackendArchEvent().autotune_configs
34 )
35 self.arch_heuristics_config = (
36 backend.BackendArchEvent().heuristics_configs
37 )
38 except Exception as err:
39 print(f"[INFO] : {err}")
41 if self.vendor_heuristics_config is None:
42 vendorname = self.device.vendor_name
43 warnings.warn(
44 f"The {vendorname} configuration of heuristics_config is None"
45 )
46 # gen_key is an identifier that indicates whether the current config needs to be generated automatically
47 self.gen_key = "gen"
48 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config
49 self.loaded_triton_config = {}
50 self.triton_config_default = {
51 "num_stages": 2,
52 "num_warps": 4,
53 "num_ctas": 1,
54 }
55 if self.device.vendor_name in ["hygon"]:
56 self.triton_config_default = {
57 "num_stages": 2,
58 "num_warps": 4,
59 "num_ctas": 1,
60 "num_ldmatrixes": 0,
61 }
62 self.load_all()
64 def load_all(self):
65 for key in self.vendor_primitive_yaml_config:
66 self.loaded_triton_config[key] = self.get_tuned_config(key)
68 def get_vendor_heuristics_config(self):
69 return backend.get_heuristic_config(self.device.vendor_name)
71 def get_default_heuristics_config(self):
72 return backend.get_heuristic_config("nvidia")
74 def get_default_tune_config(self):
75 return backend.get_tune_config("nvidia")
77 def get_vendor_tune_config(self):
78 return backend.get_tune_config(self.device.vendor_name)
80 def get_heuristics_config(self, op_name):
81 if self.arch_heuristics_config and op_name in self.arch_heuristics_config:
82 return self.arch_heuristics_config[op_name]
83 elif op_name in self.vendor_heuristics_config:
84 return self.vendor_heuristics_config[op_name]
85 elif op_name in self.default_heuristics_config:
86 return self.default_heuristics_config[op_name]
87 else:
88 warnings.warn(f"No heuristics config found for {op_name}")
89 return None
91 def _resolve_iteration_values(self, gen_config, config_var_key):
92 if isinstance(config_var_key, (list, tuple)):
93 return config_var_key
94 if isinstance(config_var_key, int):
95 return [config_var_key]
96 return gen_config[config_var_key]
98 def _gen_impl(
99 self,
100 gen_config,
101 iteration_plan,
102 std_config,
103 ):
104 all_configs = []
105 final_step = len(iteration_plan)
106 stack = [{"cur_config": std_config, "current_step": 0}]
108 while stack:
109 cur_state = stack[-1]
110 stack.pop()
111 cur_config = cur_state.get("cur_config")
112 current_step = cur_state.get("current_step")
114 if current_step == final_step:
115 all_configs.append(
116 triton.Config(
117 cur_config["META"],
118 num_warps=cur_config["num_warps"],
119 num_stages=cur_config["num_stages"],
120 num_ctas=cur_config["num_ctas"],
121 )
122 )
123 else:
124 cur_entry = iteration_plan[current_step]
125 cur_key = cur_entry["key"]
126 key_config = self._resolve_iteration_values(
127 gen_config, cur_entry["source"]
128 )
129 for single_value in key_config:
130 new_config = copy.deepcopy(cur_config)
131 if cur_entry["kind"] == "meta_field":
132 new_config["META"][cur_key] = single_value
133 elif cur_entry["kind"] == "meta_block":
134 new_config["META"] = copy.deepcopy(single_value)
135 else:
136 new_config[cur_key] = single_value
137 stack.append(
138 {
139 "cur_config": new_config,
140 "current_step": current_step + 1,
141 }
142 )
143 return all_configs
145 def to_gen_config(self, gen_config):
146 param_config = gen_config["param_map"]
147 meta_config = param_config["META"]
148 iteration_plan = []
150 if isinstance(meta_config, dict):
151 for meta_key, source in meta_config.items():
152 iteration_plan.append(
153 {"key": meta_key, "source": source, "kind": "meta_field"}
154 )
155 else:
156 iteration_plan.append(
157 {"key": "META", "source": meta_config, "kind": "meta_block"}
158 )
160 for key, source in param_config.items():
161 if key == "META":
162 continue
163 iteration_plan.append(
164 {"key": key, "source": source, "kind": "config_field"}
165 )
167 current_config = {"META": {}}
168 current_config.update(self.triton_config_default)
169 return self._gen_impl(
170 gen_config,
171 iteration_plan,
172 current_config,
173 )
175 def get_tuned_config(self, op_name):
176 if op_name in self.loaded_triton_config:
177 return self.loaded_triton_config[op_name]
179 if (
180 self.arch_specialized_yaml_config
181 and op_name in self.arch_specialized_yaml_config
182 ):
183 current_op_configs = self.arch_specialized_yaml_config[op_name]
184 elif op_name in self.vendor_primitive_yaml_config:
185 current_op_configs = self.vendor_primitive_yaml_config[op_name]
186 else:
187 current_op_configs = self.default_primitive_yaml_config[op_name]
189 configs = []
190 if len(current_op_configs) == 0:
191 return configs
193 for single_config in current_op_configs:
194 if self.gen_key in single_config:
195 configs.extend(self.to_gen_config(single_config))
196 continue
198 current_config = copy.deepcopy(self.triton_config_default)
199 for default_param in current_config:
200 if default_param in single_config:
201 current_config[default_param] = single_config[default_param]
203 if self.device.vendor_name in ["hygon"]:
204 configs.append(
205 triton.Config(
206 single_config["META"],
207 num_warps=current_config["num_warps"],
208 num_stages=current_config["num_stages"],
209 num_ctas=current_config["num_ctas"],
210 num_ldmatrixes=current_config["num_ldmatrixes"],
211 )
212 )
213 else:
214 configs.append(
215 triton.Config(
216 single_config["META"],
217 num_warps=current_config["num_warps"],
218 num_stages=current_config["num_stages"],
219 num_ctas=current_config["num_ctas"],
220 )
221 )
222 return configs