Coverage for src/flag_gems/__init__.py: 89%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import warnings 

2 

3import torch 

4from packaging import version 

5 

6from flag_gems import testing # noqa: F401 

7from flag_gems import runtime 

8from flag_gems.config import aten_patch_list, resolve_user_setting 

9from flag_gems.experimental_ops import * # noqa: F403 

10from flag_gems.fused import * # noqa: F403 

11from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging 

12from flag_gems.modules import * # noqa: F403 

13from flag_gems.ops import * # noqa: F403 

14from flag_gems.patches import * # noqa: F403 

15from flag_gems.runtime.register import Register 

16 

17__version__ = "4.2.1.rc.0" 

18device = runtime.device.name 

19vendor_name = runtime.device.vendor_name 

20aten_lib = torch.library.Library("aten", "IMPL") 

21registrar = Register 

22current_work_registrar = None 

23runtime.replace_customized_ops(globals()) 

24 

25 

26def torch_ge(v): 

27 return version.parse(torch.__version__) >= version.parse(v) 

28 

29 

30_FULL_CONFIG = ( 

31 ("_flash_attention_forward", flash_attention_forward), 

32 ("_log_softmax", log_softmax), 

33 ("_log_softmax_backward_data", log_softmax_backward), 

34 ("_softmax", softmax), 

35 ("_softmax_backward_data", softmax_backward), 

36 ( 

37 "_to_copy", 

38 to_copy, 

39 lambda: version.parse(torch.__version__) >= version.parse("2.4"), 

40 ), 

41 ("_unique2", _unique2), 

42 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa), 

43 ("_weight_norm_interface", weight_norm_interface), 

44 ("_weight_norm_interface_backward", weight_norm_interface_backward), 

45 ("abs", abs), 

46 ("abs_", abs_), 

47 ("absolute", absolute), 

48 ("acos", acos), 

49 ("add.Tensor", add), 

50 ("add_.Tensor", add_), 

51 ("addcdiv", addcdiv), 

52 ("addcmul", addcmul), 

53 ("addmv", addmv), 

54 ("addmv.out", addmv_out), 

55 ("addmm", addmm), 

56 ("addmm.out", addmm_out), 

57 ("addr", addr), 

58 ("all", all), 

59 ("all.dim", all_dim), 

60 ("all.dims", all_dims), 

61 ("allclose", allclose), 

62 ("amax", amax), 

63 ("angle", angle), 

64 ("any", any), 

65 ("any.dim", any_dim), 

66 ("any.dims", any_dims), 

67 ("arange", arange), 

68 ("arange.start", arange_start), 

69 ("arange.start_step", arange_start), 

70 ("argmax", argmax), 

71 ("argmin", argmin), 

72 ("atan", atan), 

73 ("atan_", atan_), 

74 ("avg_pool2d", avg_pool2d), 

75 ("avg_pool2d_backward", avg_pool2d_backward), 

76 ("baddbmm", baddbmm), 

77 ("bincount", bincount), 

78 ("bitwise_and.Scalar", bitwise_and_scalar), 

79 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor), 

80 ("bitwise_and.Tensor", bitwise_and_tensor), 

81 ("bitwise_and_.Scalar", bitwise_and_scalar_), 

82 ("bitwise_and_.Tensor", bitwise_and_tensor_), 

83 ("bitwise_left_shift", bitwise_left_shift), 

84 ("bitwise_not", bitwise_not), 

85 ("bitwise_not_", bitwise_not_), 

86 ("bitwise_or.Scalar", bitwise_or_scalar), 

87 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor), 

88 ("bitwise_or.Tensor", bitwise_or_tensor), 

89 ("bitwise_or_.Scalar", bitwise_or_scalar_), 

90 ("bitwise_or_.Tensor", bitwise_or_tensor_), 

91 ("bitwise_right_shift", bitwise_right_shift), 

92 ("bmm", bmm), 

93 ("bmm.out", bmm_out), 

94 ("cat", cat), 

95 ("celu", celu), 

96 ("celu_", celu_), 

97 ("ceil", ceil), 

98 ("ceil_", ceil_), 

99 ("ceil.out", ceil_out), 

100 ("clamp", clamp), 

101 ("clamp.Tensor", clamp_tensor), 

102 ("clamp_min", clamp_min), 

103 ("clamp_", clamp_), 

104 ("clamp_.Tensor", clamp_tensor_), 

105 ("clamp_min_", clamp_min_), 

