Coverage for src/flag_gems/utils/libentry.py: 80%

379 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1from __future__ import annotations 

2 

3import hashlib 

4import inspect 

5import logging 

6import math 

7import multiprocessing 

8import os 

9import time 

10from abc import abstractmethod 

11from collections import OrderedDict 

12from functools import cached_property 

13from itertools import starmap 

14from pathlib import Path 

15from typing import ( 

16 Any, 

17 Callable, 

18 Dict, 

19 Final, 

20 Iterator, 

21 List, 

22 Optional, 

23 Tuple, 

24 Type, 

25 Union, 

26 overload, 

27) 

28 

29import triton 

30 

31from flag_gems import runtime 

32from flag_gems.runtime import torch_device_fn 

33from flag_gems.runtime.backend import vendor_module 

34from flag_gems.utils.code_cache import config_cache_dir 

35from flag_gems.utils.models import PersistantModel, SQLPersistantModel 

36 

37logger = logging.getLogger(__name__) 

38 

39DEVICE_COUNT = runtime.device.device_count 

40 

41version = triton.__version__.split(".") 

42major_version, minor_version = eval(version[0]), eval(version[1]) 

43 

44 

45if major_version == 2: 

46 

47 def all_kwargs(self): 

48 return { 

49 **self.kwargs, 

50 **{ 

51 k: getattr(self, k) 

52 for k in ( 

53 "num_warps", 

54 "num_ctas", 

55 "num_stages", 

56 "num_buffers_warp_spec", 

57 "num_consumer_groups", 

58 "reg_dec_producer", 

59 "reg_inc_consumer", 

60 "maxnreg", 

61 ) 

62 if hasattr(self, k) 

63 }, 

64 } 

65 

66 setattr(triton.Config, "all_kwargs", all_kwargs) 

67 

68FLAGGEMS_DB_URL = os.getenv("FLAGGEMS_DB_URL", None) 

69 

70 

71class Cache(object): 

72 def __init__( 

73 self, table_name: str, model: PersistantModel, *args, **kwargs 

74 ) -> Cache: 

75 super().__init__(*args, **kwargs) 

76 self.table_name: Final[str] = table_name 

77 self.model: Final[PersistantModel] = model 

78 

79 

80class ConfigCache(Cache): 

81 """ 

82 `ConfigCache` is used to store the relationship between keys and their known best configurations. 

83 """ 

84 

85 def __init__( 

86 self, table_name: str, model: PersistantModel, *args, **kwargs 

87 ) -> ConfigCache: 

88 super().__init__(table_name, model, *args, **kwargs) 

89 

90 def __contains__(self, key: Tuple[Union[int, float, str], ...]) -> bool: 

91 return self.get(key) is not None 

92 

93 def __getitem__(self, key: Tuple[Union[int, float, str], ...]) -> triton.Config: 

94 ret: Optional[triton.Config] = self.get(key) 

95 if ret is None: 

96 raise KeyError(f"Key {key} not found in ConfigCache.") 

97 return ret 

98 

99 def __setitem__( 

100 self, key: Tuple[Union[int, float, str], ...], config: triton.Config 

101 ) -> None: 

102 self.set(key, config) 

103 

104 def get(self, key: Tuple[Union[int, float, str], ...]) -> Optional[triton.Config]: 

105 return self.model.get_config(self.table_name, key) 

106 

107 def set( 

108 self, key: Tuple[Union[int, float, str], ...], config: triton.Config 

109 ) -> None: 

110 return self.model.put_config(self.table_name, key, config) 

111 

112 

113class BenchmarkCache(Cache): 

114 def __init__( 

115 self, 

116 table_name: str, 

117 model: PersistantModel, 

118 key: Tuple[Union[int, float, str], ...], 

119 *args, 

120 **kwargs, 

121 ) -> BenchmarkCache: 

122 """ 

123 `BenchmarkCache` is used to store the benchmark results for the pair of the specific key and configuration. 

124 """ 

125 super().__init__(table_name, model, *args, **kwargs) 

126 self.key: Final[Tuple[Union[int, float, str], ...]] = key 

127 

128 def __contains__(self, config: triton.Config) -> bool: 

