Coverage for src/flag_gems/runtime/configs_loader.py: 65%
237 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import copy
2import inspect
3import os
4import warnings
6import triton
8from . import backend, common
9from .backend.device import DeviceDetector
12class TunedConfigLoader(object):
13 _instance = None
15 def __new__(cls, *args, **kargs):
16 if cls._instance is None:
17 cls._instance = super(TunedConfigLoader, cls).__new__(cls)
18 return cls._instance
20 def __init__(self):
21 if not hasattr(self, "initialized"):
22 self.initialized = True
23 self.device = DeviceDetector()
24 # primitive_yaml_config is simply the dictionary returned by yaml
25 # and is reserved from being an attr for vendor customizability
26 self.arch_specialized_yaml_config = None
27 self.arch_heuristics_config = None
28 self.vendor_primitive_yaml_config = self.get_vendor_tune_config()
29 self.default_primitive_yaml_config = self.get_default_tune_config()
30 self.vendor_heuristics_config = self.get_vendor_heuristics_config()
31 self.default_heuristics_config = self.get_default_heuristics_config()
32 self.update_config_from_arch()
34 if self.vendor_heuristics_config is None:
35 vendorname = self.device.vendor_name
36 warnings.warn(
37 f"The {vendorname} configuration of heuristics_config is None"
38 )
39 # gen_key is an identifier that indicates whether the current config needs to be generated automatically
40 self.gen_key = "gen"
41 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config
42 self.loaded_triton_config = {}
43 self.triton_config_default = {
44 "num_stages": 2,
45 "num_warps": 4,
46 "num_ctas": 1,
47 }
48 if self.device.vendor_name == "hygon":
49 self.triton_config_default["num_ldmatrixes"] = 0
50 self.expand_config_registry = self._build_expand_registry()
51 self.load_all()
53 def update_config_from_arch(self):
54 try:
55 archEvent = backend.BackendArchEvent()
56 if archEvent.has_arch:
57 self.arch_specialized_yaml_config = archEvent.autotune_configs
58 self.arch_heuristics_config = archEvent.heuristics_configs
59 except Exception as err:
60 print(f"[INFO] : {err}")
62 def _get_op_configs(self, op_name):
63 """Get config for op_name from available config sources."""
64 for config in (
65 self.arch_specialized_yaml_config,
66 self.vendor_primitive_yaml_config,
67 self.default_primitive_yaml_config,
68 ):
69 if config and op_name in config:
70 return config[op_name]
71 return []
73 def _create_triton_config(self, single_config, current_config):
74 """Create a triton.Config with appropriate parameters."""
75 kwargs = {
76 "num_warps": current_config["num_warps"],
77 "num_stages": current_config["num_stages"],
78 "num_ctas": current_config["num_ctas"],
79 }
80 if (
81 self.device.vendor_name == "hygon"
82 and "num_ldmatrixes" in inspect.signature(triton.Config).parameters
83 ):
84 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"]
85 return triton.Config(single_config["META"], **kwargs)
87 def _build_configs_by_op(self, op_name, ranges, pre_hook=None):
88 if op_name == "bmm":
89 return [
90 triton.Config(
91 {
92 "TILE_M": block_m,
93 "TILE_N": block_n,
94 "TILE_K": block_k,
95 "GROUP_M": 1 if block_m == 32 else 2,
96 },
97 num_stages=s,
98 num_warps=w,
99 pre_hook=pre_hook,
100 )
101 for block_m in ranges["BLOCK_M"]
102 for block_n in ranges["BLOCK_N"]
103 for block_k in ranges["BLOCK_K"]
104 for s in ranges["s"]
105 for w in ranges["w"]
106 ]
108 if op_name == "addmm":
109 return [
110 triton.Config(
111 {
112 "BLOCK_SIZE_M": block_m,
113 "BLOCK_SIZE_N": block_n,
114 "BLOCK_SIZE_K": block_k,
115 },
116 num_stages=s,
117 num_warps=w,
118 pre_hook=pre_hook,
119 )
120 for block_m in ranges["BLOCK_M"]
121 for block_n in ranges["BLOCK_N"]
122 for block_k in ranges["BLOCK_K"]
123 for s in ranges["s"]
124 for w in ranges["w"]
125 ]
127 if op_name == "baddbmm":
128 return [
129 triton.Config(
130 {
131 "TILE_M": block_m,
132 "TILE_N": block_n,
133 "TILE_K": block_k,
134 "GROUP_M": 1 if block_m <= 32 else 2,
135 },
136 num_stages=s,
137 num_warps=w,
138 pre_hook=pre_hook,
139 )
140 for block_m in ranges["BLOCK_M"]
141 for block_n in ranges["BLOCK_N"]
142 for block_k in ranges["BLOCK_K"]
143 for s in ranges["s"]
144 for w in ranges["w"]
145 ]
147 if op_name == "mv":
148 return [
149 triton.Config(
150 {
151 "BLOCK_N": block_n,
152 "BLOCK_M": block_m,
153 },
154 num_stages=s,
155 num_warps=w,
156 pre_hook=pre_hook,
157 )
158 for block_n in ranges["BLOCK_N"]
159 for block_m in ranges["BLOCK_M"]
160 for s in ranges["s"]
161 for w in ranges["w"]
162 ]
164 if op_name == "mm_general_tma":
165 group_m_values = ranges.get("GROUP_M", [8])
166 return [
167 triton.Config(
168 {
169 "BLOCK_M": block_m,
170 "BLOCK_N": block_n,
171 "BLOCK_K": block_k,
172 "GROUP_M": group_m,
173 },
174 num_stages=s,
175 num_warps=w,
176 pre_hook=pre_hook,
177 )
178 for block_m in ranges["BLOCK_M"]
179 for block_n in ranges["BLOCK_N"]
180 for block_k in ranges["BLOCK_K"]
181 for group_m in group_m_values
182 for s in ranges["s"]
183 for w in ranges["w"]
184 ]
186 if op_name in ("mm", "mm_sqmma"):
187 return [
188 triton.Config(
189 {
190 "BLOCK_M": block_m,
191 "BLOCK_N": block_n,
192 "BLOCK_K": block_k,
193 },
194 num_stages=s,
195 num_warps=w,
196 pre_hook=pre_hook,
197 )
198 for block_m in ranges["BLOCK_M"]
199 for block_n in ranges["BLOCK_N"]
200 for block_k in ranges["BLOCK_K"]
201 for s in ranges["s"]
202 for w in ranges["w"]
203 ]
205 if op_name in ("bmm_sqmma", "addmm_sqmma"):
206 return [
207 triton.Config(
208 {
209 "BLOCK_SIZE_M": block_m,
210 "BLOCK_SIZE_N": block_n,
211 "BLOCK_SIZE_K": block_k,
212 },
213 num_stages=s,
214 num_warps=w,
215 pre_hook=pre_hook,
216 )
217 for block_m in ranges["BLOCK_M"]
218 for block_n in ranges["BLOCK_N"]
219 for block_k in ranges["BLOCK_K"]
220 for s in ranges["s"]
221 for w in ranges["w"]
222 ]
224 if op_name == "gemv":
225 return [
226 triton.Config(
227 {"BLOCK_M": block_m, "BLOCK_K": block_k},
228 num_stages=s,
229 num_warps=w,
230 pre_hook=pre_hook,
231 )
232 for block_m in ranges["BLOCK_M"]
233 for block_k in ranges["BLOCK_K"]
234 for s in ranges["s"]
235 for w in ranges["w"]
236 ]
238 if op_name == "sparse_attention":
239 return [
240 triton.Config(
241 {"BLOCK": block},
242 num_stages=s,
243 num_warps=w,
244 pre_hook=pre_hook,
245 )
246 for block in ranges["BLOCK"]
247 for s in ranges["s"]
248 for w in ranges["w"]
249 ]
251 if op_name == "w8a8_block_fp8_general":
252 return [
253 triton.Config(
254 {
255 "BLOCK_M": block_m,
256 "BLOCK_N": block_n,
257 "BLOCK_K": block_k,
258 "GROUP_M": group_m,
259 },
260 num_stages=s,
261 num_warps=w,
262 pre_hook=pre_hook,
263 )
264 for block_m in ranges["BLOCK_M"]
265 for block_n in ranges["BLOCK_N"]
266 for block_k in ranges["BLOCK_K"]
267 for group_m in ranges["GROUP_M"]
268 for s in ranges["s"]
269 for w in ranges["w"]
270 ]
272 if op_name == "w8a8_block_fp8_general_tma":
273 group_m_values = ranges.get("GROUP_M", [None])
274 return [
275 triton.Config(
276 dict(
277 {
278 "BLOCK_M": block_m,
279 "BLOCK_N": block_n,
280 "BLOCK_K": block_k,
281 },
282 **({} if group_m is None else {"GROUP_M": group_m}),
283 ),
284 num_stages=s,
285 num_warps=w,
286 pre_hook=pre_hook,
287 )
288 for block_m in ranges["BLOCK_M"]
289 for block_n in ranges["BLOCK_N"]
290 for block_k in ranges["BLOCK_K"]
291 for group_m in group_m_values
292 for s in ranges["s"]
293 for w in ranges["w"]
294 ]
296 if op_name == "w8a8_block_fp8_general_splitk":
297 return [
298 triton.Config(
299 {
300 "BLOCK_M": block_m,
301 "BLOCK_N": block_n,
302 "BLOCK_K": block_k,
303 "SPLIT_K": split_k,
304 },
305 num_stages=s,
306 num_warps=w,
307 pre_hook=pre_hook,
308 )
309 for block_m in ranges["BLOCK_M"]
310 for block_n in ranges["BLOCK_N"]
311 for block_k in ranges["BLOCK_K"]
312 for split_k in ranges["SPLIT_K"]
313 for s in ranges["s"]
314 for w in ranges["w"]
315 ]
317 if op_name == "mm_splitk":
318 return [
319 triton.Config(
320 {
321 "BLOCK_M": block_m,
322 "BLOCK_N": block_n,
323 "BLOCK_K": block_k,
324 "SPLIT_K": split_k,
325 },
326 num_stages=s,
327 num_warps=w,
328 pre_hook=pre_hook,
329 )
330 for block_m in ranges["BLOCK_M"]
331 for block_n in ranges["BLOCK_N"]
332 for block_k in ranges["BLOCK_K"]
333 for split_k in ranges["SPLIT_K"]
334 for s in ranges["s"]
335 for w in ranges["w"]
336 ]
338 return []
340 def _build_single_expand_spec(
341 self,
342 op_name,
343 expand_yaml_path=None,
344 yaml_op_name=None,
345 ):
346 return {
347 "yaml_op_name": yaml_op_name or op_name,
348 "key": common.OP_KEY_ORDERS[op_name],
349 "default_strategy": common.DEFAULT_STRATEGIES[op_name],
350 "expand_yaml_path": expand_yaml_path,
351 }
353 def _iter_expand_config_candidates(self, op_name):
354 vendor_name = self.device.vendor_name
355 contexts = []
356 try:
357 arch_event = backend.BackendArchEvent()
358 current_arch_path = getattr(arch_event, "current_arch_path", None)
359 arch_name = getattr(arch_event, "arch", None)
360 if arch_event.has_arch and current_arch_path:
361 contexts.append((current_arch_path, arch_name))
362 except Exception:
363 pass
365 backend_dir = os.path.join(os.path.dirname(__file__), "backend")
366 contexts.append((os.path.join(backend_dir, f"_{vendor_name}"), vendor_name))
368 seen = set()
369 for base_dir, backend_name in contexts:
370 filenames = []
371 if op_name:
372 filenames.extend(
373 (
374 f"{op_name}_{backend_name}_expand.yaml",
375 f"{op_name}_{vendor_name}_expand.yaml",
376 f"{op_name}_expand.yaml",
377 )
378 )
379 filenames.extend(
380 (
381 f"general_ops_{backend_name}_expand.yaml",
382 f"general_ops_{vendor_name}_expand.yaml",
383 "general_ops_expand.yaml",
384 )
385 )
387 for filename in filenames:
388 path = os.path.normpath(os.path.join(base_dir, filename))
389 if path in seen:
390 continue
391 seen.add(path)
392 yield path
394 def _get_expand_config_path(self, op_name):
395 for path in self._iter_expand_config_candidates(op_name):
396 if os.path.exists(path):
397 return path
398 return None
400 def _build_expand_registry(self):
401 return {
402 "addmm": self._build_single_expand_spec(
403 "addmm", expand_yaml_path=self._get_expand_config_path("addmm")
404 ),
405 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"),
406 "baddbmm": self._build_single_expand_spec(
407 "baddbmm", expand_yaml_path=self._get_expand_config_path("baddbmm")
408 ),
409 "bmm": self._build_single_expand_spec(
410 "bmm", expand_yaml_path=self._get_expand_config_path("bmm")
411 ),
412 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"),
413 "gemv": self._build_single_expand_spec("gemv"),
414 "mm": self._build_single_expand_spec(
415 "mm", expand_yaml_path=self._get_expand_config_path("mm")
416 ),
417 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"),
418 "mv": self._build_single_expand_spec(
419 "mv", expand_yaml_path=self._get_expand_config_path("mv")
420 ),
421 "w8a8_block_fp8_general": self._build_single_expand_spec(
422 "w8a8_block_fp8_general"
423 ),
424 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec(
425 "w8a8_block_fp8_general_splitk"
426 ),
427 "w8a8_block_fp8_general_tma": self._build_single_expand_spec(
428 "w8a8_block_fp8_general_tma"
429 ),
430 "mm_splitk": self._build_single_expand_spec("mm_splitk"),
431 "sparse_attention": self._build_single_expand_spec("sparse_attention"),
432 }
434 def load_all(self):
435 for key in self.vendor_primitive_yaml_config:
436 self.loaded_triton_config[key] = self.get_tuned_config(key)
438 def get_vendor_heuristics_config(self):
439 return backend.get_heuristic_config(self.device.vendor_name)
441 def get_default_heuristics_config(self):
442 return backend.get_heuristic_config("nvidia")
444 def get_default_tune_config(self):
445 return backend.get_tune_config("nvidia")
447 def get_vendor_tune_config(self):
448 return backend.get_tune_config(self.device.vendor_name)
450 def get_heuristics_config(self, op_name):
451 if self.arch_heuristics_config and op_name in self.arch_heuristics_config:
452 return self.arch_heuristics_config[op_name]
453 elif op_name in self.vendor_heuristics_config:
454 return self.vendor_heuristics_config[op_name]
455 elif op_name in self.default_heuristics_config:
456 return self.default_heuristics_config[op_name]
457 else:
458 warnings.warn(f"No heuristics config found for {op_name}")
459 return None
461 def _resolve_iteration_values(self, gen_config, config_var_key):
462 if isinstance(config_var_key, (list, tuple)):
463 return config_var_key
464 if isinstance(config_var_key, int):
465 return [config_var_key]
466 return gen_config[config_var_key]
468 def _gen_impl(
469 self,
470 gen_config,
471 iteration_plan,
472 std_config,
473 ):
474 all_configs = []
475 final_step = len(iteration_plan)
476 stack = [{"cur_config": std_config, "current_step": 0}]
478 while stack:
479 cur_state = stack[-1]
480 stack.pop()
481 cur_config = cur_state.get("cur_config")
482 current_step = cur_state.get("current_step")
484 if current_step == final_step:
485 all_configs.append(
486 triton.Config(
487 cur_config["META"],
488 num_warps=cur_config["num_warps"],
489 num_stages=cur_config["num_stages"],
490 num_ctas=cur_config["num_ctas"],
491 )
492 )
493 else:
494 cur_entry = iteration_plan[current_step]
495 cur_key = cur_entry["key"]
496 key_config = self._resolve_iteration_values(
497 gen_config, cur_entry["source"]
498 )
499 for single_value in key_config:
500 new_config = copy.deepcopy(cur_config)
501 if cur_entry["kind"] == "meta_field":
502 new_config["META"][cur_key] = single_value
503 elif cur_entry["kind"] == "meta_block":
504 new_config["META"] = copy.deepcopy(single_value)
505 else:
506 new_config[cur_key] = single_value
507 stack.append(
508 {
509 "cur_config": new_config,
510 "current_step": current_step + 1,
511 }
512 )
513 return all_configs
515 def to_gen_config(self, gen_config):
516 param_config = gen_config["param_map"]
517 meta_config = param_config["META"]
518 iteration_plan = []
520 if isinstance(meta_config, dict):
521 for meta_key, source in meta_config.items():
522 iteration_plan.append(
523 {"key": meta_key, "source": source, "kind": "meta_field"}
524 )
525 else:
526 iteration_plan.append(
527 {"key": "META", "source": meta_config, "kind": "meta_block"}
528 )
530 for key, source in param_config.items():
531 if key == "META":
532 continue
533 iteration_plan.append(
534 {"key": key, "source": source, "kind": "config_field"}
535 )
537 current_config = {"META": {}}
538 current_config.update(self.triton_config_default)
539 return self._gen_impl(
540 gen_config,
541 iteration_plan,
542 current_config,
543 )
545 def get_expand_config(self, op_name, yaml_path=None):
546 op_spec = self.expand_config_registry.get(op_name)
547 if op_spec is None:
548 return -1
550 key = op_spec.get("key", [])
551 default_strategy = op_spec.get("default_strategy")
552 expand_yaml_path = yaml_path or op_spec.get("expand_yaml_path")
553 yaml_op_name = op_spec.get("yaml_op_name", op_name)
554 if not expand_yaml_path:
555 return -1
557 try:
558 expand_configs = backend.get_expand_config(
559 op_name=yaml_op_name,
560 file_path=expand_yaml_path,
561 )
562 if not isinstance(expand_configs, list):
563 return -1
565 gen_config = None
566 strategy_config = None
567 for single_config in expand_configs:
568 if isinstance(single_config, dict) and "param_map" in single_config:
569 gen_config = single_config
571 if isinstance(single_config, dict) and "strategy" in single_config:
572 strategy_config = single_config.get("strategy")
574 param_map = gen_config.get("param_map")
575 meta_map = param_map.get("META")
577 strategy = default_strategy
578 if isinstance(strategy_config, dict):
579 strategy = [
580 strategy_config.get(k, default_strategy[idx])
581 for idx, k in enumerate(key)
582 ]
584 ranges = {}
586 for mapped_key in meta_map.values():
587 ranges[mapped_key.upper()] = gen_config[mapped_key]
588 ranges["s"] = gen_config[param_map.get("num_stages")]
589 ranges["w"] = gen_config[param_map.get("num_warps")]
591 return {
592 "ranges": ranges,
593 "strategy": strategy,
594 }
595 except Exception:
596 return -1
598 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None):
599 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path)
600 if expand_config == -1:
601 return []
602 ranges = expand_config["ranges"]
603 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook)
605 def get_tuned_config(self, op_name):
606 if op_name in self.loaded_triton_config:
607 return self.loaded_triton_config[op_name]
609 current_op_configs = self._get_op_configs(op_name)
610 if not current_op_configs:
611 return []
613 configs = []
615 for single_config in current_op_configs:
616 if self.gen_key in single_config:
617 configs.extend(self.to_gen_config(single_config))
618 continue
620 current_config = copy.deepcopy(self.triton_config_default)
621 for default_param in current_config:
622 if default_param in single_config:
623 current_config[default_param] = single_config[default_param]
625 configs.append(self._create_triton_config(single_config, current_config))
626 return configs