Coverage for src/flag_gems/runtime/configs_loader.py: 65%

237 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import copy 

2import inspect 

3import os 

4import warnings 

5 

6import triton 

7 

8from . import backend, common 

9from .backend.device import DeviceDetector 

10 

11 

12class TunedConfigLoader(object): 

13 _instance = None 

14 

15 def __new__(cls, *args, **kargs): 

16 if cls._instance is None: 

17 cls._instance = super(TunedConfigLoader, cls).__new__(cls) 

18 return cls._instance 

19 

20 def __init__(self): 

21 if not hasattr(self, "initialized"): 

22 self.initialized = True 

23 self.device = DeviceDetector() 

24 # primitive_yaml_config is simply the dictionary returned by yaml 

25 # and is reserved from being an attr for vendor customizability 

26 self.arch_specialized_yaml_config = None 

27 self.arch_heuristics_config = None 

28 self.vendor_primitive_yaml_config = self.get_vendor_tune_config() 

29 self.default_primitive_yaml_config = self.get_default_tune_config() 

30 self.vendor_heuristics_config = self.get_vendor_heuristics_config() 

31 self.default_heuristics_config = self.get_default_heuristics_config() 

32 self.update_config_from_arch() 

33 

34 if self.vendor_heuristics_config is None: 

35 vendorname = self.device.vendor_name 

36 warnings.warn( 

37 f"The {vendorname} configuration of heuristics_config is None" 

38 ) 

39 # gen_key is an identifier that indicates whether the current config needs to be generated automatically 

40 self.gen_key = "gen" 

41 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config 

42 self.loaded_triton_config = {} 

43 self.triton_config_default = { 

44 "num_stages": 2, 

45 "num_warps": 4, 

46 "num_ctas": 1, 

47 } 

48 if self.device.vendor_name == "hygon": 

49 self.triton_config_default["num_ldmatrixes"] = 0 

50 self.expand_config_registry = self._build_expand_registry() 

51 self.load_all() 

52 

53 def update_config_from_arch(self): 

54 try: 

55 archEvent = backend.BackendArchEvent() 

56 if archEvent.has_arch: 

57 self.arch_specialized_yaml_config = archEvent.autotune_configs 

58 self.arch_heuristics_config = archEvent.heuristics_configs 

59 except Exception as err: 

60 print(f"[INFO] : {err}") 

61 

62 def _get_op_configs(self, op_name): 

63 """Get config for op_name from available config sources.""" 

64 for config in ( 

65 self.arch_specialized_yaml_config, 

66 self.vendor_primitive_yaml_config, 

67 self.default_primitive_yaml_config, 

68 ): 

69 if config and op_name in config: 

70 return config[op_name] 

71 return [] 

72 

73 def _create_triton_config(self, single_config, current_config): 

74 """Create a triton.Config with appropriate parameters.""" 

75 kwargs = { 

76 "num_warps": current_config["num_warps"], 

77 "num_stages": current_config["num_stages"], 

78 "num_ctas": current_config["num_ctas"], 

79 } 

80 if ( 

81 self.device.vendor_name == "hygon" 

82 and "num_ldmatrixes" in inspect.signature(triton.Config).parameters 

83 ): 

84 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"] 

85 return triton.Config(single_config["META"], **kwargs) 

86 

87 def _build_configs_by_op(self, op_name, ranges, pre_hook=None): 

88 if op_name == "bmm": 

89 return [ 

90 triton.Config( 

91 { 

92 "TILE_M": block_m, 

93 "TILE_N": block_n, 

94 "TILE_K": block_k, 

95 "GROUP_M": 1 if block_m == 32 else 2, 

96 }, 

97 num_stages=s, 

98 num_warps=w, 

99 pre_hook=pre_hook, 

100 ) 

101 for block_m in ranges["BLOCK_M"] 

102 for block_n in ranges["BLOCK_N"] 

103 for block_k in ranges["BLOCK_K"] 

104 for s in ranges["s"] 

105 for w in ranges["w"] 

106 ] 

107 

108 if op_name == "addmm": 

109 return [ 

110 triton.Config( 

111 { 

112 "BLOCK_SIZE_M": block_m, 

113 "BLOCK_SIZE_N": block_n, 

114 "BLOCK_SIZE_K": block_k, 

115 }, 

116 num_stages=s, 

117 num_warps=w, 

118 pre_hook=pre_hook, 

119 ) 

120 for block_m in ranges["BLOCK_M"] 

121 for block_n in ranges["BLOCK_N"] 

122 for block_k in ranges["BLOCK_K"] 

123 for s in ranges["s"] 

124 for w in ranges["w"] 

125 ] 

