Coverage for src/flag_gems/runtime/configloader.py: 62%
201 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import copy
2import warnings
4import triton
6from . import backend, common
7from .backend.device import DeviceDetector
10class ConfigLoader(object):
11 _instance = None
13 def __new__(cls, *args, **kargs):
14 if cls._instance is None:
15 cls._instance = super(ConfigLoader, cls).__new__(cls)
16 return cls._instance
18 def __init__(self):
19 if not hasattr(self, "initialized"):
20 self.initialized = True
21 self.device = DeviceDetector()
22 # primitive_yaml_config is simply the dictionary returned by yaml
23 # and is reserved from being an attr for vendor customizability
24 self.arch_specialized_yaml_config = None
25 self.arch_heuristics_config = None
26 self.vendor_primitive_yaml_config = self.get_vendor_tune_config()
27 self.default_primitive_yaml_config = self.get_default_tune_config()
28 self.vendor_heuristics_config = self.get_vendor_heuristics_config()
29 self.default_heuristics_config = self.get_default_heuristics_config()
30 self.update_config_from_arch()
32 if self.vendor_heuristics_config is None:
33 vendorname = self.device.vendor_name
34 warnings.warn(
35 f"The {vendorname} configuration of heuristics_config is None"
36 )
37 # gen_key is an identifier that indicates whether the current config needs to be generated automatically
38 self.gen_key = "gen"
39 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config
40 self.loaded_triton_config = {}
41 self.triton_config_default = {
42 "num_stages": 2,
43 "num_warps": 4,
44 "num_ctas": 1,
45 }
46 if self.device.vendor_name == "hygon":
47 self.triton_config_default["num_ldmatrixes"] = 0
48 self.expand_config_registry = self._build_expand_registry()
49 self.load_all()
51 def update_config_from_arch(self):
52 try:
53 archEvent = backend.BackendArchEvent()
54 if archEvent.has_arch:
55 self.arch_specialized_yaml_config = archEvent.autotune_configs
56 self.arch_heuristics_config = archEvent.heuristics_configs
57 except Exception as err:
58 print(f"[INFO] : {err}")
60 def _get_op_configs(self, op_name):
61 """Get config for op_name from available config sources."""
62 for config in (
63 self.arch_specialized_yaml_config,
64 self.vendor_primitive_yaml_config,
65 self.default_primitive_yaml_config,
66 ):
67 if config and op_name in config:
68 return config[op_name]
69 return []
71 def _create_triton_config(self, single_config, current_config):
72 """Create a triton.Config with appropriate parameters."""
73 kwargs = {
74 "num_warps": current_config["num_warps"],
75 "num_stages": current_config["num_stages"],
76 "num_ctas": current_config["num_ctas"],
77 }
78 if self.device.vendor_name == "hygon":
79 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"]
80 return triton.Config(single_config["META"], **kwargs)
82 def _build_configs_by_op(self, op_name, ranges, pre_hook=None):
83 if op_name == "bmm":
84 return [
85 triton.Config(
86 {
87 "TILE_M": block_m,
88 "TILE_N": block_n,
89 "TILE_K": block_k,
90 "GROUP_M": 1 if block_m == 32 else 2,
91 },
92 num_stages=s,
93 num_warps=w,
94 pre_hook=pre_hook,
95 )
96 for block_m in ranges["BLOCK_M"]
97 for block_n in ranges["BLOCK_N"]
98 for block_k in ranges["BLOCK_K"]
99 for s in ranges["s"]
100 for w in ranges["w"]
101 ]
103 if op_name == "addmm":
104 return [
105 triton.Config(
106 {
107 "BLOCK_SIZE_M": block_m,
108 "BLOCK_SIZE_N": block_n,
109 "BLOCK_SIZE_K": block_k,
110 },
111 num_stages=s,
112 num_warps=w,
113 pre_hook=pre_hook,
114 )
115 for block_m in ranges["BLOCK_M"]
116 for block_n in ranges["BLOCK_N"]
117 for block_k in ranges["BLOCK_K"]
118 for s in ranges["s"]
119 for w in ranges["w"]
120 ]
122 if op_name == "baddbmm":
123 return [
124 triton.Config(
125 {
126 "TILE_M": block_m,
127 "TILE_N": block_n,
128 "TILE_K": block_k,
129 "GROUP_M": 1 if block_m <= 32 else 2,
130 },
131 num_stages=s,
132 num_warps=w,
133 pre_hook=pre_hook,
134 )
135 for block_m in ranges["BLOCK_M"]
136 for block_n in ranges["BLOCK_N"]
137 for block_k in ranges["BLOCK_K"]
138 for s in ranges["s"]
139 for w in ranges["w"]
140 ]
142 if op_name == "mv":
143 return [
144 triton.Config(
145 {
146 "BLOCK_N": block_n,
147 "BLOCK_M": block_m,
148 },
149 num_stages=s,
150 num_warps=w,
151 pre_hook=pre_hook,
152 )
153 for block_n in ranges["BLOCK_N"]
154 for block_m in ranges["BLOCK_M"]
155 for s in ranges["s"]
156 for w in ranges["w"]
157 ]
159 if op_name == "mm_general_tma":
160 return [
161 triton.Config(
162 {
163 "BLOCK_M": block_m,
164 "BLOCK_N": block_n,
165 "BLOCK_K": block_k,
166 },
167 num_stages=s,
168 num_warps=w,
169 pre_hook=pre_hook,
170 )
171 for block_m in ranges["BLOCK_M"]
172 for block_n in ranges["BLOCK_N"]
173 for block_k in ranges["BLOCK_K"]
174 for s in ranges["s"]
175 for w in ranges["w"]
176 ]
178 if op_name in ("mm", "mm_sqmma"):
179 return [
180 triton.Config(
181 {
182 "BLOCK_M": block_m,
183 "BLOCK_N": block_n,
184 "BLOCK_K": block_k,
185 },
186 num_stages=s,
187 num_warps=w,
188 pre_hook=pre_hook,
189 )
190 for block_m in ranges["BLOCK_M"]
191 for block_n in ranges["BLOCK_N"]
192 for block_k in ranges["BLOCK_K"]
193 for s in ranges["s"]
194 for w in ranges["w"]
195 ]
197 if op_name in ("bmm_sqmma", "addmm_sqmma"):
198 return [
199 triton.Config(
200 {
201 "BLOCK_SIZE_M": block_m,
202 "BLOCK_SIZE_N": block_n,
203 "BLOCK_SIZE_K": block_k,
204 },
205 num_stages=s,
206 num_warps=w,
207 pre_hook=pre_hook,
208 )
209 for block_m in ranges["BLOCK_M"]
210 for block_n in ranges["BLOCK_N"]
211 for block_k in ranges["BLOCK_K"]
212 for s in ranges["s"]
213 for w in ranges["w"]
214 ]
216 if op_name == "gemv":
217 return [
218 triton.Config(
219 {"BLOCK_M": block_m, "BLOCK_K": block_k},
220 num_stages=s,
221 num_warps=w,
222 pre_hook=pre_hook,
223 )
224 for block_m in ranges["BLOCK_M"]
225 for block_k in ranges["BLOCK_K"]
226 for s in ranges["s"]
227 for w in ranges["w"]
228 ]
230 if op_name == "sparse_attention":
231 return [
232 triton.Config(
233 {"BLOCK": block},
234 num_stages=s,
235 num_warps=w,
236 pre_hook=pre_hook,
237 )
238 for block in ranges["BLOCK"]
239 for s in ranges["s"]
240 for w in ranges["w"]
241 ]
243 if op_name == "w8a8_block_fp8_general":
244 return [
245 triton.Config(
246 {
247 "BLOCK_M": block_m,
248 "BLOCK_N": block_n,
249 "BLOCK_K": block_k,
250 "GROUP_M": group_m,
251 },
252 num_stages=s,
253 num_warps=w,
254 pre_hook=pre_hook,
255 )
256 for block_m in ranges["BLOCK_M"]
257 for block_n in ranges["BLOCK_N"]
258 for block_k in ranges["BLOCK_K"]
259 for group_m in ranges["GROUP_M"]
260 for s in ranges["s"]
261 for w in ranges["w"]
262 ]
264 if op_name == "w8a8_block_fp8_general_tma":
265 group_m_values = ranges.get("GROUP_M", [None])
266 return [
267 triton.Config(
268 dict(
269 {
270 "BLOCK_M": block_m,
271 "BLOCK_N": block_n,
272 "BLOCK_K": block_k,
273 },
274 **({} if group_m is None else {"GROUP_M": group_m}),
275 ),
276 num_stages=s,
277 num_warps=w,
278 pre_hook=pre_hook,
279 )
280 for block_m in ranges["BLOCK_M"]
281 for block_n in ranges["BLOCK_N"]
282 for block_k in ranges["BLOCK_K"]
283 for group_m in group_m_values
284 for s in ranges["s"]
285 for w in ranges["w"]
286 ]
288 if op_name == "w8a8_block_fp8_general_splitk":
289 return [
290 triton.Config(
291 {
292 "BLOCK_M": block_m,
293 "BLOCK_N": block_n,
294 "BLOCK_K": block_k,
295 "SPLIT_K": split_k,
296 },
297 num_stages=s,
298 num_warps=w,
299 pre_hook=pre_hook,
300 )
301 for block_m in ranges["BLOCK_M"]
302 for block_n in ranges["BLOCK_N"]
303 for block_k in ranges["BLOCK_K"]
304 for split_k in ranges["SPLIT_K"]
305 for s in ranges["s"]
306 for w in ranges["w"]
307 ]
309 return []
311 def _build_single_expand_spec(
312 self,
313 op_name,
314 expand_yaml_path=None,
315 yaml_op_name=None,
316 ):
317 return {
318 "yaml_op_name": yaml_op_name or op_name,
319 "key": common.OP_KEY_ORDERS[op_name],
320 "default_strategy": common.DEFAULT_STRATEGIES[op_name],
321 "expand_yaml_path": expand_yaml_path,
322 }
324 def _build_expand_registry(self):
325 DEFAULT_EXPAND_CONFIG_PATH = common.DEFAULT_EXPAND_CONFIG_PATH
326 return {
327 "bmm": self._build_single_expand_spec(
328 "bmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
329 ),
330 "addmm": self._build_single_expand_spec(
331 "addmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
332 ),
333 "baddbmm": self._build_single_expand_spec(
334 "baddbmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
335 ),
336 "mv": self._build_single_expand_spec(
337 "mv", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
338 ),
339 "w8a8_block_fp8_general": self._build_single_expand_spec(
340 "w8a8_block_fp8_general"
341 ),
342 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec(
343 "w8a8_block_fp8_general_splitk"
344 ),
345 "w8a8_block_fp8_general_tma": self._build_single_expand_spec(
346 "w8a8_block_fp8_general_tma"
347 ),
348 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"),
349 "gemv": self._build_single_expand_spec("gemv"),
350 "sparse_attention": self._build_single_expand_spec("sparse_attention"),
351 "mm": self._build_single_expand_spec("mm"),
352 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"),
353 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"),
354 }
356 def load_all(self):
357 for key in self.vendor_primitive_yaml_config:
358 self.loaded_triton_config[key] = self.get_tuned_config(key)
360 def get_vendor_heuristics_config(self):
361 return backend.get_heuristic_config(self.device.vendor_name)
363 def get_default_heuristics_config(self):
364 return backend.get_heuristic_config("nvidia")
366 def get_default_tune_config(self):
367 return backend.get_tune_config("nvidia")
369 def get_vendor_tune_config(self):
370 return backend.get_tune_config(self.device.vendor_name)
372 def get_heuristics_config(self, op_name):
373 if self.arch_heuristics_config and op_name in self.arch_heuristics_config:
374 return self.arch_heuristics_config[op_name]
375 elif op_name in self.vendor_heuristics_config:
376 return self.vendor_heuristics_config[op_name]
377 elif op_name in self.default_heuristics_config:
378 return self.default_heuristics_config[op_name]
379 else:
380 warnings.warn(f"No heuristics config found for {op_name}")
381 return None
383 def _resolve_iteration_values(self, gen_config, config_var_key):
384 if isinstance(config_var_key, (list, tuple)):
385 return config_var_key
386 if isinstance(config_var_key, int):
387 return [config_var_key]
388 return gen_config[config_var_key]
390 def _gen_impl(
391 self,
392 gen_config,
393 iteration_plan,
394 std_config,
395 ):
396 all_configs = []
397 final_step = len(iteration_plan)
398 stack = [{"cur_config": std_config, "current_step": 0}]
400 while stack:
401 cur_state = stack[-1]
402 stack.pop()
403 cur_config = cur_state.get("cur_config")
404 current_step = cur_state.get("current_step")
406 if current_step == final_step:
407 all_configs.append(
408 triton.Config(
409 cur_config["META"],
410 num_warps=cur_config["num_warps"],
411 num_stages=cur_config["num_stages"],
412 num_ctas=cur_config["num_ctas"],
413 )
414 )
415 else:
416 cur_entry = iteration_plan[current_step]
417 cur_key = cur_entry["key"]
418 key_config = self._resolve_iteration_values(
419 gen_config, cur_entry["source"]
420 )
421 for single_value in key_config:
422 new_config = copy.deepcopy(cur_config)
423 if cur_entry["kind"] == "meta_field":
424 new_config["META"][cur_key] = single_value
425 elif cur_entry["kind"] == "meta_block":
426 new_config["META"] = copy.deepcopy(single_value)
427 else:
428 new_config[cur_key] = single_value
429 stack.append(
430 {
431 "cur_config": new_config,
432 "current_step": current_step + 1,
433 }
434 )
435 return all_configs
437 def to_gen_config(self, gen_config):
438 param_config = gen_config["param_map"]
439 meta_config = param_config["META"]
440 iteration_plan = []
442 if isinstance(meta_config, dict):
443 for meta_key, source in meta_config.items():
444 iteration_plan.append(
445 {"key": meta_key, "source": source, "kind": "meta_field"}
446 )
447 else:
448 iteration_plan.append(
449 {"key": "META", "source": meta_config, "kind": "meta_block"}
450 )
452 for key, source in param_config.items():
453 if key == "META":
454 continue
455 iteration_plan.append(
456 {"key": key, "source": source, "kind": "config_field"}
457 )
459 current_config = {"META": {}}
460 current_config.update(self.triton_config_default)
461 return self._gen_impl(
462 gen_config,
463 iteration_plan,
464 current_config,
465 )
467 def get_expand_config(self, op_name, yaml_path=None):
468 op_spec = self.expand_config_registry.get(op_name)
469 if op_spec is None:
470 return -1
472 key = op_spec.get("key", [])
473 default_strategy = op_spec.get("default_strategy")
474 expand_yaml_path = op_spec.get("expand_yaml_path") or yaml_path
475 yaml_op_name = op_spec.get("yaml_op_name", op_name)
477 try:
478 expand_configs = backend.get_expand_config(
479 op_name=yaml_op_name,
480 file_path=expand_yaml_path,
481 )
482 if not isinstance(expand_configs, list):
483 return -1
485 gen_config = None
486 strategy_config = None
487 for single_config in expand_configs:
488 if isinstance(single_config, dict) and "param_map" in single_config:
489 gen_config = single_config
491 if isinstance(single_config, dict) and "strategy" in single_config:
492 strategy_config = single_config.get("strategy")
494 param_map = gen_config.get("param_map")
495 meta_map = param_map.get("META")
497 strategy = default_strategy
498 if isinstance(strategy_config, dict):
499 strategy = [
500 strategy_config.get(k, default_strategy[idx])
501 for idx, k in enumerate(key)
502 ]
504 ranges = {}
506 for mapped_key in meta_map.values():
507 ranges[mapped_key.upper()] = gen_config[mapped_key]
508 ranges["s"] = gen_config[param_map.get("num_stages")]
509 ranges["w"] = gen_config[param_map.get("num_warps")]
511 return {
512 "ranges": ranges,
513 "strategy": strategy,
514 }
515 except Exception:
516 return -1
518 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None):
519 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path)
520 if expand_config == -1:
521 return []
522 ranges = expand_config["ranges"]
523 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook)
525 def get_tuned_config(self, op_name):
526 if op_name in self.loaded_triton_config:
527 return self.loaded_triton_config[op_name]
529 current_op_configs = self._get_op_configs(op_name)
530 if not current_op_configs:
531 return []
533 configs = []
535 for single_config in current_op_configs:
536 if self.gen_key in single_config:
537 configs.extend(self.to_gen_config(single_config))
538 continue
540 current_config = copy.deepcopy(self.triton_config_default)
541 for default_param in current_config:
542 if default_param in single_config:
543 current_config[default_param] = single_config[default_param]
545 configs.append(self._create_triton_config(single_config, current_config))
546 return configs