129 return self.model.get_benchmark(self.key, config) is not None 

130 

131 def __getitem__(self, config: triton.Config) -> Tuple[float]: 

132 ret: Optional[Tuple[float, float, float]] = self.get(config) 

133 if ret is None: 

134 raise KeyError( 

135 f"Config {config} not found in BenchmarkCache for key {self.key}." 

136 ) 

137 return ret 

138 

139 def __setitem__(self, config: triton.Config, benchmark: Tuple[float]) -> None: 

140 return self.set(config, benchmark) 

141 

142 def get(self, config: triton.Config) -> Optional[Tuple[float, float, float]]: 

143 return self.model.get_benchmark(self.table_name, self.key, config) 

144 

145 def set(self, config: triton.Config, benchmark: Tuple[float, float, float]) -> None: 

146 return self.model.put_benchmark(self.table_name, self.key, config, benchmark) 

147 

148 

149class LibCache(object): 

150 _instance = None 

151 

152 def __new__(cls, *args, **kwargs): 

153 if cls._instance is None: 

154 cls._instance = super(LibCache, cls).__new__(cls) 

155 return cls._instance 

156 

157 def __init__(self, db_url: Optional[str] = None): 

158 self.global_cache: Dict = {} 

159 self.volumn: Dict = {} 

160 if db_url is None: 

161 try: 

162 device_name: str = torch_device_fn.get_device_name().replace(" ", "_") 

163 except AttributeError: 

164 device_name: str = vendor_module.vendor_info.device_name 

165 cache_file_name: str = ( 

166 f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db" 

167 if vendor_module.vendor_info.vendor_name == "nvidia" 

168 else f"TunedConfig_{vendor_module.vendor_info.vendor_name}_triton_{major_version}_{minor_version}.db" 

169 ) 

170 cache_path: Path = config_cache_dir() / cache_file_name 

171 self.db_url: str = f"sqlite:///{cache_path}" 

172 else: 

173 self.db_url: str = db_url 

174 self.config_cache_pool: Dict[str, ConfigCache] = {} 

175 self.benchmark_cache_pool: Dict[ 

176 Tuple[str, Tuple[Union[int, float, str], ...]], BenchmarkCache 

177 ] = {} 

178 self.model: PersistantModel = SQLPersistantModel(self.db_url) 

179 

180 @overload 

181 def __getitem__(self, key: str) -> ConfigCache: 

182 ... 

183 

184 @overload 

185 def __getitem__(self, key: Tuple[Union[int, float, str]]) -> BenchmarkCache: 

186 ... 

187 

188 def __getitem__( 

189 self, key: Union[str, Tuple[Union[int, float, str], ...]] 

190 ) -> Union[BenchmarkCache, ConfigCache]: 

191 if isinstance(key, str): 

192 return self.get_config(key) 

193 elif isinstance(key, tuple): 

194 return self.get_benchmark(*key) 

195 else: 

196 assert False, f"the type of key '{key.__class__.__name__}' is unacceptable" 

197 

198 def get_benchmark( 

199 self, table: str, key: Tuple[Union[int, float, str], ...] 

200 ) -> BenchmarkCache: 

201 ret = self.benchmark_cache_pool.get((table, key)) 

202 if ret is None: 

203 ret = BenchmarkCache(table, self.model, key) 

204 self.benchmark_cache_pool[(table, key)] = ret 

205 return ret 

206 

207 def get_config(self, table: str) -> ConfigCache: 

208 ret = self.config_cache_pool.get(table) 

209 if ret is None: 

210 ret = ConfigCache(table, self.model) 

211 self.config_cache_pool[table] = ret 

212 return ret 

213 

214 

215libcache = LibCache(FLAGGEMS_DB_URL) 

216 

217 

218class LibTuner(triton.runtime.Autotuner): 

219 """`LibTuner` is the base class for `FlagGems` library autotuner. 

220 

221 It could be extended in two ways, overriding the `policy` or `run` method in a subclass. 

222 For `policy` extension, `LibTuner` provides a decorator `register_policy` to register a policy function quickly. 

223 Please refer to the implementation of `default_policy` for an example. 

224 """ 

225 

226 # The dispatch table for `LibTuner` subclasses. It's shared across all instances. 

