Coverage for src/flag_gems/runtime/configloader.py: 62%

201 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import copy 

2import warnings 

3 

4import triton 

5 

6from . import backend, common 

7from .backend.device import DeviceDetector 

8 

9 

10class ConfigLoader(object): 

11 _instance = None 

12 

13 def __new__(cls, *args, **kargs): 

14 if cls._instance is None: 

15 cls._instance = super(ConfigLoader, cls).__new__(cls) 

16 return cls._instance 

17 

18 def __init__(self): 

19 if not hasattr(self, "initialized"): 

20 self.initialized = True 

21 self.device = DeviceDetector() 

22 # primitive_yaml_config is simply the dictionary returned by yaml 

23 # and is reserved from being an attr for vendor customizability 

24 self.arch_specialized_yaml_config = None 

25 self.arch_heuristics_config = None 

26 self.vendor_primitive_yaml_config = self.get_vendor_tune_config() 

27 self.default_primitive_yaml_config = self.get_default_tune_config() 

28 self.vendor_heuristics_config = self.get_vendor_heuristics_config() 

29 self.default_heuristics_config = self.get_default_heuristics_config() 

30 self.update_config_from_arch() 

31 

32 if self.vendor_heuristics_config is None: 

33 vendorname = self.device.vendor_name 

34 warnings.warn( 

35 f"The {vendorname} configuration of heuristics_config is None" 

36 ) 

37 # gen_key is an identifier that indicates whether the current config needs to be generated automatically 

38 self.gen_key = "gen" 

39 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config 

40 self.loaded_triton_config = {} 

41 self.triton_config_default = { 

42 "num_stages": 2, 

43 "num_warps": 4, 

44 "num_ctas": 1, 

45 } 

46 if self.device.vendor_name == "hygon": 

47 self.triton_config_default["num_ldmatrixes"] = 0 

48 self.expand_config_registry = self._build_expand_registry() 

49 self.load_all() 

50 

51 def update_config_from_arch(self): 

52 try: 

53 archEvent = backend.BackendArchEvent() 

54 if archEvent.has_arch: 

55 self.arch_specialized_yaml_config = archEvent.autotune_configs 

56 self.arch_heuristics_config = archEvent.heuristics_configs 

57 except Exception as err: 

58 print(f"[INFO] : {err}") 

59 

60 def _get_op_configs(self, op_name): 

61 """Get config for op_name from available config sources.""" 

62 for config in ( 

63 self.arch_specialized_yaml_config, 

64 self.vendor_primitive_yaml_config, 

65 self.default_primitive_yaml_config, 

66 ): 

67 if config and op_name in config: 

68 return config[op_name] 

69 return [] 

70 

71 def _create_triton_config(self, single_config, current_config): 

72 """Create a triton.Config with appropriate parameters.""" 

73 kwargs = { 

74 "num_warps": current_config["num_warps"], 

75 "num_stages": current_config["num_stages"], 

76 "num_ctas": current_config["num_ctas"], 

77 } 

78 if self.device.vendor_name == "hygon": 

79 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"] 

80 return triton.Config(single_config["META"], **kwargs) 

81 

82 def _build_configs_by_op(self, op_name, ranges, pre_hook=None): 

83 if op_name == "bmm": 

84 return [ 

85 triton.Config( 

86 { 

87 "TILE_M": block_m, 

88 "TILE_N": block_n, 

89 "TILE_K": block_k, 

90 "GROUP_M": 1 if block_m == 32 else 2, 

91 }, 

92 num_stages=s, 

93 num_warps=w, 

94 pre_hook=pre_hook, 

95 ) 

96 for block_m in ranges["BLOCK_M"] 

97 for block_n in ranges["BLOCK_N"] 

98 for block_k in ranges["BLOCK_K"] 

99 for s in ranges["s"] 

100 for w in ranges["w"] 

101 ] 

102 

103 if op_name == "addmm": 

104 return [ 

105 triton.Config( 

106 { 

107 "BLOCK_SIZE_M": block_m, 

108 "BLOCK_SIZE_N": block_n, 

109 "BLOCK_SIZE_K": block_k, 

110 }, 

111 num_stages=s, 

112 num_warps=w, 

113 pre_hook=pre_hook, 

114 ) 

115 for block_m in ranges["BLOCK_M"] 

116 for block_n in ranges["BLOCK_N"] 

117 for block_k in ranges["BLOCK_K"] 

118 for s in ranges["s"] 

119 for w in ranges["w"] 

120 ] 

121 

122 if op_name == "baddbmm": 