106 ("constant_pad_nd", constant_pad_nd), 

107 # ("contiguous", contiguous), 

108 ("conv1d", conv1d), 

109 ("conv1d.padding", conv1d), 

110 ("conv2d", conv2d), 

111 ("conv2d.padding", conv2d), 

112 ("conv3d", conv3d), 

113 ("conv3d.padding", conv3d), 

114 ( 

115 "copy_", 

116 copy_, 

117 lambda: version.parse(torch.__version__) >= version.parse("2.4"), 

118 ), 

119 ("cos", cos), 

120 ("cos_", cos_), 

121 ("count_nonzero", count_nonzero), 

122 ("cummax", cummax), 

123 ("cummin", cummin), 

124 ("cumsum", cumsum), 

125 ("cumsum.out", cumsum_out), 

126 ("diag", diag), 

127 ("diag_embed", diag_embed), 

128 ("diagonal_backward", diagonal_backward), 

129 ("div.Scalar", true_divide), 

130 ("div.Scalar_mode", div_mode), 

131 ("div.Tensor", true_divide), 

132 ("div.Tensor_mode", div_mode), 

133 ("div.out", true_divide_out), 

134 ("div_.Scalar", true_divide_), 

135 ("div_.Scalar_mode", div_mode_), 

136 ("div_.Tensor", true_divide_), 

137 ("div_.Tensor_mode", div_mode_), 

138 ("divide.Scalar", true_divide), 

139 ("divide.Scalar_mode", div_mode), 

140 ("divide.Tensor", true_divide), 

141 ("divide.Tensor_mode", div_mode), 

142 ("divide_.Scalar", true_divide_), 

143 ("divide_.Scalar_mode", div_mode_), 

144 ("divide_.Tensor", true_divide_), 

145 ("divide_.Tensor_mode", div_mode_), 

146 ("dot", dot), 

147 ("elu", elu), 

148 ("elu_", elu_), 

149 ("elu_backward", elu_backward), 

150 ("embedding", embedding), 

151 ("embedding_backward", embedding_backward), 

152 ("embedding_dense_backward", embedding_dense_backward), 

153 ("eq.Scalar", eq_scalar), 

154 ("eq.Tensor", eq), 

155 ("equal", equal), 

156 ("erf", erf), 

157 ("erf_", erf_), 

158 ("exp", exp), 

159 ("exp_", exp_), 

160 ("exp.out", exp_out), 

161 ("exp2", exp2), 

162 ("exp2_", exp2_), 

163 ("exponential_", exponential_), 

164 ("eye", eye), 

165 ("eye.m", eye_m), 

166 ("fill.Scalar", fill_scalar), 

167 ("fill.Scalar_out", fill_scalar_out), 

168 ("fill.Tensor", fill_tensor), 

169 ("fill.Tensor_out", fill_tensor_out), 

170 ("fill_.Scalar", fill_scalar_), 

171 ("fill_.Tensor", fill_tensor_), 

172 ("flip", flip), 

173 ("floor_divide", floor_divide), 

174 ("floor_divide.Scalar", floor_divide), 

175 ("floor_divide_.Scalar", floor_divide_), 

176 ("floor_divide_.Tensor", floor_divide_), 

177 ("full", full), 

178 ("full_like", full_like), 

179 ("gather", gather), 

180 ("gather_backward", gather_backward), 

181 ("ge.Scalar", ge_scalar), 

182 ("ge.Tensor", ge), 

183 ("gelu", gelu), 

184 ("gelu_", gelu_), 

185 ("gelu_backward", gelu_backward), 

186 ("glu", glu), 

187 ("glu_backward", glu_backward), 

188 ("gt.Scalar", gt_scalar), 

189 ("gt.Tensor", gt), 

190 ("hstack", hstack), 

191 ("hypot", hypot), 

192 ("i0", i0), 

193 ("i0.out", i0_out), 

194 ("index.Tensor", index), 

195 ("index_add", index_add), 

196 ("index_add_", index_add_), 

197 ("index_put", index_put), 

198 ("index_put_", index_put_), 

199 ("index_select", index_select), 

200 ("isclose", isclose), 

201 ("isfinite", isfinite), 

202 ("isin.Scalar_Tensor", isin), 

203 ("isin.Tensor_Scalar", isin), 

204 ("isin.Tensor_Tensor", isin), 

205 ("isinf", isinf), 

206 ("isnan", isnan), 

207 ("kron", kron), 

208 ("le.Scalar", le_scalar), 