227 _dispatch_table: Dict[str, Type[LibTuner]] = {} 

228 _strategy_table: Dict[str, Callable[[Any], Any]] = {} 

229 

230 def __init__( 

231 self, 

232 fn, 

233 arg_names, 

234 configs, 

235 key, 

236 reset_to_zero, 

237 restore_value, 

238 pre_hook=None, 

239 post_hook=None, 

240 prune_configs_by: Optional[Dict] = None, 

241 warmup=None, 

242 rep=None, 

243 use_cuda_graph=False, 

244 do_bench=None, 

245 strategy=None, 

246 ): 

247 # NOTE(zhengyang): See discussion in https://github.com/triton-lang/triton/pull/4496 

248 if major_version == 2 or (major_version == 3 and minor_version <= 1): 

249 if warmup is None: 

250 warmup = 25 

251 if rep is None: 

252 rep = 100 

253 if major_version == 2: 

254 super().__init__( 

255 fn, 

256 arg_names, 

257 configs, 

258 key, 

259 reset_to_zero, 

260 restore_value, 

261 prune_configs_by, 

262 warmup, 

263 rep, 

264 ) 

265 self.base_fn = fn 

266 while not inspect.isfunction(self.base_fn): 

267 self.base_fn = self.base_fn.fn 

268 else: 

269 super().__init__( 

270 fn, 

271 arg_names, 

272 configs, 

273 key, 

274 reset_to_zero, 

275 restore_value, 

276 pre_hook, 

277 post_hook, 

278 prune_configs_by, 

279 warmup, 

280 rep, 

281 use_cuda_graph, 

282 ) 

283 self.__name__ = self.base_fn.__name__ 

284 self.keys = key 

285 if isinstance(strategy, str): 

286 strategy = LibTuner.get_strategy(strategy) 

287 if not isinstance(strategy, (list, tuple)): 

288 strategy = [strategy] * len(self.keys) 

289 assert len(strategy) == len( 

290 self.keys 

291 ), f"the length of strategy {len(strategy)} must match the length of keys {len(self.keys)}" 

292 strategy: List[Callable[[Any], Any]] = [ 

293 LibTuner.get_strategy(s) if isinstance(s, str) else s for s in strategy 

294 ] 

295 self.strategy: List[Callable[[Any], Any]] = strategy 

296 self.config_table_name: str = f"{self.__name__}_{self.kernel_hash}" 

297 self.benchmark_table_name: str = f"{self.__name__}_{self.cache_key}_benchmark" 

298 self.cache: BenchmarkCache = libcache[self.config_table_name] 

299 

300 @cached_property 

301 def cache_key(self) -> str: 

302 jit_fn = self.fn 

303 while not isinstance(jit_fn, triton.runtime.JITFunction): 

304 jit_fn = jit_fn.fn 

305 return jit_fn.cache_key 

306 

307 @cached_property 

308 def kernel_hash(self) -> str: 

309 return hashlib.md5( 

310 f"{self.cache_key}{self.configs_hash}".encode("utf-8") 

311 ).hexdigest()[:32] 

312 

313 @cached_property 

314 def configs_hash(self) -> str: 

315 return hashlib.md5( 

316 ",".join(map(lambda config: str(config), self.configs)).encode("utf-8") 

317 ).hexdigest()[:32] 

318 

319 def get_key(self, args): 

320 if self.strategy is None: 

321 key = tuple(args[k] for k in self.keys if k in args) 

322 else: 

323 key = tuple( 

324 starmap( 

325 lambda idx0, idx1: self.strategy[idx0](args[idx1]), 

326 enumerate(self.keys), 

327 ) 

328 ) 

329 key += tuple(str(arg.dtype) for arg in args.values() if hasattr(arg, "dtype")) 

330 return key 

331 

332 @staticmethod 

333 @abstractmethod 

334 def policy( 

335 self, 

336 fn: Callable[[triton.Config], List[float]], 

337 configs: Iterator[triton.Config], 

338 args: Tuple[Any], 

339 kwargs: Dict[str, Any], 

340 ) -> Tuple[triton.Config, Dict[str, float]]: 