123 return [ 

124 triton.Config( 

125 { 

126 "TILE_M": block_m, 

127 "TILE_N": block_n, 

128 "TILE_K": block_k, 

129 "GROUP_M": 1 if block_m <= 32 else 2, 

130 }, 

131 num_stages=s, 

132 num_warps=w, 

133 pre_hook=pre_hook, 

134 ) 

135 for block_m in ranges["BLOCK_M"] 

136 for block_n in ranges["BLOCK_N"] 

137 for block_k in ranges["BLOCK_K"] 

138 for s in ranges["s"] 

139 for w in ranges["w"] 

140 ] 

141 

142 if op_name == "mv": 

143 return [ 

144 triton.Config( 

145 { 

146 "BLOCK_N": block_n, 

147 "BLOCK_M": block_m, 

148 }, 

149 num_stages=s, 

150 num_warps=w, 

151 pre_hook=pre_hook, 

152 ) 

153 for block_n in ranges["BLOCK_N"] 

154 for block_m in ranges["BLOCK_M"] 

155 for s in ranges["s"] 

156 for w in ranges["w"] 

157 ] 

158 

159 if op_name == "mm_general_tma": 

160 return [ 

161 triton.Config( 

162 { 

163 "BLOCK_M": block_m, 

164 "BLOCK_N": block_n, 

165 "BLOCK_K": block_k, 

166 }, 

167 num_stages=s, 

168 num_warps=w, 

169 pre_hook=pre_hook, 

170 ) 

171 for block_m in ranges["BLOCK_M"] 

172 for block_n in ranges["BLOCK_N"] 

173 for block_k in ranges["BLOCK_K"] 

174 for s in ranges["s"] 

175 for w in ranges["w"] 

176 ] 

177 

178 if op_name in ("mm", "mm_sqmma"): 

179 return [ 

180 triton.Config( 

181 { 

182 "BLOCK_M": block_m, 

183 "BLOCK_N": block_n, 

184 "BLOCK_K": block_k, 

185 }, 

186 num_stages=s, 

187 num_warps=w, 

188 pre_hook=pre_hook, 

189 ) 

190 for block_m in ranges["BLOCK_M"] 

191 for block_n in ranges["BLOCK_N"] 

192 for block_k in ranges["BLOCK_K"] 

193 for s in ranges["s"] 

194 for w in ranges["w"] 

195 ] 

196 

197 if op_name in ("bmm_sqmma", "addmm_sqmma"): 

198 return [ 

199 triton.Config( 

200 { 

201 "BLOCK_SIZE_M": block_m, 

202 "BLOCK_SIZE_N": block_n, 

203 "BLOCK_SIZE_K": block_k, 

204 }, 

205 num_stages=s, 

206 num_warps=w, 

207 pre_hook=pre_hook, 

208 ) 

209 for block_m in ranges["BLOCK_M"] 

210 for block_n in ranges["BLOCK_N"] 

211 for block_k in ranges["BLOCK_K"] 

212 for s in ranges["s"] 

213 for w in ranges["w"] 

214 ] 

215 

216 if op_name == "gemv": 

217 return [ 

218 triton.Config( 

219 {"BLOCK_M": block_m, "BLOCK_K": block_k}, 

220 num_stages=s, 

221 num_warps=w, 

222 pre_hook=pre_hook, 

223 ) 

224 for block_m in ranges["BLOCK_M"] 

225 for block_k in ranges["BLOCK_K"] 

226 for s in ranges["s"] 

227 for w in ranges["w"] 

228 ] 

229 

230 if op_name == "sparse_attention": 

231 return [ 

232 triton.Config( 

233 {"BLOCK": block}, 

234 num_stages=s, 

235 num_warps=w, 

236 pre_hook=pre_hook, 

237 ) 

238 for block in ranges["BLOCK"] 

239 for s in ranges["s"] 

240 for w in ranges["w"] 

241 ] 

242 

243 if op_name == "w8a8_block_fp8_general": 

244 return [ 

245 triton.Config( 

246 { 

247 "BLOCK_M": block_m, 

248 "BLOCK_N": block_n, 

249 "BLOCK_K": block_k, 

250 "GROUP_M": group_m, 

251 }, 

252 num_stages=s, 

253 num_warps=w, 

254 pre_hook=pre_hook, 

255 ) 

256 for block_m in ranges["BLOCK_M"] 

257 for block_n in ranges["BLOCK_N"] 

258 for block_k in ranges["BLOCK_K"] 

259 for group_m in ranges["GROUP_M"] 

260 for s in ranges["s"] 

261 for w in ranges["w"] 

262 ] 

263 

264 if op_name == "w8a8_block_fp8_general_tma": 

265 group_m_values = ranges.get("GROUP_M", [None]) 