126 

127 if op_name == "baddbmm": 

128 return [ 

129 triton.Config( 

130 { 

131 "TILE_M": block_m, 

132 "TILE_N": block_n, 

133 "TILE_K": block_k, 

134 "GROUP_M": 1 if block_m <= 32 else 2, 

135 }, 

136 num_stages=s, 

137 num_warps=w, 

138 pre_hook=pre_hook, 

139 ) 

140 for block_m in ranges["BLOCK_M"] 

141 for block_n in ranges["BLOCK_N"] 

142 for block_k in ranges["BLOCK_K"] 

143 for s in ranges["s"] 

144 for w in ranges["w"] 

145 ] 

146 

147 if op_name == "mv": 

148 return [ 

149 triton.Config( 

150 { 

151 "BLOCK_N": block_n, 

152 "BLOCK_M": block_m, 

153 }, 

154 num_stages=s, 

155 num_warps=w, 

156 pre_hook=pre_hook, 

157 ) 

158 for block_n in ranges["BLOCK_N"] 

159 for block_m in ranges["BLOCK_M"] 

160 for s in ranges["s"] 

161 for w in ranges["w"] 

162 ] 

163 

164 if op_name == "mm_general_tma": 

165 group_m_values = ranges.get("GROUP_M", [8]) 

166 return [ 

167 triton.Config( 

168 { 

169 "BLOCK_M": block_m, 

170 "BLOCK_N": block_n, 

171 "BLOCK_K": block_k, 

172 "GROUP_M": group_m, 

173 }, 

174 num_stages=s, 

175 num_warps=w, 

176 pre_hook=pre_hook, 

177 ) 

178 for block_m in ranges["BLOCK_M"] 

179 for block_n in ranges["BLOCK_N"] 

180 for block_k in ranges["BLOCK_K"] 

181 for group_m in group_m_values 

182 for s in ranges["s"] 

183 for w in ranges["w"] 

184 ] 

185 

186 if op_name in ("mm", "mm_sqmma"): 

187 return [ 

188 triton.Config( 

189 { 

190 "BLOCK_M": block_m, 

191 "BLOCK_N": block_n, 

192 "BLOCK_K": block_k, 

193 }, 

194 num_stages=s, 

195 num_warps=w, 

196 pre_hook=pre_hook, 

197 ) 

198 for block_m in ranges["BLOCK_M"] 

199 for block_n in ranges["BLOCK_N"] 

200 for block_k in ranges["BLOCK_K"] 

201 for s in ranges["s"] 

202 for w in ranges["w"] 

203 ] 

204 

205 if op_name in ("bmm_sqmma", "addmm_sqmma"): 

206 return [ 

207 triton.Config( 

208 { 

209 "BLOCK_SIZE_M": block_m, 

210 "BLOCK_SIZE_N": block_n, 

211 "BLOCK_SIZE_K": block_k, 

212 }, 

213 num_stages=s, 

214 num_warps=w, 

215 pre_hook=pre_hook, 

216 ) 

217 for block_m in ranges["BLOCK_M"] 

218 for block_n in ranges["BLOCK_N"] 

219 for block_k in ranges["BLOCK_K"] 

220 for s in ranges["s"] 

221 for w in ranges["w"] 

222 ] 

223 

224 if op_name == "gemv": 

225 return [ 

226 triton.Config( 

227 {"BLOCK_M": block_m, "BLOCK_K": block_k}, 

228 num_stages=s, 

229 num_warps=w, 

230 pre_hook=pre_hook, 

231 ) 

232 for block_m in ranges["BLOCK_M"] 

233 for block_k in ranges["BLOCK_K"] 

234 for s in ranges["s"] 

235 for w in ranges["w"] 

236 ] 

237 

238 if op_name == "sparse_attention": 

239 return [ 

240 triton.Config( 

241 {"BLOCK": block}, 

242 num_stages=s, 

243 num_warps=w, 

244 pre_hook=pre_hook, 

245 ) 

246 for block in ranges["BLOCK"] 

247 for s in ranges["s"] 

248 for w in ranges["w"] 

249 ] 

250 

251 if op_name == "w8a8_block_fp8_general": 