341 raise NotImplementedError( 

342 f"`policy` isn't implemented in {self.__class__.__name__}" 

343 ) 

344 

345 @classmethod 

346 def register(cls, name: str): 

347 """Register a subclass of `LibTuner` with a name. 

348 

349 Args: 

350 name: The name of the subclass. 

351 Returns: 

352 A decorator that registers the subclass with the name. 

353 """ 

354 

355 def decorator(subclass): 

356 cls._dispatch_table[name] = subclass 

357 return subclass 

358 

359 return decorator 

360 

361 @classmethod 

362 def get(cls, name: str): 

363 return cls._dispatch_table[name] 

364 

365 @classmethod 

366 def get_strategy(cls, name: str): 

367 return cls._strategy_table[name] 

368 

369 @staticmethod 

370 def register_policy( 

371 name: str, 

372 ) -> Type[LibTuner]: 

373 """A decorator to register a policy for `LibTuner`. 

374 

375 This decorator allows you to create a new `LibTuner` subclass without defining a new class explicitly. 

376 The new subclass will have the `policy` method set to the provided policy function and will be registered under 

377 the specified name in the `LibTuner` dispatch table. 

378 """ 

379 

380 def decorator( 

381 policy_impl: Callable[ 

382 [ 

383 Callable[[triton.Config], List[float]], 

384 Iterator[triton.Config], 

385 Tuple[Any], 

386 Dict[str, Any], 

387 ], 

388 Tuple[triton.Config, Dict[str, float]], 

389 ], 

390 ): 

391 @LibTuner.register(name) 

392 class AnonymousLibTunerImpl(LibTuner): 

393 def __init__(self, *args, **kwargs): 

394 super().__init__(*args, **kwargs) 

395 

396 def policy( 

397 self, 

398 fn: Callable[[triton.Config], List[float]], 

399 configs: Iterator[triton.Config], 

400 args: Tuple[Any], 

401 kwargs: Dict[str, Any], 

402 ) -> Tuple[triton.Config, Dict[str, float]]: 

403 return policy_impl(fn, configs, args, kwargs) 

404 

405 return AnonymousLibTunerImpl 

406 

407 return decorator 

408 

409 @staticmethod 

410 def register_strategy(name: str): 

411 def decorator( 

412 strategy: Union[Callable[[Any], Any], List[Callable[[Any], Any]]], 

413 ): 

414 LibTuner._strategy_table[name] = strategy 

415 return strategy 

416 

417 return decorator 

418 

419 def run(self, *args, **kwargs): 

420 # `arg_names` corresponds to the arguments of the `JITFunction`'s signature, 

421 # so please make sure the orders of `arg_names` and `args` match. 

422 self.nargs = dict(zip(self.arg_names, args)) 

423 used_cached_result = True 

424 if len(self.configs) > 1: 

425 all_args = {**self.nargs, **kwargs} 

426 _args = {k: v for k, v in all_args.items() if k in self.arg_names} 

427 key = self.get_key(_args) 

428 if key not in self.cache: 

429 cache: BenchmarkCache = libcache[self.benchmark_table_name, key] 

430 # prune configs 

431 used_cached_result = False 

432 pruned_configs = self.prune_configs(kwargs) 

433 bench_start = time.time() 

434 

435 def bench(config: triton.Config) -> List[float]: 

436 ret = cache.get(config) 

437 if ret is None: 

438 ret = self._bench(*args, config=config, **kwargs) 

439 cache[config] = tuple(ret) 

440 return list(ret) 

441 

442 best_config, timings = self.policy( 

443 bench, 

444 pruned_configs, 

445 args, 

446 kwargs, 

447 ) 

448 bench_end = time.time() 

449 self.bench_time = bench_end - bench_start 

450 self.cache[key] = best_config 

451 full_nargs = { 

452 **self.nargs, 

453 **kwargs, 

454 **self.cache[key].all_kwargs(), 

455 } 

456 self.pre_hook(full_nargs, reset_only=True) 

457 self.configs_timings = timings 

458 config = self.cache[key] 

459 if config.pre_hook is None: 

460 cached_kwargs = config.all_kwargs() 

461 for original_config in self.configs: 

462 if original_config.all_kwargs() == cached_kwargs: 