209 ("le.Tensor", le), 

210 ("lerp.Scalar", lerp_scalar), 

211 ("lerp.Tensor", lerp_tensor), 

212 ("lerp_.Scalar", lerp_scalar_), 

213 ("lerp_.Tensor", lerp_tensor_), 

214 ("lift_fresh_copy", lift_fresh_copy), 

215 ("linalg_vector_norm", vector_norm), 

216 ("linspace", linspace), 

217 ("log", log), 

218 ("log_sigmoid", log_sigmoid), 

219 ("logical_and", logical_and), 

220 ("logical_and_", logical_and_), 

221 ("logical_not", logical_not), 

222 ("logical_or", logical_or), 

223 ("logical_or_", logical_or_), 

224 ("logical_xor", logical_xor), 

225 ("logspace", logspace), 

226 ("lt.Scalar", lt_scalar), 

227 ("lt.Tensor", lt), 

228 ("masked_fill.Scalar", masked_fill), 

229 ("masked_fill.Tensor", masked_fill), 

230 ("masked_fill_.Scalar", masked_fill_), 

231 ("masked_fill_.Tensor", masked_fill_), 

232 ("masked_scatter", masked_scatter), 

233 ("masked_scatter_", masked_scatter_), 

234 ("masked_select", masked_select), 

235 ("max", max), 

236 ("max.dim", max_dim), 

237 ("maximum", maximum), 

238 ("max_pool2d_with_indices", max_pool2d_with_indices), 

239 ("max_pool2d_backward", max_pool2d_backward), 

240 ("mean", mean), 

241 ("mean.dim", mean_dim), 

242 ("min", min), 

243 ("min.dim", min_dim), 

244 ("minimum", minimum), 

245 ("mm", mm), 

246 ("mm.out", mm_out), 

247 ("mse_loss", mse_loss), 

248 ("mul.Tensor", mul), 

249 ("mul_.Tensor", mul_), 

250 ("multinomial", multinomial), 

251 ("mv", mv), 

252 ("nan_to_num", nan_to_num), 

253 ("native_batch_norm", batch_norm), 

254 ("native_batch_norm_backward", batch_norm_backward), 

255 ("native_dropout", dropout), 

256 ("native_dropout_backward", dropout_backward), 

257 ("native_group_norm", group_norm), 

258 ("native_group_norm_backward", group_norm_backward), 

259 ("native_layer_norm", layer_norm), 

260 ("native_layer_norm_backward", layer_norm_backward), 

261 ("ne.Scalar", ne_scalar), 

262 ("ne.Tensor", ne), 

263 ("neg", neg), 

264 ("neg_", neg_), 

265 ("nll_loss_backward", nll_loss_backward), 

266 ("nll_loss_forward", nll_loss_forward), 

267 ("nll_loss2d_backward", nll_loss2d_backward), 

268 ("nll_loss2d_forward", nll_loss2d_forward), 

269 ("nll_loss_nd_forward", nll_loss_nd_forward), 

270 ("nll_loss_nd_backward", nll_loss_nd_backward), 

271 ("nonzero", nonzero), 

272 ("normal.float_Tensor", normal_float_tensor), 

273 ("normal.Tensor_float", normal_tensor_float), 

274 ("normal.Tensor_Tensor", normal_tensor_tensor), 

275 ("normal_", normal_), 

276 ("ones", ones), 

277 ("ones_like", ones_like), 

278 ("one_hot", one_hot), 

279 ("pad", pad), 

280 ("polar", polar), 

281 ("pow.Scalar", pow_scalar), 

282 ("pow.Tensor_Scalar", pow_tensor_scalar), 

283 ("pow.Tensor_Tensor", pow_tensor_tensor), 

284 ("pow_.Scalar", pow_tensor_scalar_), 

285 ("pow_.Tensor", pow_tensor_tensor_), 

286 ("prod", prod), 

287 ("prod.dim_int", prod_dim), 

288 ("quantile", quantile), 

289 ("rand", rand), 

290 ("rand_like", rand_like), 

291 ("randn", randn), 

292 ("randn_like", randn_like), 

293 ("randperm", randperm), 

294 ("reciprocal", reciprocal), 

295 ("reciprocal_", reciprocal_), 

296 ("relu", relu), 

297 ("relu_", relu_), 

298 ("remainder.Scalar", remainder), 

299 ("remainder.Scalar_Tensor", remainder), 

300 ("remainder.Tensor", remainder), 

301 ("remainder_.Scalar", remainder_), 

