Coverage for src/flag_gems/__init__.py: 89%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import warnings
3import torch
4from packaging import version
6from flag_gems import testing # noqa: F401
7from flag_gems import runtime
8from flag_gems.config import aten_patch_list, resolve_user_setting
9from flag_gems.experimental_ops import * # noqa: F403
10from flag_gems.fused import * # noqa: F403
11from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
12from flag_gems.modules import * # noqa: F403
13from flag_gems.ops import * # noqa: F403
14from flag_gems.patches import * # noqa: F403
15from flag_gems.runtime.register import Register
17__version__ = "5.0.1.rc.0"
18device = runtime.device.name
19vendor_name = runtime.device.vendor_name
20aten_lib = torch.library.Library("aten", "IMPL")
21registrar = Register
22current_work_registrar = None
23runtime.replace_customized_ops(globals())
26def torch_ge(v):
27 return version.parse(torch.__version__) >= version.parse(v)
30_FULL_CONFIG = (
31 ("_flash_attention_forward", flash_attention_forward),
32 (
33 "_functional_sym_constrain_range_for_size",
34 _functional_sym_constrain_range_for_size,
35 ),
36 ("_log_softmax", log_softmax),
37 ("_log_softmax_backward_data", log_softmax_backward),
38 ("_safe_softmax", _safe_softmax),
39 ("_softmax", softmax),
40 ("_softmax_backward_data", softmax_backward),
41 (
42 "_to_copy",
43 to_copy,
44 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
45 ),
46 ("_unique2", _unique2),
47 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
48 ("_upsample_nearest_exact1d", _upsample_nearest_exact1d),
49 ("_weight_norm_interface", weight_norm_interface),
50 ("_weight_norm_interface_backward", weight_norm_interface_backward),
51 ("abs", abs),
52 ("abs_", abs_),
53 ("absolute", absolute),
54 ("acos", acos),
55 ("add.Tensor", add),
56 ("add_.Tensor", add_),
57 ("addcdiv", addcdiv),
58 ("addcmul", addcmul),
59 ("addmv", addmv),
60 ("addmv.out", addmv_out),
61 ("addmm", addmm),
62 ("addmm.out", addmm_out),
63 ("addr", addr),
64 ("alias_copy", alias_copy),
65 ("all", all),
66 ("all.dim", all_dim),
67 ("all.dims", all_dims),
68 ("allclose", allclose),
69 ("amax", amax),
70 ("angle", angle),
71 ("any", any),
72 ("any.dim", any_dim),
73 ("any.dims", any_dims),
74 ("arange", arange),
75 ("arange.start", arange_start),
76 ("arange.start_step", arange_start),
77 ("arcsinh", arcsinh),
78 ("arcsinh.out", arcsinh_out),
79 ("arcsinh_", arcsinh_),
80 ("argmax", argmax),
81 ("argmin", argmin),
82 ("asinh_", asinh_),
83 ("atan", atan),
84 ("atan_", atan_),
85 ("arctanh_", arctanh_),
86 ("avg_pool2d", avg_pool2d),
87 ("avg_pool2d_backward", avg_pool2d_backward),
88 ("baddbmm", baddbmm),
89 ("bincount", bincount),
90 ("bitwise_and.Scalar", bitwise_and_scalar),
91 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
92 ("bitwise_and.Tensor", bitwise_and_tensor),
93 ("bitwise_and_.Scalar", bitwise_and_scalar_),
94 ("bitwise_and_.Tensor", bitwise_and_tensor_),
95 ("bitwise_left_shift", bitwise_left_shift),
96 ("bitwise_not", bitwise_not),
97 ("bitwise_not_", bitwise_not_),
98 ("bitwise_or.Scalar", bitwise_or_scalar),
99 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
100 ("bitwise_or.Tensor", bitwise_or_tensor),
101 ("bitwise_or_.Scalar", bitwise_or_scalar_),
102 ("bitwise_or_.Tensor", bitwise_or_tensor_),
103 ("bitwise_right_shift", bitwise_right_shift),
104 ("bmm", bmm),
105 ("bmm.out", bmm_out),
106 ("cat", cat),
107 ("celu", celu),
108 ("celu_", celu_),
109 ("ceil", ceil),
110 ("ceil_", ceil_),
111 ("ceil.out", ceil_out),
112 ("clamp", clamp),
113 ("clamp.Tensor", clamp_tensor),
114 ("clamp_min", clamp_min),
115 ("clamp_", clamp_),
116 ("clamp_.Tensor", clamp_tensor_),
117 ("clamp_min_", clamp_min_),
118 ("constant_pad_nd", constant_pad_nd),
119 # ("contiguous", contiguous),
120 ("conv1d", conv1d),
121 ("conv1d.padding", conv1d),
122 ("conv2d", conv2d),
123 ("conv2d.padding", conv2d),
124 ("conv3d", conv3d),
125 ("conv3d.padding", conv3d),
126 (
127 "copy_",
128 copy_,
129 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
130 ),
131 ("cos", cos),
132 ("cos_", cos_),
133 ("count_nonzero", count_nonzero),
134 ("cummax", cummax),
135 ("cummin", cummin),
136 ("cumsum", cumsum),
137 ("cumsum.out", cumsum_out),
138 ("diag", diag),
139 ("diag_embed", diag_embed),
140 ("diagonal_backward", diagonal_backward),
141 ("digamma_", digamma_),
142 ("div.Scalar", true_divide),
143 ("div.Scalar_mode", div_mode),
144 ("div.Tensor", true_divide),
145 ("div.Tensor_mode", div_mode),
146 ("div.out", true_divide_out),
147 ("div_.Scalar", true_divide_),
148 ("div_.Scalar_mode", div_mode_),
149 ("div_.Tensor", true_divide_),
150 ("div_.Tensor_mode", div_mode_),
151 ("divide.Scalar", true_divide),
152 ("divide.Scalar_mode", div_mode),
153 ("divide.Tensor", true_divide),
154 ("divide.Tensor_mode", div_mode),
155 ("divide_.Scalar", true_divide_),
156 ("divide_.Scalar_mode", div_mode_),
157 ("divide_.Tensor", true_divide_),
158 ("divide_.Tensor_mode", div_mode_),
159 ("dot", dot),
160 ("elu", elu),
161 ("elu_", elu_),
162 ("elu_backward", elu_backward),
163 ("embedding", embedding),
164 ("embedding_backward", embedding_backward),
165 ("embedding_dense_backward", embedding_dense_backward),
166 ("eq.Scalar", eq_scalar),
167 ("eq.Tensor", eq),
168 ("equal", equal),
169 ("erf", erf),
170 ("erf_", erf_),
171 ("exp", exp),
172 ("exp_", exp_),
173 ("exp.out", exp_out),
174 ("exp2", exp2),
175 ("exp2_", exp2_),
176 ("exponential_", exponential_),
177 ("eye", eye),
178 ("eye.m", eye_m),
179 ("fill.Scalar", fill_scalar),
180 ("fill.Scalar_out", fill_scalar_out),
181 ("fill.Tensor", fill_tensor),
182 ("fill.Tensor_out", fill_tensor_out),
183 ("fill_.Scalar", fill_scalar_),
184 ("fill_.Tensor", fill_tensor_),
185 ("flip", flip),
186 ("floor_", floor_),
187 ("floor_divide", floor_divide),
188 ("floor_divide.Scalar", floor_divide),
189 ("floor_divide_.Scalar", floor_divide_),
190 ("floor_divide_.Tensor", floor_divide_),
191 ("fmin", fmin),
192 ("fmin.out", fmin_out),
193 ("full", full),
194 ("full_like", full_like),
195 ("gather", gather),
196 ("gather_backward", gather_backward),
197 ("ge.Scalar", ge_scalar),
198 ("ge.Tensor", ge),
199 ("gelu", gelu),
200 ("gelu_", gelu_),
201 ("gelu_backward", gelu_backward),
202 ("glu", glu),
203 ("glu_backward", glu_backward),
204 ("gt.Scalar", gt_scalar),
205 ("gt.Tensor", gt),
206 ("hardsigmoid", hardsigmoid),
207 ("hardsigmoid.out", hardsigmoid_out),
208 ("hardswish_", hardswish_),
209 ("hstack", hstack),
210 ("hypot", hypot),
211 ("i0", i0),
212 ("i0.out", i0_out),
213 ("i0_", i0_),
214 ("index.Tensor", index),
215 ("index_add", index_add),
216 ("index_add_", index_add_),
217 ("index_put", index_put),
218 ("index_put_", index_put_),
219 ("index_select", index_select),
220 ("isclose", isclose),
221 ("isfinite", isfinite),
222 ("isin.Scalar_Tensor", isin),
223 ("isin.Tensor_Scalar", isin),
224 ("isin.Tensor_Tensor", isin),
225 ("isinf", isinf),
226 ("isnan", isnan),
227 ("kron", kron),
228 ("le.Scalar", le_scalar),
229 ("le.Tensor", le),
230 ("lerp.Scalar", lerp_scalar),
231 ("lerp.Tensor", lerp_tensor),
232 ("lerp_.Scalar", lerp_scalar_),
233 ("lerp_.Tensor", lerp_tensor_),
234 ("lift_fresh_copy", lift_fresh_copy),
235 ("linalg_vector_norm", vector_norm),
236 ("linspace", linspace),
237 ("log", log),
238 ("log_sigmoid", log_sigmoid),
239 ("log1p_", log1p_),
240 ("logaddexp", logaddexp),
241 ("logaddexp.out", logaddexp_out),
242 ("logical_and", logical_and),
243 ("logical_and_", logical_and_),
244 ("logical_not", logical_not),
245 ("logical_or", logical_or),
246 ("logical_or_", logical_or_),
247 ("logical_xor", logical_xor),
248 ("logit", logit),
249 ("logit_", logit_),
250 ("logspace", logspace),
251 ("lt.Scalar", lt_scalar),
252 ("lt.Tensor", lt),
253 ("margin_ranking_loss", margin_ranking_loss),
254 ("masked_fill.Scalar", masked_fill),
255 ("masked_fill.Tensor", masked_fill),
256 ("masked_fill_.Scalar", masked_fill_),
257 ("masked_fill_.Tensor", masked_fill_),
258 ("masked_scatter", masked_scatter),
259 ("masked_scatter_", masked_scatter_),
260 ("masked_select", masked_select),
261 ("max", max),
262 ("max.dim", max_dim),
263 ("maximum", maximum),
264 ("max_pool2d_with_indices", max_pool2d_with_indices),
265 ("max_pool2d_backward", max_pool2d_backward),
266 ("mean", mean),
267 ("mean.dim", mean_dim),
268 ("min", min),
269 ("min.dim", min_dim),
270 ("minimum", minimum),
271 ("mm", mm),
272 ("mm.out", mm_out),
273 ("mse_loss", mse_loss),
274 ("mul.Tensor", mul),
275 ("mul_.Tensor", mul_),
276 ("multinomial", multinomial),
277 ("mv", mv),
278 ("nan_to_num", nan_to_num),
279 ("native_batch_norm", batch_norm),
280 ("native_batch_norm_backward", batch_norm_backward),
281 ("native_dropout", dropout),
282 ("native_dropout_backward", dropout_backward),
283 ("native_group_norm", group_norm),
284 ("native_group_norm_backward", group_norm_backward),
285 ("native_layer_norm", layer_norm),
286 ("native_layer_norm_backward", layer_norm_backward),
287 ("ne.Scalar", ne_scalar),
288 ("ne.Tensor", ne),
289 ("neg", neg),
290 ("neg_", neg_),
291 ("nll_loss_backward", nll_loss_backward),
292 ("nll_loss_forward", nll_loss_forward),
293 ("nll_loss2d_backward", nll_loss2d_backward),
294 ("nll_loss2d_forward", nll_loss2d_forward),
295 ("nll_loss_nd_forward", nll_loss_nd_forward),
296 ("nll_loss_nd_backward", nll_loss_nd_backward),
297 ("nonzero", nonzero),
298 ("normal.float_Tensor", normal_float_tensor),
299 ("normal.Tensor_float", normal_tensor_float),
300 ("normal.Tensor_Tensor", normal_tensor_tensor),
301 ("normal_", normal_),
302 ("ones", ones),
303 ("ones_like", ones_like),
304 ("one_hot", one_hot),
305 ("pad", pad),
306 ("pixel_unshuffle", pixel_unshuffle),
307 ("pixel_unshuffle.out", pixel_unshuffle_out),
308 ("polar", polar),
309 ("pow.Scalar", pow_scalar),
310 ("pow.Tensor_Scalar", pow_tensor_scalar),
311 ("pow.Tensor_Tensor", pow_tensor_tensor),
312 ("pow_.Scalar", pow_tensor_scalar_),
313 ("pow_.Tensor", pow_tensor_tensor_),
314 ("prelu", prelu),
315 ("prod", prod),
316 ("prod.dim_int", prod_dim),
317 ("quantile", quantile),
318 ("rand", rand),
319 ("rand_like", rand_like),
320 ("randn", randn),
321 ("randn_like", randn_like),
322 ("randperm", randperm),
323 ("reciprocal", reciprocal),
324 ("reciprocal_", reciprocal_),
325 ("reflection_pad2d", reflection_pad2d),
326 ("reflection_pad2d.out", reflection_pad2d_out),
327 ("reflection_pad1d", reflection_pad1d),
328 ("reflection_pad1d.out", reflection_pad1d_out),
329 ("relu", relu),
330 ("relu_", relu_),
331 ("relu6", relu6),
332 ("remainder.Scalar", remainder),
333 ("remainder.Scalar_Tensor", remainder),
334 ("remainder.Tensor", remainder),
335 ("remainder_.Scalar", remainder_),
336 ("remainder_.Tensor", remainder_),
337 ("repeat", repeat),
338 ("repeat_interleave.self_int", repeat_interleave_self_int),
339 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
340 ("repeat_interleave.Tensor", repeat_interleave_tensor),
341 ("replication_pad1d", replication_pad1d),
342 ("replication_pad1d.out", replication_pad1d_out),
343 ("replication_pad3d", replication_pad3d),
344 ("resolve_conj", resolve_conj),
345 ("resolve_neg", resolve_neg),
346 ("rms_norm", rms_norm),
347 ("rrelu_with_noise_backward", rrelu_with_noise_backward),
348 ("rsqrt", rsqrt),
349 ("rsqrt_", rsqrt_),
350 ("scaled_softmax_backward", scaled_softmax_backward),
351 ("scaled_softmax_forward", scaled_softmax_forward),
352 ("scatter.reduce", scatter),
353 ("scatter.src", scatter),
354 ("scatter_.reduce", scatter_),
355 ("scatter_.src", scatter_),
356 ("scatter_add_", scatter_add_),
357 ("select_backward", select_backward),
358 ("select_scatter", select_scatter),
359 ("selu_", selu_),
360 ("sgn_", sgn_),
361 ("selu", selu),
362 ("sigmoid", sigmoid),
363 ("sigmoid_", sigmoid_),
364 ("sigmoid_backward", sigmoid_backward),
365 ("silu", silu),
366 ("silu_", silu_),
367 ("silu_backward", silu_backward),
368 ("sin", sin),
369 ("sin_", sin_),
370 ("sinh_", sinh_),
371 ("slice_backward", slice_backward),
372 ("slice_scatter", slice_scatter),
373 ("soft_margin_loss", soft_margin_loss),
374 ("softplus", softplus),
375 ("softshrink", softshrink),
376 ("softshrink.out", softshrink_out),
377 ("sort", sort),
378 ("sort.stable", sort_stable),
379 ("special_i1", special_i1),
380 ("special_i0e", special_i0e),
381 ("special_i0e.out", special_i0e_out),
382 ("sqrt", sqrt),
383 ("sqrt_", sqrt_),
384 ("stack", stack),
385 ("std.correction", std),
386 ("sub.Tensor", sub),
387 ("sub_.Tensor", sub_),
388 ("sum", sum),
389 ("sum.dim_IntList", sum_dim),
390 ("sum.IntList_out", sum_dim_out),
391 ("sum.out", sum_out),
392 ("t_copy", t_copy),
393 ("t_copy.out", t_copy_out),
394 ("tan", tan),
395 ("tan_", tan_),
396 ("tanh", tanh),
397 ("tanh_", tanh_),
398 ("tanh_backward", tanh_backward),
399 ("threshold", threshold),
400 ("threshold_backward", threshold_backward),
401 ("tile", tile),
402 ("topk", topk),
403 ("trace", trace),
404 ("tril", tril),
405 ("triu", triu),
406 ("triu_", triu_),
407 ("true_divide.Scalar", true_divide),
408 ("true_divide.Tensor", true_divide),
409 ("true_divide_.Scalar", true_divide_),
410 ("true_divide_.Tensor", true_divide_),
411 ("unfold_backward", unfold_backward),
412 ("uniform_", uniform_),
413 ("upsample_bicubic2d", upsample_bicubic2d),
414 ("upsample_linear1d", upsample_linear1d),
415 ("upsample_nearest1d", upsample_nearest1d),
416 ("upsample_nearest2d", upsample_nearest2d),
417 ("upsample_nearest3d", upsample_nearest3d),
418 ("var_mean.correction", var_mean),
419 ("vdot", vdot),
420 ("vstack", vstack),
421 ("where.self", where_self),
422 ("where.self_out", where_self_out),
423 ("zero", zero),
424 ("zero.out", zero_out),
425 ("zero_", zero_),
426 ("zeros", zeros),
427 ("zeros_like", zeros_like),
428)
430# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
431FULL_CONFIG_BY_FUNC = {}
432for _item in _FULL_CONFIG:
433 if not _item or len(_item) < 2:
434 continue
435 fn = _item[1]
436 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
437 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
440def enable(
441 lib=aten_lib,
442 unused=None,
443 registrar=registrar,
444 record=False,
445 once=False,
446 path=None,
447):
448 """Register all FlagGems ops except those explicitly excluded.
450 Args:
451 lib: torch.library.Library instance to register into. Defaults to the
452 global `aten_lib` (IMPL mode).
453 unused: Which ops to skip. Supported forms:
454 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
455 - str path to a YAML file ending with .yml/.yaml containing an
456 `exclude:` list.
457 - "default" or None: auto-load vendor/arch-specific
458 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
459 registrar: Registrar class; defaults to `Register`.
460 record: Whether to enable FlagGems logging.
461 once: When True, log only once.
462 path: Optional log output path when recording.
464 Notes:
465 - If the exclude list/YAML resolves to empty, all ops are registered.
466 """
467 global current_work_registrar
468 exclude_ops = resolve_user_setting(unused, "exclude")
469 current_work_registrar = registrar(
470 _FULL_CONFIG,
471 user_include_ops=[],
472 user_exclude_ops=exclude_ops,
473 cpp_patched_ops=list(set(aten_patch_list)),
474 lib=lib,
475 )
476 setup_flaggems_logging(path=path, record=record, once=once)
479def only_enable(
480 lib=aten_lib,
481 include=None,
482 registrar=registrar,
483 record=False,
484 once=False,
485 path=None,
486):
487 """Register only the specified FlagGems ops and skip the rest.
489 Args:
490 lib: torch.library.Library instance to register into. Defaults to the
491 global `aten_lib` (IMPL mode).
492 include: Which ops to register. Supported forms:
493 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
494 - str path to a YAML file ending with .yml/.yaml (expects a list or
495 an `include:` key).
496 - "default" or None: auto-load vendor/arch-specific
497 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
498 registrar: Registrar class; defaults to `Register`.
499 record: Whether to enable FlagGems logging.
500 once: When True, log only once.
501 path: Optional log output path when recording.
503 Classic usage:
504 - Only register a few ops:
505 only_enable(include=["rms_norm", "softmax"])
506 - Use vendor default YAML:
507 only_enable(include="default") # or include=None
508 - Use a custom YAML:
509 only_enable(include="/path/to/only_enable.yaml")
511 Notes:
512 - If the include list/YAML resolves to empty or none of the names match
513 known ops, the function warns and returns without registering.
514 """
515 include_ops = resolve_user_setting(include, "include")
516 if not include_ops:
517 warnings.warn(
518 "only_enable failed: No include entries resolved from list or yaml."
519 )
520 return
522 global current_work_registrar
523 current_work_registrar = registrar(
524 _FULL_CONFIG,
525 user_include_ops=include_ops,
526 user_exclude_ops=[],
527 cpp_patched_ops=list(set(aten_patch_list)),
528 full_config_by_func=FULL_CONFIG_BY_FUNC,
529 lib=lib,
530 )
531 setup_flaggems_logging(path=path, record=record, once=once)
534class use_gems:
535 """
536 The 'include' parameter has higher priority than 'exclude'.
537 When 'include' is not None, use_gems will not process 'exclude'.
538 """
540 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
541 self.lib = torch.library.Library("aten", "IMPL")
542 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
543 self.include = include if isinstance(include, (list, tuple, set, str)) else []
544 self.registrar = Register
545 self.record = record
546 self.once = once
547 self.path = path
549 def __enter__(self):
550 if self.include:
551 only_enable(
552 lib=self.lib,
553 include=self.include,
554 registrar=self.registrar,
555 record=self.record,
556 once=self.once,
557 path=self.path,
558 )
559 else:
560 enable(
561 lib=self.lib,
562 unused=self.exclude,
563 registrar=self.registrar,
564 record=self.record,
565 once=self.once,
566 path=self.path,
567 )
569 def __exit__(self, exc_type, exc_val, exc_tb):
570 global current_work_registrar
571 if torch.__version__ >= "2.5":
572 self.lib._destroy()
573 del self.lib
574 del self.exclude
575 del self.include
576 del self.registrar
577 del current_work_registrar
578 if self.record:
579 teardown_flaggems_logging()
581 @property
582 def experimental_ops(self):
583 import flag_gems.experimental_ops
585 return flag_gems.experimental_ops
588def all_registered_ops():
589 return current_work_registrar.get_all_ops()
592def all_registered_keys():
593 return current_work_registrar.get_all_keys()
596__all__ = [
597 "enable",
598 "only_enable",
599 "use_gems",
600 "all_registered_ops",
601 "all_registered_keys",
602]