463 # Use the original config which has the pre_hook 

464 config = original_config 

465 break 

466 else: 

467 config = self.configs[0] 

468 self.best_config = config 

469 if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: 

470 print( 

471 f"Triton autotuning for function {self.base_fn.__name__} finished after " 

472 f"{self.bench_time:.2f}s; key info: {key}, best config selected: {self.best_config};" 

473 ) 

474 if config.pre_hook is not None: 

475 full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} 

476 config.pre_hook(full_nargs) 

477 ret = self.fn.run( 

478 *args, 

479 **kwargs, 

480 **config.all_kwargs(), 

481 ) 

482 self.nargs = None 

483 return ret 

484 

485 

486@LibTuner.register_strategy(None) 

487@LibTuner.register_strategy("default") 

488def default_strategy(key: Any) -> Any: 

489 return key 

490 

491 

492@LibTuner.register_strategy("log") 

493def log2_strategy(key: Union[int, float]) -> float: 

494 return 2 ** math.ceil(math.log2(key)) 

495 

496 

497@LibTuner.register_strategy("align32") 

498def align32_strategy(key: Union[int, float]) -> int: 

499 return math.ceil(key / 32) * 32 

500 

501 

502@LibTuner.register_policy("default") 

503def default_policy( 

504 bench_fn: Callable[[triton.Config], List[float]], 

505 configs: Iterator[triton.Config], 

506 args: Tuple[Any], 

507 kwargs: Dict[str, Any], 

508) -> Tuple[triton.Config, Dict[str, float]]: 

509 """Default policy for offline autotuning. 

510 

511 Args: 

512 bench_fn: The function to benchmark. 

513 configs: The collection of the configuration search space. 

514 args: Kernel launch arguments. 

515 kwargs: Kernel launch arguments. 

516 Returns: 

517 A tuple containing the best configuration and a dictionary of timings for each configuration. 

518 

519 This is one way to implement a default policy for offline autotuning. It's equal to the following 

520 ``` 

521 @LibTuner.register("default") 

522 class DefaultLibTunerImpl(LibTuner): 

523 def __init__( 

524 self, 

525 *args, 

526 **kwargs, 

527 ): 

528 super().__init__( 

529 *args, 

530 **kwargs, 

531 ) 

532 

533 @staticmethod 

534 def policy( 

535 bench_fn: Callable[[triton.Config], List[float]], 

536 configs: Iterator[triton.Config], 

537 args: Tuple[Any], 

538 kwargs: Dict[str, Any], 

539 ) -> Tuple[triton.Config, Dict[str, float]]: 

540 timings: Dict[triton.Config, int] = { 

541 config: bench_fn(config) for config in configs 

542 } 

543 best_config: triton.Config = min(timings, key=timings.get) 

544 return best_config, timings 

545 ``` 

546 In this way policies could be extended by registering a definition function quickly, 

547 or by creating a new subclass of `LibTuner` and overriding the `policy` method to have 

548 more control over the autotuning process. 

549 """ 

550 timings: Dict[triton.Config, float] = { 

551 config: bench_fn(config) for config in configs 

552 } 

553 best_config: triton.Config = min(timings, key=timings.get) 

554 return best_config, timings 

555 

556 

557def libtuner( 

558 configs, 

559 key, 

560 prune_configs_by=None, 

561 reset_to_zero=None, 

562 restore_value=None, 

563 pre_hook=None, 

564 post_hook=None, 

565 warmup=25, 

566 rep=100, 

567 use_cuda_graph=False, 

568 do_bench=None, 

569 strategy: Union[ 

570 str, Callable[[Any], Any], List[Union[str, Callable[[Any], Any]]] 

571 ] = "default", 

572 policy: Union[str, Type[LibTuner]] = "default", 

573): 

574 """Decorator for triton library autotuner. 

575 

576 `strategy` is a function that takes a key and returns a value. 

577 It accepts a string, which is the name of a registered strategy, or a callable function. 

578 In this form it will be applied to each key in the `key` list. 

579 If it's a tuple or list, it should have the same length as `key`, 

580 and each element should be a string or a callable function that takes a key and returns a value. 

581 `policy` accepts a string, which is the name of a registered `LibTuner` subclass, or a `LibTuner` subclass itself. 

582 """ 