252 return [ 

253 triton.Config( 

254 { 

255 "BLOCK_M": block_m, 

256 "BLOCK_N": block_n, 

257 "BLOCK_K": block_k, 

258 "GROUP_M": group_m, 

259 }, 

260 num_stages=s, 

261 num_warps=w, 

262 pre_hook=pre_hook, 

263 ) 

264 for block_m in ranges["BLOCK_M"] 

265 for block_n in ranges["BLOCK_N"] 

266 for block_k in ranges["BLOCK_K"] 

267 for group_m in ranges["GROUP_M"] 

268 for s in ranges["s"] 

269 for w in ranges["w"] 

270 ] 

271 

272 if op_name == "w8a8_block_fp8_general_tma": 

273 group_m_values = ranges.get("GROUP_M", [None]) 

274 return [ 

275 triton.Config( 

276 dict( 

277 { 

278 "BLOCK_M": block_m, 

279 "BLOCK_N": block_n, 

280 "BLOCK_K": block_k, 

281 }, 

282 **({} if group_m is None else {"GROUP_M": group_m}), 

283 ), 

284 num_stages=s, 

285 num_warps=w, 

286 pre_hook=pre_hook, 

287 ) 

288 for block_m in ranges["BLOCK_M"] 

289 for block_n in ranges["BLOCK_N"] 

290 for block_k in ranges["BLOCK_K"] 

291 for group_m in group_m_values 

292 for s in ranges["s"] 

293 for w in ranges["w"] 

294 ] 

295 

296 if op_name == "w8a8_block_fp8_general_splitk": 

297 return [ 

298 triton.Config( 

299 { 

300 "BLOCK_M": block_m, 

301 "BLOCK_N": block_n, 

302 "BLOCK_K": block_k, 

303 "SPLIT_K": split_k, 

304 }, 

305 num_stages=s, 

306 num_warps=w, 

307 pre_hook=pre_hook, 

308 ) 

309 for block_m in ranges["BLOCK_M"] 

310 for block_n in ranges["BLOCK_N"] 

311 for block_k in ranges["BLOCK_K"] 

312 for split_k in ranges["SPLIT_K"] 

313 for s in ranges["s"] 

314 for w in ranges["w"] 

315 ] 

316 

317 if op_name == "mm_splitk": 

318 return [ 

319 triton.Config( 

320 { 

321 "BLOCK_M": block_m, 

322 "BLOCK_N": block_n, 

323 "BLOCK_K": block_k, 

324 "SPLIT_K": split_k, 

325 }, 

326 num_stages=s, 

327 num_warps=w, 

328 pre_hook=pre_hook, 

329 ) 

330 for block_m in ranges["BLOCK_M"] 

331 for block_n in ranges["BLOCK_N"] 

332 for block_k in ranges["BLOCK_K"] 

333 for split_k in ranges["SPLIT_K"] 

334 for s in ranges["s"] 

335 for w in ranges["w"] 

336 ] 

337 

338 return [] 

339 

340 def _build_single_expand_spec( 

341 self, 

342 op_name, 

343 expand_yaml_path=None, 

344 yaml_op_name=None, 

345 ): 

346 return { 

347 "yaml_op_name": yaml_op_name or op_name, 

348 "key": common.OP_KEY_ORDERS[op_name], 

349 "default_strategy": common.DEFAULT_STRATEGIES[op_name], 

350 "expand_yaml_path": expand_yaml_path, 

351 } 

352 

353 def _iter_expand_config_candidates(self, op_name): 

354 vendor_name = self.device.vendor_name 

355 contexts = [] 

356 try: 

357 arch_event = backend.BackendArchEvent() 

358 current_arch_path = getattr(arch_event, "current_arch_path", None) 

359 arch_name = getattr(arch_event, "arch", None) 

360 if arch_event.has_arch and current_arch_path: 

361 contexts.append((current_arch_path, arch_name)) 

362 except Exception: 

363 pass 

364 

365 backend_dir = os.path.join(os.path.dirname(__file__), "backend") 

366 contexts.append((os.path.join(backend_dir, f"_{vendor_name}"), vendor_name)) 

367 

368 seen = set() 

369 for base_dir, backend_name in contexts: 

370 filenames = [] 

371 if op_name: 

372 filenames.extend( 

373 ( 

374 f"{op_name}_{backend_name}_expand.yaml", 

375 f"{op_name}_{vendor_name}_expand.yaml", 

376 f"{op_name}_expand.yaml", 

377 ) 

378 ) 