302 ("remainder_.Tensor", remainder_), 

303 ("repeat", repeat), 

304 ("repeat_interleave.self_int", repeat_interleave_self_int), 

305 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor), 

306 ("repeat_interleave.Tensor", repeat_interleave_tensor), 

307 ("replication_pad3d", replication_pad3d), 

308 ("resolve_conj", resolve_conj), 

309 ("resolve_neg", resolve_neg), 

310 ("rms_norm", rms_norm), 

311 ("rsqrt", rsqrt), 

312 ("rsqrt_", rsqrt_), 

313 ("scaled_softmax_backward", scaled_softmax_backward), 

314 ("scaled_softmax_forward", scaled_softmax_forward), 

315 ("scatter.reduce", scatter), 

316 ("scatter.src", scatter), 

317 ("scatter_.reduce", scatter_), 

318 ("scatter_.src", scatter_), 

319 ("scatter_add_", scatter_add_), 

320 ("select_scatter", select_scatter), 

321 ("sgn_", sgn_), 

322 ("sigmoid", sigmoid), 

323 ("sigmoid_", sigmoid_), 

324 ("sigmoid_backward", sigmoid_backward), 

325 ("silu", silu), 

326 ("silu_", silu_), 

327 ("silu_backward", silu_backward), 

328 ("sin", sin), 

329 ("sin_", sin_), 

330 ("sinh_", sinh_), 

331 ("slice_backward", slice_backward), 

332 ("slice_scatter", slice_scatter), 

333 ("softplus", softplus), 

334 ("sort", sort), 

335 ("sort.stable", sort_stable), 

336 ("sqrt", sqrt), 

337 ("sqrt_", sqrt_), 

338 ("stack", stack), 

339 ("std.correction", std), 

340 ("sub.Tensor", sub), 

341 ("sub_.Tensor", sub_), 

342 ("sum", sum), 

343 ("sum.dim_IntList", sum_dim), 

344 ("sum.IntList_out", sum_dim_out), 

345 ("sum.out", sum_out), 

346 ("tan", tan), 

347 ("tan_", tan_), 

348 ("tanh", tanh), 

349 ("tanh_", tanh_), 

350 ("tanh_backward", tanh_backward), 

351 ("threshold", threshold), 

352 ("threshold_backward", threshold_backward), 

353 ("tile", tile), 

354 ("topk", topk), 

355 ("trace", trace), 

356 ("triu", triu), 

357 ("triu_", triu_), 

358 ("true_divide.Scalar", true_divide), 

359 ("true_divide.Tensor", true_divide), 

360 ("true_divide_.Scalar", true_divide_), 

361 ("true_divide_.Tensor", true_divide_), 

362 ("unfold_backward", unfold_backward), 

363 ("uniform_", uniform_), 

364 ("upsample_bicubic2d", upsample_bicubic2d), 

365 ("upsample_linear1d", upsample_linear1d), 

366 ("upsample_nearest1d", upsample_nearest1d), 

367 ("upsample_nearest2d", upsample_nearest2d), 

368 ("upsample_nearest3d", upsample_nearest3d), 

369 ("var_mean.correction", var_mean), 

370 ("vdot", vdot), 

371 ("vstack", vstack), 

372 ("where.self", where_self), 

373 ("where.self_out", where_self_out), 

374 ("zeros", zeros), 

375 ("zero_", zero_), 

376 ("zeros_like", zeros_like), 

377) 

378 

379# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup 

380FULL_CONFIG_BY_FUNC = {} 

381for _item in _FULL_CONFIG: 

382 if not _item or len(_item) < 2: 

383 continue 

384 fn = _item[1] 

385 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn) 

386 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item) 

387 

388 

389def enable( 

390 lib=aten_lib, 

391 unused=None, 

392 registrar=registrar, 

393 record=False, 

394 once=False, 

395 path=None, 

396): 

397 """Register all FlagGems ops except those explicitly excluded. 

398 

399 Args: 

400 lib: torch.library.Library instance to register into. Defaults to the 

401 global `aten_lib` (IMPL mode). 

402 unused: Which ops to skip. Supported forms: 

403 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]). 

404 - str path to a YAML file ending with .yml/.yaml containing an 

405 `exclude:` list. 

406 - "default" or None: auto-load vendor/arch-specific 

407 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present. 

408 registrar: Registrar class; defaults to `Register`. 

409 record: Whether to enable FlagGems logging. 

410 once: When True, log only once. 

411 path: Optional log output path when recording. 

412 

413 Notes: 

414 - If the exclude list/YAML resolves to empty, all ops are registered. 

415 """ 