583 

584 if isinstance(policy, str): 

585 policy = LibTuner.get(policy) 

586 assert issubclass( 

587 policy, LibTuner 

588 ), f"the class of {policy.__name__} is {policy.__class__.__name__}, not a subclass of {LibTuner.__name__}" 

589 

590 def decorator(fn): 

591 return policy( 

592 fn, 

593 fn.arg_names, 

594 configs, 

595 key, 

596 reset_to_zero, 

597 restore_value, 

598 pre_hook=pre_hook, 

599 post_hook=post_hook, 

600 prune_configs_by=prune_configs_by, 

601 warmup=warmup, 

602 rep=rep, 

603 use_cuda_graph=use_cuda_graph, 

604 do_bench=do_bench, 

605 strategy=strategy, 

606 ) 

607 

608 return decorator 

609 

610 

611class LibEntry(triton.KernelInterface): 

612 def __init__( 

613 self, 

614 fn, 

615 ): 

616 self.fn = fn 

617 self.arg_names = fn.arg_names 

618 self.divisibility = 16 

619 self.kernel_cache = tuple(dict() for _ in range(DEVICE_COUNT)) 

620 

621 while not isinstance(fn, triton.runtime.JITFunction): 

622 fn = fn.fn 

623 self.jit_function: triton.runtime.JITFunction = fn 

624 self.specialize_indices = [ 

625 p.num 

626 for p in self.jit_function.params 

627 if not p.is_constexpr and not p.do_not_specialize 

628 ] 

629 self.do_not_specialize_indices = [ 

630 p.num 

631 for p in self.jit_function.params 

632 if not p.is_constexpr and p.do_not_specialize 

633 ] 

634 self.lock = multiprocessing.Lock() 

635 self.signature = fn.signature 

636 

637 def key(self, spec_args, dns_args, const_args): 

638 def spec_arg(arg): 

639 if hasattr(arg, "data_ptr"): 

640 return (arg.dtype, arg.data_ptr() % self.divisibility == 0) 

641 return (type(arg), arg) 

642 

643 def dns_arg(arg): 

644 if hasattr(arg, "data_ptr"): 

645 return arg.dtype 

646 if not isinstance(arg, int): 

647 return type(arg) 

648 if -(2**31) <= arg and arg <= 2**31 - 1: 

649 return "i32" 

650 if 2**63 <= arg and arg <= 2**64 - 1: 

651 return "u64" 

652 return "i64" 

653 

654 spec_key = [spec_arg(arg) for arg in spec_args] 

655 dns_key = [dns_arg(arg) for arg in dns_args] 

656 # const args passed by position 

657 return tuple(spec_key + dns_key + const_args) 

658 

659 def run(self, *args, **kwargs): 

660 grid = kwargs["grid"] 

661 

662 # collect all the arguments 

663 spec_args = [] # specialize arguments 

664 dns_args = [] # do not specialize arguments 

665 const_args = [] # constexpr arguments 

666 k_args = OrderedDict() 

667 param_names = list(self.signature.parameters.keys()) 

668 for i, arg in enumerate(args): 

669 hashable_arg = arg 

670 if ( 

671 hasattr(arg, "__class__") 

672 and arg.__class__.__name__ == "TensorDescriptor" 

673 ): 

674 # Create a hashable representation of TensorDescriptor 

675 hashable_arg = ( 

676 "TensorDescriptor", 

677 tuple(arg.shape) if hasattr(arg, "shape") else None, 

678 tuple(arg.strides) if hasattr(arg, "strides") else None, 

679 tuple(arg.block_shape) if hasattr(arg, "block_shape") else None, 

680 arg.padding if hasattr(arg, "padding") else None, 

681 # Add other relevant attributes 

682 ) 

683 if i in self.specialize_indices: 

684 k_args[param_names[i]] = arg 

685 spec_args.append(hashable_arg) 

686 elif i in self.do_not_specialize_indices: 

687 k_args[param_names[i]] = arg 

688 dns_args.append(hashable_arg) 

689 else: 

690 if major_version == 3 and 3 <= minor_version <= 6: 