266 return [ 

267 triton.Config( 

268 dict( 

269 { 

270 "BLOCK_M": block_m, 

271 "BLOCK_N": block_n, 

272 "BLOCK_K": block_k, 

273 }, 

274 **({} if group_m is None else {"GROUP_M": group_m}), 

275 ), 

276 num_stages=s, 

277 num_warps=w, 

278 pre_hook=pre_hook, 

279 ) 

280 for block_m in ranges["BLOCK_M"] 

281 for block_n in ranges["BLOCK_N"] 

282 for block_k in ranges["BLOCK_K"] 

283 for group_m in group_m_values 

284 for s in ranges["s"] 

285 for w in ranges["w"] 

286 ] 

287 

288 if op_name == "w8a8_block_fp8_general_splitk": 

289 return [ 

290 triton.Config( 

291 { 

292 "BLOCK_M": block_m, 

293 "BLOCK_N": block_n, 

294 "BLOCK_K": block_k, 

295 "SPLIT_K": split_k, 

296 }, 

297 num_stages=s, 

298 num_warps=w, 

299 pre_hook=pre_hook, 

300 ) 

301 for block_m in ranges["BLOCK_M"] 

302 for block_n in ranges["BLOCK_N"] 

303 for block_k in ranges["BLOCK_K"] 

304 for split_k in ranges["SPLIT_K"] 

305 for s in ranges["s"] 

306 for w in ranges["w"] 

307 ] 

308 

309 return [] 

310 

311 def _build_single_expand_spec( 

312 self, 

313 op_name, 

314 expand_yaml_path=None, 

315 yaml_op_name=None, 

316 ): 

317 return { 

318 "yaml_op_name": yaml_op_name or op_name, 

319 "key": common.OP_KEY_ORDERS[op_name], 

320 "default_strategy": common.DEFAULT_STRATEGIES[op_name], 

321 "expand_yaml_path": expand_yaml_path, 

322 } 

323 

324 def _build_expand_registry(self): 

325 DEFAULT_EXPAND_CONFIG_PATH = common.DEFAULT_EXPAND_CONFIG_PATH 

326 return { 

327 "bmm": self._build_single_expand_spec( 

328 "bmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

329 ), 

330 "addmm": self._build_single_expand_spec( 

331 "addmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

332 ), 

333 "baddbmm": self._build_single_expand_spec( 

334 "baddbmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

335 ), 

336 "mv": self._build_single_expand_spec( 

337 "mv", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

338 ), 

339 "w8a8_block_fp8_general": self._build_single_expand_spec( 

340 "w8a8_block_fp8_general" 

341 ), 

342 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec( 

343 "w8a8_block_fp8_general_splitk" 

344 ), 

345 "w8a8_block_fp8_general_tma": self._build_single_expand_spec( 

346 "w8a8_block_fp8_general_tma" 

347 ), 

348 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"), 

349 "gemv": self._build_single_expand_spec("gemv"), 

350 "sparse_attention": self._build_single_expand_spec("sparse_attention"), 

351 "mm": self._build_single_expand_spec("mm"), 

352 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"), 

353 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"), 

354 } 

355 

356 def load_all(self): 

357 for key in self.vendor_primitive_yaml_config: 

358 self.loaded_triton_config[key] = self.get_tuned_config(key) 

359 

360 def get_vendor_heuristics_config(self): 

361 return backend.get_heuristic_config(self.device.vendor_name) 

362 

363 def get_default_heuristics_config(self): 

364 return backend.get_heuristic_config("nvidia") 

365 

366 def get_default_tune_config(self): 

367 return backend.get_tune_config("nvidia") 

368 

369 def get_vendor_tune_config(self): 

370 return backend.get_tune_config(self.device.vendor_name) 

371 

372 def get_heuristics_config(self, op_name): 

373 if self.arch_heuristics_config and op_name in self.arch_heuristics_config: 

374 return self.arch_heuristics_config[op_name] 

375 elif op_name in self.vendor_heuristics_config: 

376 return self.vendor_heuristics_config[op_name] 

377 elif op_name in self.default_heuristics_config: 

378 return self.default_heuristics_config[op_name] 

379 else: 

380 warnings.warn(f"No heuristics config found for {op_name}") 

381 return None 

382 

383 def _resolve_iteration_values(self, gen_config, config_var_key): 

384 if isinstance(config_var_key, (list, tuple)): 

385 return config_var_key 

386 if isinstance(config_var_key, int): 

387 return [config_var_key] 

388 return gen_config[config_var_key] 

389 

390 def _gen_impl( 

391 self, 

392 gen_config, 

393 iteration_plan, 

394 std_config, 

395 ): 

396 all_configs = [] 

397 final_step = len(iteration_plan) 

398 stack = [{"cur_config": std_config, "current_step": 0}] 

399 

400 while stack: 

401 cur_state = stack[-1] 

402 stack.pop() 

403 cur_config = cur_state.get("cur_config") 

