Coverage for src/flag_gems/utils/libentry.py: 81%
379 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1from __future__ import annotations
3import hashlib
4import inspect
5import logging
6import math
7import multiprocessing
8import os
9import time
10from abc import abstractmethod
11from collections import OrderedDict
12from functools import cached_property
13from itertools import starmap
14from pathlib import Path
15from typing import (
16 Any,
17 Callable,
18 Dict,
19 Final,
20 Iterator,
21 List,
22 Optional,
23 Tuple,
24 Type,
25 Union,
26 overload,
27)
29import triton
31from flag_gems import runtime
32from flag_gems.runtime import torch_device_fn
33from flag_gems.runtime.backend import vendor_module
34from flag_gems.utils.code_cache import config_cache_dir
35from flag_gems.utils.models import PersistantModel, SQLPersistantModel
37logger = logging.getLogger(__name__)
39DEVICE_COUNT = runtime.device.device_count
41version = triton.__version__.split(".")
42major_version, minor_version = eval(version[0]), eval(version[1])
45if major_version == 2:
47 def all_kwargs(self):
48 return {
49 **self.kwargs,
50 **{
51 k: getattr(self, k)
52 for k in (
53 "num_warps",
54 "num_ctas",
55 "num_stages",
56 "num_buffers_warp_spec",
57 "num_consumer_groups",
58 "reg_dec_producer",
59 "reg_inc_consumer",
60 "maxnreg",
61 )
62 if hasattr(self, k)
63 },
64 }
66 setattr(triton.Config, "all_kwargs", all_kwargs)
68FLAGGEMS_DB_URL = os.getenv("FLAGGEMS_DB_URL", None)
71class Cache(object):
72 def __init__(
73 self, table_name: str, model: PersistantModel, *args, **kwargs
74 ) -> Cache:
75 super().__init__(*args, **kwargs)
76 self.table_name: Final[str] = table_name
77 self.model: Final[PersistantModel] = model
80class ConfigCache(Cache):
81 """
82 `ConfigCache` is used to store the relationship between keys and their known best configurations.
83 """
85 def __init__(
86 self, table_name: str, model: PersistantModel, *args, **kwargs
87 ) -> ConfigCache:
88 super().__init__(table_name, model, *args, **kwargs)
90 def __contains__(self, key: Tuple[Union[int, float, str], ...]) -> bool:
91 return self.get(key) is not None
93 def __getitem__(self, key: Tuple[Union[int, float, str], ...]) -> triton.Config:
94 ret: Optional[triton.Config] = self.get(key)
95 if ret is None:
96 raise KeyError(f"Key {key} not found in ConfigCache.")
97 return ret
99 def __setitem__(
100 self, key: Tuple[Union[int, float, str], ...], config: triton.Config
101 ) -> None:
102 self.set(key, config)
104 def get(self, key: Tuple[Union[int, float, str], ...]) -> Optional[triton.Config]:
105 return self.model.get_config(self.table_name, key)
107 def set(
108 self, key: Tuple[Union[int, float, str], ...], config: triton.Config
109 ) -> None:
110 return self.model.put_config(self.table_name, key, config)
113class BenchmarkCache(Cache):
114 def __init__(
115 self,
116 table_name: str,
117 model: PersistantModel,
118 key: Tuple[Union[int, float, str], ...],
119 *args,
120 **kwargs,
121 ) -> BenchmarkCache:
122 """
123 `BenchmarkCache` is used to store the benchmark results for the pair of the specific key and configuration.
124 """
125 super().__init__(table_name, model, *args, **kwargs)
126 self.key: Final[Tuple[Union[int, float, str], ...]] = key
128 def __contains__(self, config: triton.Config) -> bool:
129 return self.model.get_benchmark(self.key, config) is not None
131 def __getitem__(self, config: triton.Config) -> Tuple[float]:
132 ret: Optional[Tuple[float, float, float]] = self.get(config)
133 if ret is None:
134 raise KeyError(
135 f"Config {config} not found in BenchmarkCache for key {self.key}."
136 )
137 return ret
139 def __setitem__(self, config: triton.Config, benchmark: Tuple[float]) -> None:
140 return self.set(config, benchmark)
142 def get(self, config: triton.Config) -> Optional[Tuple[float, float, float]]:
143 return self.model.get_benchmark(self.table_name, self.key, config)
145 def set(self, config: triton.Config, benchmark: Tuple[float, float, float]) -> None:
146 return self.model.put_benchmark(self.table_name, self.key, config, benchmark)
149class LibCache(object):
150 _instance = None
152 def __new__(cls, *args, **kwargs):
153 if cls._instance is None:
154 cls._instance = super(LibCache, cls).__new__(cls)
155 return cls._instance
157 def __init__(self, db_url: Optional[str] = None):
158 self.global_cache: Dict = {}
159 self.volumn: Dict = {}
160 if db_url is None:
161 try:
162 device_name: str = torch_device_fn.get_device_name().replace(" ", "_")
163 except AttributeError:
164 device_name: str = vendor_module.vendor_info.device_name
165 cache_file_name: str = (
166 f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db"
167 if vendor_module.vendor_info.vendor_name == "nvidia"
168 else f"TunedConfig_{vendor_module.vendor_info.vendor_name}_triton_{major_version}_{minor_version}.db"
169 )
170 cache_path: Path = config_cache_dir() / cache_file_name
171 self.db_url: str = f"sqlite:///{cache_path}"
172 else:
173 self.db_url: str = db_url
174 self.config_cache_pool: Dict[str, ConfigCache] = {}
175 self.benchmark_cache_pool: Dict[
176 Tuple[str, Tuple[Union[int, float, str], ...]], BenchmarkCache
177 ] = {}
178 self.model: PersistantModel = SQLPersistantModel(self.db_url)
180 @overload
181 def __getitem__(self, key: str) -> ConfigCache:
182 ...
184 @overload
185 def __getitem__(self, key: Tuple[Union[int, float, str]]) -> BenchmarkCache:
186 ...
188 def __getitem__(
189 self, key: Union[str, Tuple[Union[int, float, str], ...]]
190 ) -> Union[BenchmarkCache, ConfigCache]:
191 if isinstance(key, str):
192 return self.get_config(key)
193 elif isinstance(key, tuple):
194 return self.get_benchmark(*key)
195 else:
196 assert False, f"the type of key '{key.__class__.__name__}' is unacceptable"
198 def get_benchmark(
199 self, table: str, key: Tuple[Union[int, float, str], ...]
200 ) -> BenchmarkCache:
201 ret = self.benchmark_cache_pool.get((table, key))
202 if ret is None:
203 ret = BenchmarkCache(table, self.model, key)
204 self.benchmark_cache_pool[(table, key)] = ret
205 return ret
207 def get_config(self, table: str) -> ConfigCache:
208 ret = self.config_cache_pool.get(table)
209 if ret is None:
210 ret = ConfigCache(table, self.model)
211 self.config_cache_pool[table] = ret
212 return ret
215libcache = LibCache(FLAGGEMS_DB_URL)
218class LibTuner(triton.runtime.Autotuner):
219 """`LibTuner` is the base class for `FlagGems` library autotuner.
221 It could be extended in two ways, overriding the `policy` or `run` method in a subclass.
222 For `policy` extension, `LibTuner` provides a decorator `register_policy` to register a policy function quickly.
223 Please refer to the implementation of `default_policy` for an example.
224 """
226 # The dispatch table for `LibTuner` subclasses. It's shared across all instances.
227 _dispatch_table: Dict[str, Type[LibTuner]] = {}
228 _strategy_table: Dict[str, Callable[[Any], Any]] = {}
230 def __init__(
231 self,
232 fn,
233 arg_names,
234 configs,
235 key,
236 reset_to_zero,
237 restore_value,
238 pre_hook=None,
239 post_hook=None,
240 prune_configs_by: Optional[Dict] = None,
241 warmup=None,
242 rep=None,
243 use_cuda_graph=False,
244 do_bench=None,
245 strategy=None,
246 ):
247 # NOTE(zhengyang): See discussion in https://github.com/triton-lang/triton/pull/4496
248 if major_version == 2 or (major_version == 3 and minor_version <= 1):
249 if warmup is None:
250 warmup = 25
251 if rep is None:
252 rep = 100
253 if major_version == 2:
254 super().__init__(
255 fn,
256 arg_names,
257 configs,
258 key,
259 reset_to_zero,
260 restore_value,
261 prune_configs_by,
262 warmup,
263 rep,
264 )
265 self.base_fn = fn
266 while not inspect.isfunction(self.base_fn):
267 self.base_fn = self.base_fn.fn
268 else:
269 super().__init__(
270 fn,
271 arg_names,
272 configs,
273 key,
274 reset_to_zero,
275 restore_value,
276 pre_hook,
277 post_hook,
278 prune_configs_by,
279 warmup,
280 rep,
281 use_cuda_graph,
282 )
283 self.__name__ = self.base_fn.__name__
284 self.keys = key
285 if isinstance(strategy, str):
286 strategy = LibTuner.get_strategy(strategy)
287 if not isinstance(strategy, (list, tuple)):
288 strategy = [strategy] * len(self.keys)
289 assert len(strategy) == len(
290 self.keys
291 ), f"the length of strategy {len(strategy)} must match the length of keys {len(self.keys)}"
292 strategy: List[Callable[[Any], Any]] = [
293 LibTuner.get_strategy(s) if isinstance(s, str) else s for s in strategy
294 ]
295 self.strategy: List[Callable[[Any], Any]] = strategy
296 self.config_table_name: str = f"{self.__name__}_{self.kernel_hash}"
297 self.benchmark_table_name: str = f"{self.__name__}_{self.cache_key}_benchmark"
298 self.cache: BenchmarkCache = libcache[self.config_table_name]
300 @cached_property
301 def cache_key(self) -> str:
302 jit_fn = self.fn
303 while not isinstance(jit_fn, triton.runtime.JITFunction):
304 jit_fn = jit_fn.fn
305 return jit_fn.cache_key
307 @cached_property
308 def kernel_hash(self) -> str:
309 return hashlib.md5(
310 f"{self.cache_key}{self.configs_hash}".encode("utf-8")
311 ).hexdigest()[:32]
313 @cached_property
314 def configs_hash(self) -> str:
315 return hashlib.md5(
316 ",".join(map(lambda config: str(config), self.configs)).encode("utf-8")
317 ).hexdigest()[:32]
319 def get_key(self, args):
320 if self.strategy is None:
321 key = tuple(args[k] for k in self.keys if k in args)
322 else:
323 key = tuple(
324 starmap(
325 lambda idx0, idx1: self.strategy[idx0](args[idx1]),
326 enumerate(self.keys),
327 )
328 )
329 key += tuple(str(arg.dtype) for arg in args.values() if hasattr(arg, "dtype"))
330 return key
332 @staticmethod
333 @abstractmethod
334 def policy(
335 self,
336 fn: Callable[[triton.Config], List[float]],
337 configs: Iterator[triton.Config],
338 args: Tuple[Any],
339 kwargs: Dict[str, Any],
340 ) -> Tuple[triton.Config, Dict[str, float]]:
341 raise NotImplementedError(
342 f"`policy` isn't implemented in {self.__class__.__name__}"
343 )
345 @classmethod
346 def register(cls, name: str):
347 """Register a subclass of `LibTuner` with a name.
349 Args:
350 name: The name of the subclass.
351 Returns:
352 A decorator that registers the subclass with the name.
353 """
355 def decorator(subclass):
356 cls._dispatch_table[name] = subclass
357 return subclass
359 return decorator
361 @classmethod
362 def get(cls, name: str):
363 return cls._dispatch_table[name]
365 @classmethod
366 def get_strategy(cls, name: str):
367 return cls._strategy_table[name]
369 @staticmethod
370 def register_policy(
371 name: str,
372 ) -> Type[LibTuner]:
373 """A decorator to register a policy for `LibTuner`.
375 This decorator allows you to create a new `LibTuner` subclass without defining a new class explicitly.
376 The new subclass will have the `policy` method set to the provided policy function and will be registered under
377 the specified name in the `LibTuner` dispatch table.
378 """
380 def decorator(
381 policy_impl: Callable[
382 [
383 Callable[[triton.Config], List[float]],
384 Iterator[triton.Config],
385 Tuple[Any],
386 Dict[str, Any],
387 ],
388 Tuple[triton.Config, Dict[str, float]],
389 ],
390 ):
391 @LibTuner.register(name)
392 class AnonymousLibTunerImpl(LibTuner):
393 def __init__(self, *args, **kwargs):
394 super().__init__(*args, **kwargs)
396 def policy(
397 self,
398 fn: Callable[[triton.Config], List[float]],
399 configs: Iterator[triton.Config],
400 args: Tuple[Any],
401 kwargs: Dict[str, Any],
402 ) -> Tuple[triton.Config, Dict[str, float]]:
403 return policy_impl(fn, configs, args, kwargs)
405 return AnonymousLibTunerImpl
407 return decorator
409 @staticmethod
410 def register_strategy(name: str):
411 def decorator(
412 strategy: Union[Callable[[Any], Any], List[Callable[[Any], Any]]],
413 ):
414 LibTuner._strategy_table[name] = strategy
415 return strategy
417 return decorator
419 def run(self, *args, **kwargs):
420 # `arg_names` corresponds to the arguments of the `JITFunction`'s signature,
421 # so please make sure the orders of `arg_names` and `args` match.
422 self.nargs = dict(zip(self.arg_names, args))
423 used_cached_result = True
424 if len(self.configs) > 1:
425 all_args = {**self.nargs, **kwargs}
426 _args = {k: v for k, v in all_args.items() if k in self.arg_names}
427 key = self.get_key(_args)
428 if key not in self.cache:
429 cache: BenchmarkCache = libcache[self.benchmark_table_name, key]
430 # prune configs
431 used_cached_result = False
432 pruned_configs = self.prune_configs(kwargs)
433 bench_start = time.time()
435 def bench(config: triton.Config) -> List[float]:
436 ret = cache.get(config)
437 if ret is None:
438 ret = self._bench(*args, config=config, **kwargs)
439 cache[config] = tuple(ret)
440 return list(ret)
442 best_config, timings = self.policy(
443 bench,
444 pruned_configs,
445 args,
446 kwargs,
447 )
448 bench_end = time.time()
449 self.bench_time = bench_end - bench_start
450 self.cache[key] = best_config
451 full_nargs = {
452 **self.nargs,
453 **kwargs,
454 **self.cache[key].all_kwargs(),
455 }
456 self.pre_hook(full_nargs, reset_only=True)
457 self.configs_timings = timings
458 config = self.cache[key]
459 if config.pre_hook is None:
460 cached_kwargs = config.all_kwargs()
461 for original_config in self.configs:
462 if original_config.all_kwargs() == cached_kwargs:
463 # Use the original config which has the pre_hook
464 config = original_config
465 break
466 else:
467 config = self.configs[0]
468 self.best_config = config
469 if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
470 print(
471 f"Triton autotuning for function {self.base_fn.__name__} finished after "
472 f"{self.bench_time:.2f}s; key info: {key}, best config selected: {self.best_config};"
473 )
474 if config.pre_hook is not None:
475 full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
476 config.pre_hook(full_nargs)
477 ret = self.fn.run(
478 *args,
479 **kwargs,
480 **config.all_kwargs(),
481 )
482 self.nargs = None
483 return ret
486@LibTuner.register_strategy(None)
487@LibTuner.register_strategy("default")
488def default_strategy(key: Any) -> Any:
489 return key
492@LibTuner.register_strategy("log")
493def log2_strategy(key: Union[int, float]) -> float:
494 return 2 ** math.ceil(math.log2(key))
497@LibTuner.register_strategy("align32")
498def align32_strategy(key: Union[int, float]) -> int:
499 return math.ceil(key / 32) * 32
502@LibTuner.register_policy("default")
503def default_policy(
504 bench_fn: Callable[[triton.Config], List[float]],
505 configs: Iterator[triton.Config],
506 args: Tuple[Any],
507 kwargs: Dict[str, Any],
508) -> Tuple[triton.Config, Dict[str, float]]:
509 """Default policy for offline autotuning.
511 Args:
512 bench_fn: The function to benchmark.
513 configs: The collection of the configuration search space.
514 args: Kernel launch arguments.
515 kwargs: Kernel launch arguments.
516 Returns:
517 A tuple containing the best configuration and a dictionary of timings for each configuration.
519 This is one way to implement a default policy for offline autotuning. It's equal to the following
520 ```
521 @LibTuner.register("default")
522 class DefaultLibTunerImpl(LibTuner):
523 def __init__(
524 self,
525 *args,
526 **kwargs,
527 ):
528 super().__init__(
529 *args,
530 **kwargs,
531 )
533 @staticmethod
534 def policy(
535 bench_fn: Callable[[triton.Config], List[float]],
536 configs: Iterator[triton.Config],
537 args: Tuple[Any],
538 kwargs: Dict[str, Any],
539 ) -> Tuple[triton.Config, Dict[str, float]]:
540 timings: Dict[triton.Config, int] = {
541 config: bench_fn(config) for config in configs
542 }
543 best_config: triton.Config = min(timings, key=timings.get)
544 return best_config, timings
545 ```
546 In this way policies could be extended by registering a definition function quickly,
547 or by creating a new subclass of `LibTuner` and overriding the `policy` method to have
548 more control over the autotuning process.
549 """
550 timings: Dict[triton.Config, float] = {
551 config: bench_fn(config) for config in configs
552 }
553 best_config: triton.Config = min(timings, key=timings.get)
554 return best_config, timings
557def libtuner(
558 configs,
559 key,
560 prune_configs_by=None,
561 reset_to_zero=None,
562 restore_value=None,
563 pre_hook=None,
564 post_hook=None,
565 warmup=25,
566 rep=100,
567 use_cuda_graph=False,
568 do_bench=None,
569 strategy: Union[
570 str, Callable[[Any], Any], List[Union[str, Callable[[Any], Any]]]
571 ] = "default",
572 policy: Union[str, Type[LibTuner]] = "default",
573):
574 """Decorator for triton library autotuner.
576 `strategy` is a function that takes a key and returns a value.
577 It accepts a string, which is the name of a registered strategy, or a callable function.
578 In this form it will be applied to each key in the `key` list.
579 If it's a tuple or list, it should have the same length as `key`,
580 and each element should be a string or a callable function that takes a key and returns a value.
581 `policy` accepts a string, which is the name of a registered `LibTuner` subclass, or a `LibTuner` subclass itself.
582 """
584 if isinstance(policy, str):
585 policy = LibTuner.get(policy)
586 assert issubclass(
587 policy, LibTuner
588 ), f"the class of {policy.__name__} is {policy.__class__.__name__}, not a subclass of {LibTuner.__name__}"
590 def decorator(fn):
591 return policy(
592 fn,
593 fn.arg_names,
594 configs,
595 key,
596 reset_to_zero,
597 restore_value,
598 pre_hook=pre_hook,
599 post_hook=post_hook,
600 prune_configs_by=prune_configs_by,
601 warmup=warmup,
602 rep=rep,
603 use_cuda_graph=use_cuda_graph,
604 do_bench=do_bench,
605 strategy=strategy,
606 )
608 return decorator
611class LibEntry(triton.KernelInterface):
612 def __init__(
613 self,
614 fn,
615 ):
616 self.fn = fn
617 self.arg_names = fn.arg_names
618 self.divisibility = 16
619 self.kernel_cache = tuple(dict() for _ in range(DEVICE_COUNT))
621 while not isinstance(fn, triton.runtime.JITFunction):
622 fn = fn.fn
623 self.jit_function: triton.runtime.JITFunction = fn
624 self.specialize_indices = [
625 p.num
626 for p in self.jit_function.params
627 if not p.is_constexpr and not p.do_not_specialize
628 ]
629 self.do_not_specialize_indices = [
630 p.num
631 for p in self.jit_function.params
632 if not p.is_constexpr and p.do_not_specialize
633 ]
634 self.lock = multiprocessing.Lock()
635 self.signature = fn.signature
637 def key(self, spec_args, dns_args, const_args):
638 def spec_arg(arg):
639 if hasattr(arg, "data_ptr"):
640 return (arg.dtype, arg.data_ptr() % self.divisibility == 0)
641 return (type(arg), arg)
643 def dns_arg(arg):
644 if hasattr(arg, "data_ptr"):
645 return arg.dtype
646 if not isinstance(arg, int):
647 return type(arg)
648 if -(2**31) <= arg and arg <= 2**31 - 1:
649 return "i32"
650 if 2**63 <= arg and arg <= 2**64 - 1:
651 return "u64"
652 return "i64"
654 spec_key = [spec_arg(arg) for arg in spec_args]
655 dns_key = [dns_arg(arg) for arg in dns_args]
656 # const args passed by position
657 return tuple(spec_key + dns_key + const_args)
659 def run(self, *args, **kwargs):
660 grid = kwargs["grid"]
662 # collect all the arguments
663 spec_args = [] # specialize arguments
664 dns_args = [] # do not specialize arguments
665 const_args = [] # constexpr arguments
666 k_args = OrderedDict()
667 param_names = list(self.signature.parameters.keys())
668 for i, arg in enumerate(args):
669 hashable_arg = arg
670 if (
671 hasattr(arg, "__class__")
672 and arg.__class__.__name__ == "TensorDescriptor"
673 ):
674 # Create a hashable representation of TensorDescriptor
675 hashable_arg = (
676 "TensorDescriptor",
677 tuple(arg.shape) if hasattr(arg, "shape") else None,
678 tuple(arg.strides) if hasattr(arg, "strides") else None,
679 tuple(arg.block_shape) if hasattr(arg, "block_shape") else None,
680 arg.padding if hasattr(arg, "padding") else None,
681 # Add other relevant attributes
682 )
683 if i in self.specialize_indices:
684 k_args[param_names[i]] = arg
685 spec_args.append(hashable_arg)
686 elif i in self.do_not_specialize_indices:
687 k_args[param_names[i]] = arg
688 dns_args.append(hashable_arg)
689 else:
690 if major_version == 3 and 3 <= minor_version <= 6:
691 k_args[param_names[i]] = arg
692 const_args.append(hashable_arg)
693 for p in self.jit_function.params[len(args) :]:
694 if p.name in kwargs:
695 val = kwargs[p.name]
696 elif p.default is inspect._empty:
697 continue
698 else:
699 val = p.default
701 if p.is_constexpr:
702 const_args.append(val)
703 if major_version == 3 and 3 <= minor_version <= 6:
704 k_args[p.name] = val
705 elif p.do_not_specialize:
706 dns_args.append(val)
707 k_args[p.name] = val
708 else:
709 spec_args.append(val)
710 k_args[p.name] = val
712 entry_key = self.key(spec_args, dns_args, const_args)
713 device = torch_device_fn.current_device()
714 cache = self.kernel_cache[device]
715 while entry_key not in cache:
716 # NOTE: we serialize the first run of a jit function regardless of which device to run on
717 # because Triton runtime is currently not threadsafe.
718 with self.lock:
719 if entry_key in cache:
720 break
721 kernel = self.fn.run(*args, **kwargs)
722 fn = self.fn
723 # collect constexpr arguments for grid computation
724 constexprs = {}
725 tune_constexprs = {}
726 heur_constexprs = {}
727 while not isinstance(fn, triton.runtime.JITFunction):
728 if isinstance(fn, triton.runtime.Autotuner):
729 config = fn.best_config
730 constexprs["num_warps"] = config.num_warps
731 constexprs["num_stages"] = config.num_stages
732 constexprs["num_ctas"] = config.num_ctas
733 constexprs = {**constexprs, **config.kwargs}
734 tune_constexprs = {**tune_constexprs, **config.kwargs}
735 elif isinstance(fn, triton.runtime.Heuristics):
736 for v, heur in fn.values.items():
737 heur_constexprs[v] = heur(
738 {
739 **dict(zip(fn.arg_names, args)),
740 **kwargs,
741 **constexprs,
742 }
743 )
744 constexprs[v] = heur_constexprs[v]
745 else:
746 raise RuntimeError("Invalid Runtime Function")
747 fn = fn.fn
748 for p in self.jit_function.params:
749 if (
750 p.is_constexpr
751 and p.name not in constexprs
752 and (p.default is not inspect._empty)
753 ):
754 constexprs[p.name] = p.default
755 cache[entry_key] = (
756 kernel,
757 constexprs,
758 tune_constexprs,
759 heur_constexprs,
760 )
761 return kernel, constexprs
763 kernel, constexprs, tune_constexprs, heur_constexprs = cache[entry_key]
765 if callable(grid):
766 # collect all arguments to the grid fn,ie:
767 # 1. args,
768 # 2. kwargs,
769 # 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics
770 # when kwargs & captured args conflict, captured args have higher priority
771 meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs}
772 grid = grid(meta)
773 grid = grid + (1, 1)
775 if major_version == 3 and 3 <= minor_version <= 6:
776 all_args = []
777 missing_keys = []
778 for key in list(self.signature.parameters.keys()):
779 if key in k_args:
780 all_args.append(k_args[key])
781 elif key in tune_constexprs:
782 all_args.append(tune_constexprs[key])
783 elif key in heur_constexprs:
784 all_args.append(heur_constexprs[key])
785 elif key in constexprs:
786 all_args.append(constexprs[key])
787 else:
788 missing_keys.append(key)
789 if len(missing_keys):
790 raise RuntimeError(
791 f"[libentry]: probably a bug, the following kernel params where not captured: {missing_keys}"
792 )
793 kernel[grid[0:3]](*all_args)
794 else:
795 kernel[grid[0:3]](*k_args.values())
796 return kernel, constexprs
799def libentry():
800 """Decorator for triton library entries."""
802 def decorator(fn):
803 return LibEntry(fn)
805 return decorator