691 k_args[param_names[i]] = arg 

692 const_args.append(hashable_arg) 

693 for p in self.jit_function.params[len(args) :]: 

694 if p.name in kwargs: 

695 val = kwargs[p.name] 

696 elif p.default is inspect._empty: 

697 continue 

698 else: 

699 val = p.default 

700 

701 if p.is_constexpr: 

702 const_args.append(val) 

703 if major_version == 3 and 3 <= minor_version <= 6: 

704 k_args[p.name] = val 

705 elif p.do_not_specialize: 

706 dns_args.append(val) 

707 k_args[p.name] = val 

708 else: 

709 spec_args.append(val) 

710 k_args[p.name] = val 

711 

712 entry_key = self.key(spec_args, dns_args, const_args) 

713 device = torch_device_fn.current_device() 

714 cache = self.kernel_cache[device] 

715 while entry_key not in cache: 

716 # NOTE: we serialize the first run of a jit function regardless of which device to run on 

717 # because Triton runtime is currently not threadsafe. 

718 with self.lock: 

719 if entry_key in cache: 

720 break 

721 kernel = self.fn.run(*args, **kwargs) 

722 fn = self.fn 

723 # collect constexpr arguments for grid computation 

724 constexprs = {} 

725 tune_constexprs = {} 

726 heur_constexprs = {} 

727 while not isinstance(fn, triton.runtime.JITFunction): 

728 if isinstance(fn, triton.runtime.Autotuner): 

729 config = fn.best_config 

730 constexprs["num_warps"] = config.num_warps 

731 constexprs["num_stages"] = config.num_stages 

732 constexprs["num_ctas"] = config.num_ctas 

733 constexprs = {**constexprs, **config.kwargs} 

734 tune_constexprs = {**tune_constexprs, **config.kwargs} 

735 elif isinstance(fn, triton.runtime.Heuristics): 

736 for v, heur in fn.values.items(): 

737 heur_constexprs[v] = heur( 

738 { 

739 **dict(zip(fn.arg_names, args)), 

740 **kwargs, 

741 **constexprs, 

742 } 

743 ) 

744 constexprs[v] = heur_constexprs[v] 

745 else: 

746 raise RuntimeError("Invalid Runtime Function") 

747 fn = fn.fn 

748 for p in self.jit_function.params: 

749 if ( 

750 p.is_constexpr 

751 and p.name not in constexprs 

752 and (p.default is not inspect._empty) 

753 ): 

754 constexprs[p.name] = p.default 

755 cache[entry_key] = ( 

756 kernel, 

757 constexprs, 

758 tune_constexprs, 

759 heur_constexprs, 

760 ) 

761 return kernel, constexprs 

762 

763 kernel, constexprs, tune_constexprs, heur_constexprs = cache[entry_key] 

764 

765 if callable(grid): 

766 # collect all arguments to the grid fn,ie: 

767 # 1. args, 

768 # 2. kwargs, 

769 # 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics 

770 # when kwargs & captured args conflict, captured args have higher priority 

771 meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs} 

772 grid = grid(meta) 

773 grid = grid + (1, 1) 

774 

775 if major_version == 3 and 3 <= minor_version <= 6: 

776 all_args = [] 

777 missing_keys = [] 

778 for key in list(self.signature.parameters.keys()): 

779 if key in k_args: 

780 all_args.append(k_args[key]) 

781 elif key in tune_constexprs: 

782 all_args.append(tune_constexprs[key]) 

783 elif key in heur_constexprs: 

784 all_args.append(heur_constexprs[key]) 

785 elif key in constexprs: 

786 all_args.append(constexprs[key]) 

787 else: 

788 missing_keys.append(key) 

789 if len(missing_keys): 

790 raise RuntimeError( 

791 f"[libentry]: probably a bug, the following kernel params where not captured: {missing_keys}" 

792 ) 

793 kernel[grid[0:3]](*all_args) 

794 else: 

795 kernel[grid[0:3]](*k_args.values()) 

796 return kernel, constexprs 

797 

798 

799def libentry(): 

800 """Decorator for triton library entries.""" 

801 

802 def decorator(fn): 

803 return LibEntry(fn) 

804 

805 return decorator