379 filenames.extend( 

380 ( 

381 f"general_ops_{backend_name}_expand.yaml", 

382 f"general_ops_{vendor_name}_expand.yaml", 

383 "general_ops_expand.yaml", 

384 ) 

385 ) 

386 

387 for filename in filenames: 

388 path = os.path.normpath(os.path.join(base_dir, filename)) 

389 if path in seen: 

390 continue 

391 seen.add(path) 

392 yield path 

393 

394 def _get_expand_config_path(self, op_name): 

395 for path in self._iter_expand_config_candidates(op_name): 

396 if os.path.exists(path): 

397 return path 

398 return None 

399 

400 def _build_expand_registry(self): 

401 return { 

402 "addmm": self._build_single_expand_spec( 

403 "addmm", expand_yaml_path=self._get_expand_config_path("addmm") 

404 ), 

405 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"), 

406 "baddbmm": self._build_single_expand_spec( 

407 "baddbmm", expand_yaml_path=self._get_expand_config_path("baddbmm") 

408 ), 

409 "bmm": self._build_single_expand_spec( 

410 "bmm", expand_yaml_path=self._get_expand_config_path("bmm") 

411 ), 

412 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"), 

413 "gemv": self._build_single_expand_spec("gemv"), 

414 "mm": self._build_single_expand_spec( 

415 "mm", expand_yaml_path=self._get_expand_config_path("mm") 

416 ), 

417 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"), 

418 "mv": self._build_single_expand_spec( 

419 "mv", expand_yaml_path=self._get_expand_config_path("mv") 

420 ), 

421 "w8a8_block_fp8_general": self._build_single_expand_spec( 

422 "w8a8_block_fp8_general" 

423 ), 

424 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec( 

425 "w8a8_block_fp8_general_splitk" 

426 ), 

427 "w8a8_block_fp8_general_tma": self._build_single_expand_spec( 

428 "w8a8_block_fp8_general_tma" 

429 ), 

430 "mm_splitk": self._build_single_expand_spec("mm_splitk"), 

431 "sparse_attention": self._build_single_expand_spec("sparse_attention"), 

432 } 

433 

434 def load_all(self): 

435 for key in self.vendor_primitive_yaml_config: 

436 self.loaded_triton_config[key] = self.get_tuned_config(key) 

437 

438 def get_vendor_heuristics_config(self): 

439 return backend.get_heuristic_config(self.device.vendor_name) 

440 

441 def get_default_heuristics_config(self): 

442 return backend.get_heuristic_config("nvidia") 

443 

444 def get_default_tune_config(self): 

445 return backend.get_tune_config("nvidia") 

446 

447 def get_vendor_tune_config(self): 

448 return backend.get_tune_config(self.device.vendor_name) 

449 

450 def get_heuristics_config(self, op_name): 

451 if self.arch_heuristics_config and op_name in self.arch_heuristics_config: 

452 return self.arch_heuristics_config[op_name] 

453 elif op_name in self.vendor_heuristics_config: 

454 return self.vendor_heuristics_config[op_name] 

455 elif op_name in self.default_heuristics_config: 

456 return self.default_heuristics_config[op_name] 

457 else: 

458 warnings.warn(f"No heuristics config found for {op_name}") 

459 return None 

460 

461 def _resolve_iteration_values(self, gen_config, config_var_key): 

462 if isinstance(config_var_key, (list, tuple)): 

463 return config_var_key 

464 if isinstance(config_var_key, int): 

465 return [config_var_key] 

466 return gen_config[config_var_key] 

467 

468 def _gen_impl( 

469 self, 

470 gen_config, 

471 iteration_plan, 

472 std_config, 

473 ): 

474 all_configs = [] 

475 final_step = len(iteration_plan) 

476 stack = [{"cur_config": std_config, "current_step": 0}] 

477 

478 while stack: 

479 cur_state = stack[-1] 

480 stack.pop() 

481 cur_config = cur_state.get("cur_config") 

482 current_step = cur_state.get("current_step") 

483 

484 if current_step == final_step: 

485 all_configs.append( 

486 triton.Config( 

487 cur_config["META"], 

488 num_warps=cur_config["num_warps"], 

489 num_stages=cur_config["num_stages"], 

490 num_ctas=cur_config["num_ctas"], 

491 ) 

492 ) 

493 else: 

494 cur_entry = iteration_plan[current_step] 