404 current_step = cur_state.get("current_step") 

405 

406 if current_step == final_step: 

407 all_configs.append( 

408 triton.Config( 

409 cur_config["META"], 

410 num_warps=cur_config["num_warps"], 

411 num_stages=cur_config["num_stages"], 

412 num_ctas=cur_config["num_ctas"], 

413 ) 

414 ) 

415 else: 

416 cur_entry = iteration_plan[current_step] 

417 cur_key = cur_entry["key"] 

418 key_config = self._resolve_iteration_values( 

419 gen_config, cur_entry["source"] 

420 ) 

421 for single_value in key_config: 

422 new_config = copy.deepcopy(cur_config) 

423 if cur_entry["kind"] == "meta_field": 

424 new_config["META"][cur_key] = single_value 

425 elif cur_entry["kind"] == "meta_block": 

426 new_config["META"] = copy.deepcopy(single_value) 

427 else: 

428 new_config[cur_key] = single_value 

429 stack.append( 

430 { 

431 "cur_config": new_config, 

432 "current_step": current_step + 1, 

433 } 

434 ) 

435 return all_configs 

436 

437 def to_gen_config(self, gen_config): 

438 param_config = gen_config["param_map"] 

439 meta_config = param_config["META"] 

440 iteration_plan = [] 

441 

442 if isinstance(meta_config, dict): 

443 for meta_key, source in meta_config.items(): 

444 iteration_plan.append( 

445 {"key": meta_key, "source": source, "kind": "meta_field"} 

446 ) 

447 else: 

448 iteration_plan.append( 

449 {"key": "META", "source": meta_config, "kind": "meta_block"} 

450 ) 

451 

452 for key, source in param_config.items(): 

453 if key == "META": 

454 continue 

455 iteration_plan.append( 

456 {"key": key, "source": source, "kind": "config_field"} 

457 ) 

458 

459 current_config = {"META": {}} 

460 current_config.update(self.triton_config_default) 

461 return self._gen_impl( 

462 gen_config, 

463 iteration_plan, 

464 current_config, 

465 ) 

466 

467 def get_expand_config(self, op_name, yaml_path=None): 

468 op_spec = self.expand_config_registry.get(op_name) 

469 if op_spec is None: 

470 return -1 

471 

472 key = op_spec.get("key", []) 

473 default_strategy = op_spec.get("default_strategy") 

474 expand_yaml_path = op_spec.get("expand_yaml_path") or yaml_path 

475 yaml_op_name = op_spec.get("yaml_op_name", op_name) 

476 

477 try: 

478 expand_configs = backend.get_expand_config( 

479 op_name=yaml_op_name, 

480 file_path=expand_yaml_path, 

481 ) 

482 if not isinstance(expand_configs, list): 

483 return -1 

484 

485 gen_config = None 

486 strategy_config = None 

487 for single_config in expand_configs: 

488 if isinstance(single_config, dict) and "param_map" in single_config: 

489 gen_config = single_config 

490 

491 if isinstance(single_config, dict) and "strategy" in single_config: 

492 strategy_config = single_config.get("strategy") 

493 

494 param_map = gen_config.get("param_map") 

495 meta_map = param_map.get("META") 

496 

497 strategy = default_strategy 

498 if isinstance(strategy_config, dict): 

499 strategy = [ 

500 strategy_config.get(k, default_strategy[idx]) 

501 for idx, k in enumerate(key) 

502 ] 

503 

504 ranges = {} 

505 

506 for mapped_key in meta_map.values(): 

507 ranges[mapped_key.upper()] = gen_config[mapped_key] 

508 ranges["s"] = gen_config[param_map.get("num_stages")] 

509 ranges["w"] = gen_config[param_map.get("num_warps")] 

510 

511 return { 

512 "ranges": ranges, 

513 "strategy": strategy, 

514 } 

515 except Exception: 

516 return -1 

517 

518 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None): 

519 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path) 

520 if expand_config == -1: 

521 return [] 

522 ranges = expand_config["ranges"] 

523 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook) 

524 

525 def get_tuned_config(self, op_name): 

526 if op_name in self.loaded_triton_config: 

527 return self.loaded_triton_config[op_name] 

528 

529 current_op_configs = self._get_op_configs(op_name) 

530 if not current_op_configs: 

531 return [] 

532 

533 configs = [] 

534 

535 for single_config in current_op_configs: 

536 if self.gen_key in single_config: 

537 configs.extend(self.to_gen_config(single_config)) 

538 continue 

539 

540 current_config = copy.deepcopy(self.triton_config_default) 

541 for default_param in current_config: 

542 if default_param in single_config: 

543 current_config[default_param] = single_config[default_param] 

544 

545 configs.append(self._create_triton_config(single_config, current_config)) 

546 return configs