416 global current_work_registrar 

417 exclude_ops = resolve_user_setting(unused, "exclude") 

418 current_work_registrar = registrar( 

419 _FULL_CONFIG, 

420 user_include_ops=[], 

421 user_exclude_ops=exclude_ops, 

422 cpp_patched_ops=list(set(aten_patch_list)), 

423 lib=lib, 

424 ) 

425 setup_flaggems_logging(path=path, record=record, once=once) 

426 

427 

428def only_enable( 

429 lib=aten_lib, 

430 include=None, 

431 registrar=registrar, 

432 record=False, 

433 once=False, 

434 path=None, 

435): 

436 """Register only the specified FlagGems ops and skip the rest. 

437 

438 Args: 

439 lib: torch.library.Library instance to register into. Defaults to the 

440 global `aten_lib` (IMPL mode). 

441 include: Which ops to register. Supported forms: 

442 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]). 

443 - str path to a YAML file ending with .yml/.yaml (expects a list or 

444 an `include:` key). 

445 - "default" or None: auto-load vendor/arch-specific 

446 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present. 

447 registrar: Registrar class; defaults to `Register`. 

448 record: Whether to enable FlagGems logging. 

449 once: When True, log only once. 

450 path: Optional log output path when recording. 

451 

452 Classic usage: 

453 - Only register a few ops: 

454 only_enable(include=["rms_norm", "softmax"]) 

455 - Use vendor default YAML: 

456 only_enable(include="default") # or include=None 

457 - Use a custom YAML: 

458 only_enable(include="/path/to/only_enable.yaml") 

459 

460 Notes: 

461 - If the include list/YAML resolves to empty or none of the names match 

462 known ops, the function warns and returns without registering. 

463 """ 

464 include_ops = resolve_user_setting(include, "include") 

465 if not include_ops: 

466 warnings.warn( 

467 "only_enable failed: No include entries resolved from list or yaml." 

468 ) 

469 return 

470 

471 global current_work_registrar 

472 current_work_registrar = registrar( 

473 _FULL_CONFIG, 

474 user_include_ops=include_ops, 

475 user_exclude_ops=[], 

476 cpp_patched_ops=list(set(aten_patch_list)), 

477 full_config_by_func=FULL_CONFIG_BY_FUNC, 

478 lib=lib, 

479 ) 

480 setup_flaggems_logging(path=path, record=record, once=once) 

481 

482 

483class use_gems: 

484 """ 

485 The 'include' parameter has higher priority than 'exclude'. 

486 When 'include' is not None, use_gems will not process 'exclude'. 

487 """ 

488 

489 def __init__(self, exclude=None, include=None, record=False, once=False, path=None): 

490 self.lib = torch.library.Library("aten", "IMPL") 

491 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else [] 

492 self.include = include if isinstance(include, (list, tuple, set, str)) else [] 

493 self.registrar = Register 

494 self.record = record 

495 self.once = once 

496 self.path = path 

497 

498 def __enter__(self): 

499 if self.include: 

500 only_enable( 

501 lib=self.lib, 

502 include=self.include, 

503 registrar=self.registrar, 

504 record=self.record, 

505 once=self.once, 

506 path=self.path, 

507 ) 

508 else: 

509 enable( 

510 lib=self.lib, 

511 unused=self.exclude, 

512 registrar=self.registrar, 

513 record=self.record, 

514 once=self.once, 

515 path=self.path, 

516 ) 

517 

518 def __exit__(self, exc_type, exc_val, exc_tb): 

519 global current_work_registrar 

520 if torch.__version__ >= "2.5": 

521 self.lib._destroy() 

522 del self.lib 

523 del self.exclude 

524 del self.include 

525 del self.registrar 

526 del current_work_registrar 

527 if self.record: 

528 teardown_flaggems_logging() 

529 

530 @property 

531 def experimental_ops(self): 

532 import flag_gems.experimental_ops 

533 

534 return flag_gems.experimental_ops 

535 

536 

537def all_registered_ops(): 

538 return current_work_registrar.get_all_ops() 

539 

540 

541def all_registered_keys(): 

542 return current_work_registrar.get_all_keys() 

543 

544 

545__all__ = [ 

546 "enable", 

547 "only_enable", 

548 "use_gems", 

549 "all_registered_ops", 

550 "all_registered_keys", 

551]