495 cur_key = cur_entry["key"] 

496 key_config = self._resolve_iteration_values( 

497 gen_config, cur_entry["source"] 

498 ) 

499 for single_value in key_config: 

500 new_config = copy.deepcopy(cur_config) 

501 if cur_entry["kind"] == "meta_field": 

502 new_config["META"][cur_key] = single_value 

503 elif cur_entry["kind"] == "meta_block": 

504 new_config["META"] = copy.deepcopy(single_value) 

505 else: 

506 new_config[cur_key] = single_value 

507 stack.append( 

508 { 

509 "cur_config": new_config, 

510 "current_step": current_step + 1, 

511 } 

512 ) 

513 return all_configs 

514 

515 def to_gen_config(self, gen_config): 

516 param_config = gen_config["param_map"] 

517 meta_config = param_config["META"] 

518 iteration_plan = [] 

519 

520 if isinstance(meta_config, dict): 

521 for meta_key, source in meta_config.items(): 

522 iteration_plan.append( 

523 {"key": meta_key, "source": source, "kind": "meta_field"} 

524 ) 

525 else: 

526 iteration_plan.append( 

527 {"key": "META", "source": meta_config, "kind": "meta_block"} 

528 ) 

529 

530 for key, source in param_config.items(): 

531 if key == "META": 

532 continue 

533 iteration_plan.append( 

534 {"key": key, "source": source, "kind": "config_field"} 

535 ) 

536 

537 current_config = {"META": {}} 

538 current_config.update(self.triton_config_default) 

539 return self._gen_impl( 

540 gen_config, 

541 iteration_plan, 

542 current_config, 

543 ) 

544 

545 def get_expand_config(self, op_name, yaml_path=None): 

546 op_spec = self.expand_config_registry.get(op_name) 

547 if op_spec is None: 

548 return -1 

549 

550 key = op_spec.get("key", []) 

551 default_strategy = op_spec.get("default_strategy") 

552 expand_yaml_path = yaml_path or op_spec.get("expand_yaml_path") 

553 yaml_op_name = op_spec.get("yaml_op_name", op_name) 

554 if not expand_yaml_path: 

555 return -1 

556 

557 try: 

558 expand_configs = backend.get_expand_config( 

559 op_name=yaml_op_name, 

560 file_path=expand_yaml_path, 

561 ) 

562 if not isinstance(expand_configs, list): 

563 return -1 

564 

565 gen_config = None 

566 strategy_config = None 

567 for single_config in expand_configs: 

568 if isinstance(single_config, dict) and "param_map" in single_config: 

569 gen_config = single_config 

570 

571 if isinstance(single_config, dict) and "strategy" in single_config: 

572 strategy_config = single_config.get("strategy") 

573 

574 param_map = gen_config.get("param_map") 

575 meta_map = param_map.get("META") 

576 

577 strategy = default_strategy 

578 if isinstance(strategy_config, dict): 

579 strategy = [ 

580 strategy_config.get(k, default_strategy[idx]) 

581 for idx, k in enumerate(key) 

582 ] 

583 

584 ranges = {} 

585 

586 for mapped_key in meta_map.values(): 

587 ranges[mapped_key.upper()] = gen_config[mapped_key] 

588 ranges["s"] = gen_config[param_map.get("num_stages")] 

589 ranges["w"] = gen_config[param_map.get("num_warps")] 

590 

591 return { 

592 "ranges": ranges, 

593 "strategy": strategy, 

594 } 

595 except Exception: 

596 return -1 

597 

598 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None): 

599 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path) 

600 if expand_config == -1: 

601 return [] 

602 ranges = expand_config["ranges"] 

603 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook) 

604 

605 def get_tuned_config(self, op_name): 

606 if op_name in self.loaded_triton_config: 

607 return self.loaded_triton_config[op_name] 

608 

609 current_op_configs = self._get_op_configs(op_name) 

610 if not current_op_configs: 

611 return [] 

612 

613 configs = [] 

614 

615 for single_config in current_op_configs: 

616 if self.gen_key in single_config: 

617 configs.extend(self.to_gen_config(single_config)) 

618 continue 

619 

620 current_config = copy.deepcopy(self.triton_config_default) 

621 for default_param in current_config: 

622 if default_param in single_config: 

623 current_config[default_param] = single_config[default_param] 

624 

625 configs.append(self._create_triton_config(single_config, current_config)) 

626 return configs