Coverage for src/flag_gems/__init__.py: 89%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import warnings
3import torch
4from packaging import version
6from flag_gems import testing # noqa: F401
7from flag_gems import runtime
8from flag_gems.config import aten_patch_list, resolve_user_setting
9from flag_gems.experimental_ops import * # noqa: F403
10from flag_gems.fused import * # noqa: F403
11from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
12from flag_gems.modules import * # noqa: F403
13from flag_gems.ops import * # noqa: F403
14from flag_gems.patches import * # noqa: F403
15from flag_gems.runtime.register import Register
17__version__ = "4.2.1.rc.0"
18device = runtime.device.name
19vendor_name = runtime.device.vendor_name
20aten_lib = torch.library.Library("aten", "IMPL")
21registrar = Register
22current_work_registrar = None
23runtime.replace_customized_ops(globals())
26def torch_ge(v):
27 return version.parse(torch.__version__) >= version.parse(v)
30_FULL_CONFIG = (
31 ("_flash_attention_forward", flash_attention_forward),
32 (
33 "_functional_sym_constrain_range_for_size",
34 _functional_sym_constrain_range_for_size,
35 ),
36 ("_log_softmax", log_softmax),
37 ("_log_softmax_backward_data", log_softmax_backward),
38 ("_safe_softmax", _safe_softmax),
39 ("_softmax", softmax),
40 ("_softmax_backward_data", softmax_backward),
41 (
42 "_to_copy",
43 to_copy,
44 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
45 ),
46 ("_unique2", _unique2),
47 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
48 ("_upsample_nearest_exact1d", _upsample_nearest_exact1d),
49 ("_weight_norm_interface", weight_norm_interface),
50 ("_weight_norm_interface_backward", weight_norm_interface_backward),
51 ("abs", abs),
52 ("abs_", abs_),
53 ("absolute", absolute),
54 ("acos", acos),
55 ("add.Tensor", add),
56 ("add_.Tensor", add_),
57 ("addcdiv", addcdiv),
58 ("addcmul", addcmul),
59 ("addmv", addmv),
60 ("addmv.out", addmv_out),
61 ("addmm", addmm),
62 ("addmm.out", addmm_out),
63 ("addr", addr),
64 ("alias_copy", alias_copy),
65 ("all", all),
66 ("all.dim", all_dim),
67 ("all.dims", all_dims),
68 ("allclose", allclose),
69 ("amax", amax),
70 ("angle", angle),
71 ("any", any),
72 ("any.dim", any_dim),
73 ("any.dims", any_dims),
74 ("arange", arange),
75 ("arange.start", arange_start),
76 ("arange.start_step", arange_start),
77 ("arcsinh", arcsinh),
78 ("arcsinh.out", arcsinh_out),
79 ("arcsinh_", arcsinh_),
80 ("argmax", argmax),
81 ("argmin", argmin),
82 ("asinh_", asinh_),
83 ("atan", atan),
84 ("atan_", atan_),
85 ("arctanh_", arctanh_),
86 ("avg_pool2d", avg_pool2d),
87 ("avg_pool2d_backward", avg_pool2d_backward),
88 ("baddbmm", baddbmm),
89 ("bincount", bincount),
90 ("bitwise_and.Scalar", bitwise_and_scalar),
91 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
92 ("bitwise_and.Tensor", bitwise_and_tensor),
93 ("bitwise_and_.Scalar", bitwise_and_scalar_),
94 ("bitwise_and_.Tensor", bitwise_and_tensor_),
95 ("bitwise_left_shift", bitwise_left_shift),
96 ("bitwise_not", bitwise_not),
97 ("bitwise_not_", bitwise_not_),
98 ("bitwise_or.Scalar", bitwise_or_scalar),
99 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
100 ("bitwise_or.Tensor", bitwise_or_tensor),
101 ("bitwise_or_.Scalar", bitwise_or_scalar_),
102 ("bitwise_or_.Tensor", bitwise_or_tensor_),
103 ("bitwise_right_shift", bitwise_right_shift),
104 ("bmm", bmm),
105 ("bmm.out", bmm_out),
106 ("cat", cat),
107 ("celu", celu),
108 ("celu_", celu_),
109 ("ceil", ceil),
110 ("ceil_", ceil_),
111 ("ceil.out", ceil_out),
112 ("clamp", clamp),
113 ("clamp.Tensor", clamp_tensor),
114 ("clamp_min", clamp_min),
115 ("clamp_", clamp_),
116 ("clamp_.Tensor", clamp_tensor_),
117 ("clamp_min_", clamp_min_),
118 ("constant_pad_nd", constant_pad_nd),
119 # ("contiguous", contiguous),
120 ("conv1d", conv1d),
121 ("conv1d.padding", conv1d),
122 ("conv2d", conv2d),
123 ("conv2d.padding", conv2d),
124 ("conv3d", conv3d),
125 ("conv3d.padding", conv3d),
126 (
127 "copy_",
128 copy_,
129 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
130 ),
131 ("cos", cos),
132 ("cos_", cos_),
133 ("count_nonzero", count_nonzero),
134 ("cummax", cummax),
135 ("cummin", cummin),
136 ("cumsum", cumsum),
137 ("cumsum.out", cumsum_out),
138 ("diag", diag),
139 ("diag_embed", diag_embed),
140 ("diagonal_backward", diagonal_backward),
141 ("digamma_", digamma_),
142 ("div.Scalar", true_divide),
143 ("div.Scalar_mode", div_mode),
144 ("div.Tensor", true_divide),
145 ("div.Tensor_mode", div_mode),
146 ("div.out", true_divide_out),
147 ("div_.Scalar", true_divide_),
148 ("div_.Scalar_mode", div_mode_),
149 ("div_.Tensor", true_divide_),
150 ("div_.Tensor_mode", div_mode_),
151 ("divide.Scalar", true_divide),
152 ("divide.Scalar_mode", div_mode),
153 ("divide.Tensor", true_divide),
154 ("divide.Tensor_mode", div_mode),
155 ("divide_.Scalar", true_divide_),
156 ("divide_.Scalar_mode", div_mode_),
157 ("divide_.Tensor", true_divide_),
158 ("divide_.Tensor_mode", div_mode_),
159 ("dot", dot),
160 ("elu", elu),
161 ("elu_", elu_),
162 ("elu_backward", elu_backward),
163 ("embedding", embedding),
164 ("embedding_backward", embedding_backward),
165 ("embedding_dense_backward", embedding_dense_backward),
166 ("eq.Scalar", eq_scalar),
167 ("eq.Tensor", eq),
168 ("equal", equal),
169 ("erf", erf),
170 ("erf_", erf_),
171 ("exp", exp),
172 ("exp_", exp_),
173 ("exp.out", exp_out),
174 ("exp2", exp2),
175 ("exp2_", exp2_),
176 ("exponential_", exponential_),
177 ("eye", eye),
178 ("eye.m", eye_m),
179 ("fill.Scalar", fill_scalar),
180 ("fill.Scalar_out", fill_scalar_out),
181 ("fill.Tensor", fill_tensor),
182 ("fill.Tensor_out", fill_tensor_out),
183 ("fill_.Scalar", fill_scalar_),
184 ("fill_.Tensor", fill_tensor_),
185 ("flip", flip),
186 ("floor_", floor_),
187 ("floor_divide", floor_divide),
188 ("floor_divide.Scalar", floor_divide),
189 ("floor_divide_.Scalar", floor_divide_),
190 ("floor_divide_.Tensor", floor_divide_),
191 ("fmin", fmin),
192 ("fmin.out", fmin_out),
193 ("full", full),
194 ("full_like", full_like),
195 ("gather", gather),
196 ("gather_backward", gather_backward),
197 ("ge.Scalar", ge_scalar),
198 ("ge.Tensor", ge),
199 ("gelu", gelu),
200 ("gelu_", gelu_),
201 ("gelu_backward", gelu_backward),
202 ("glu", glu),
203 ("glu_backward", glu_backward),
204 ("gt.Scalar", gt_scalar),
205 ("gt.Tensor", gt),
206 ("hardsigmoid", hardsigmoid),
207 ("hardsigmoid.out", hardsigmoid_out),
208 ("hardswish_", hardswish_),
209 ("hstack", hstack),
210 ("hypot", hypot),
211 ("i0", i0),
212 ("i0.out", i0_out),
213 ("i0_", i0_),
214 ("index.Tensor", index),
215 ("index_add", index_add),
216 ("index_add_", index_add_),
217 ("index_put", index_put),
218 ("index_put_", index_put_),
219 ("index_select", index_select),
220 ("isclose", isclose),
221 ("isfinite", isfinite),
222 ("isin.Scalar_Tensor", isin),
223 ("isin.Tensor_Scalar", isin),
224 ("isin.Tensor_Tensor", isin),
225 ("isinf", isinf),
226 ("isnan", isnan),
227 ("kron", kron),
228 ("le.Scalar", le_scalar),
229 ("le.Tensor", le),
230 ("lerp.Scalar", lerp_scalar),
231 ("lerp.Tensor", lerp_tensor),
232 ("lerp_.Scalar", lerp_scalar_),
233 ("lerp_.Tensor", lerp_tensor_),
234 ("lift_fresh_copy", lift_fresh_copy),
235 ("linalg_vector_norm", vector_norm),
236 ("linspace", linspace),
237 ("log", log),
238 ("log_sigmoid", log_sigmoid),
239 ("log1p_", log1p_),
240 ("logaddexp", logaddexp),
241 ("logaddexp.out", logaddexp_out),
242 ("logical_and", logical_and),
243 ("logical_and_", logical_and_),
244 ("logical_not", logical_not),
245 ("logical_or", logical_or),
246 ("logical_or_", logical_or_),
247 ("logical_xor", logical_xor),
248 ("logit", logit),
249 ("logspace", logspace),
250 ("lt.Scalar", lt_scalar),
251 ("lt.Tensor", lt),
252 ("margin_ranking_loss", margin_ranking_loss),
253 ("masked_fill.Scalar", masked_fill),
254 ("masked_fill.Tensor", masked_fill),
255 ("masked_fill_.Scalar", masked_fill_),
256 ("masked_fill_.Tensor", masked_fill_),
257 ("masked_scatter", masked_scatter),
258 ("masked_scatter_", masked_scatter_),
259 ("masked_select", masked_select),
260 ("max", max),
261 ("max.dim", max_dim),
262 ("maximum", maximum),
263 ("max_pool2d_with_indices", max_pool2d_with_indices),
264 ("max_pool2d_backward", max_pool2d_backward),
265 ("mean", mean),
266 ("mean.dim", mean_dim),
267 ("min", min),
268 ("min.dim", min_dim),
269 ("minimum", minimum),
270 ("mm", mm),
271 ("mm.out", mm_out),
272 ("mse_loss", mse_loss),
273 ("mul.Tensor", mul),
274 ("mul_.Tensor", mul_),
275 ("multinomial", multinomial),
276 ("mv", mv),
277 ("nan_to_num", nan_to_num),
278 ("native_batch_norm", batch_norm),
279 ("native_batch_norm_backward", batch_norm_backward),
280 ("native_dropout", dropout),
281 ("native_dropout_backward", dropout_backward),
282 ("native_group_norm", group_norm),
283 ("native_group_norm_backward", group_norm_backward),
284 ("native_layer_norm", layer_norm),
285 ("native_layer_norm_backward", layer_norm_backward),
286 ("ne.Scalar", ne_scalar),
287 ("ne.Tensor", ne),
288 ("neg", neg),
289 ("neg_", neg_),
290 ("nll_loss_backward", nll_loss_backward),
291 ("nll_loss_forward", nll_loss_forward),
292 ("nll_loss2d_backward", nll_loss2d_backward),
293 ("nll_loss2d_forward", nll_loss2d_forward),
294 ("nll_loss_nd_forward", nll_loss_nd_forward),
295 ("nll_loss_nd_backward", nll_loss_nd_backward),
296 ("nonzero", nonzero),
297 ("normal.float_Tensor", normal_float_tensor),
298 ("normal.Tensor_float", normal_tensor_float),
299 ("normal.Tensor_Tensor", normal_tensor_tensor),
300 ("normal_", normal_),
301 ("ones", ones),
302 ("ones_like", ones_like),
303 ("one_hot", one_hot),
304 ("pad", pad),
305 ("pixel_unshuffle", pixel_unshuffle),
306 ("pixel_unshuffle.out", pixel_unshuffle_out),
307 ("polar", polar),
308 ("pow.Scalar", pow_scalar),
309 ("pow.Tensor_Scalar", pow_tensor_scalar),
310 ("pow.Tensor_Tensor", pow_tensor_tensor),
311 ("pow_.Scalar", pow_tensor_scalar_),
312 ("pow_.Tensor", pow_tensor_tensor_),
313 ("prelu", prelu),
314 ("prod", prod),
315 ("prod.dim_int", prod_dim),
316 ("quantile", quantile),
317 ("rand", rand),
318 ("rand_like", rand_like),
319 ("randn", randn),
320 ("randn_like", randn_like),
321 ("randperm", randperm),
322 ("reciprocal", reciprocal),
323 ("reciprocal_", reciprocal_),
324 ("reflection_pad2d", reflection_pad2d),
325 ("reflection_pad2d.out", reflection_pad2d_out),
326 ("reflection_pad1d", reflection_pad1d),
327 ("reflection_pad1d.out", reflection_pad1d_out),
328 ("relu", relu),
329 ("relu_", relu_),
330 ("relu6", relu6),
331 ("remainder.Scalar", remainder),
332 ("remainder.Scalar_Tensor", remainder),
333 ("remainder.Tensor", remainder),
334 ("remainder_.Scalar", remainder_),
335 ("remainder_.Tensor", remainder_),
336 ("repeat", repeat),
337 ("repeat_interleave.self_int", repeat_interleave_self_int),
338 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
339 ("repeat_interleave.Tensor", repeat_interleave_tensor),
340 ("replication_pad1d", replication_pad1d),
341 ("replication_pad1d.out", replication_pad1d_out),
342 ("replication_pad3d", replication_pad3d),
343 ("resolve_conj", resolve_conj),
344 ("resolve_neg", resolve_neg),
345 ("rms_norm", rms_norm),
346 ("rrelu_with_noise_backward", rrelu_with_noise_backward),
347 ("rsqrt", rsqrt),
348 ("rsqrt_", rsqrt_),
349 ("scaled_softmax_backward", scaled_softmax_backward),
350 ("scaled_softmax_forward", scaled_softmax_forward),
351 ("scatter.reduce", scatter),
352 ("scatter.src", scatter),
353 ("scatter_.reduce", scatter_),
354 ("scatter_.src", scatter_),
355 ("scatter_add_", scatter_add_),
356 ("select_scatter", select_scatter),
357 ("selu_", selu_),
358 ("sgn_", sgn_),
359 ("selu", selu),
360 ("sigmoid", sigmoid),
361 ("sigmoid_", sigmoid_),
362 ("sigmoid_backward", sigmoid_backward),
363 ("silu", silu),
364 ("silu_", silu_),
365 ("silu_backward", silu_backward),
366 ("sin", sin),
367 ("sin_", sin_),
368 ("sinh_", sinh_),
369 ("slice_backward", slice_backward),
370 ("slice_scatter", slice_scatter),
371 ("soft_margin_loss", soft_margin_loss),
372 ("softplus", softplus),
373 ("softshrink", softshrink),
374 ("softshrink.out", softshrink_out),
375 ("sort", sort),
376 ("sort.stable", sort_stable),
377 ("special_i1", special_i1),
378 ("special_i0e", special_i0e),
379 ("special_i0e.out", special_i0e_out),
380 ("sqrt", sqrt),
381 ("sqrt_", sqrt_),
382 ("stack", stack),
383 ("std.correction", std),
384 ("sub.Tensor", sub),
385 ("sub_.Tensor", sub_),
386 ("sum", sum),
387 ("sum.dim_IntList", sum_dim),
388 ("sum.IntList_out", sum_dim_out),
389 ("sum.out", sum_out),
390 ("t_copy", t_copy),
391 ("t_copy.out", t_copy_out),
392 ("tan", tan),
393 ("tan_", tan_),
394 ("tanh", tanh),
395 ("tanh_", tanh_),
396 ("tanh_backward", tanh_backward),
397 ("threshold", threshold),
398 ("threshold_backward", threshold_backward),
399 ("tile", tile),
400 ("topk", topk),
401 ("trace", trace),
402 ("tril", tril),
403 ("triu", triu),
404 ("triu_", triu_),
405 ("true_divide.Scalar", true_divide),
406 ("true_divide.Tensor", true_divide),
407 ("true_divide_.Scalar", true_divide_),
408 ("true_divide_.Tensor", true_divide_),
409 ("unfold_backward", unfold_backward),
410 ("uniform_", uniform_),
411 ("upsample_bicubic2d", upsample_bicubic2d),
412 ("upsample_linear1d", upsample_linear1d),
413 ("upsample_nearest1d", upsample_nearest1d),
414 ("upsample_nearest2d", upsample_nearest2d),
415 ("upsample_nearest3d", upsample_nearest3d),
416 ("var_mean.correction", var_mean),
417 ("vdot", vdot),
418 ("vstack", vstack),
419 ("where.self", where_self),
420 ("where.self_out", where_self_out),
421 ("zero", zero),
422 ("zero.out", zero_out),
423 ("zero_", zero_),
424 ("zeros", zeros),
425 ("zeros_like", zeros_like),
426)
428# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
429FULL_CONFIG_BY_FUNC = {}
430for _item in _FULL_CONFIG:
431 if not _item or len(_item) < 2:
432 continue
433 fn = _item[1]
434 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
435 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
438def enable(
439 lib=aten_lib,
440 unused=None,
441 registrar=registrar,
442 record=False,
443 once=False,
444 path=None,
445):
446 """Register all FlagGems ops except those explicitly excluded.
448 Args:
449 lib: torch.library.Library instance to register into. Defaults to the
450 global `aten_lib` (IMPL mode).
451 unused: Which ops to skip. Supported forms:
452 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
453 - str path to a YAML file ending with .yml/.yaml containing an
454 `exclude:` list.
455 - "default" or None: auto-load vendor/arch-specific
456 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
457 registrar: Registrar class; defaults to `Register`.
458 record: Whether to enable FlagGems logging.
459 once: When True, log only once.
460 path: Optional log output path when recording.
462 Notes:
463 - If the exclude list/YAML resolves to empty, all ops are registered.
464 """
465 global current_work_registrar
466 exclude_ops = resolve_user_setting(unused, "exclude")
467 current_work_registrar = registrar(
468 _FULL_CONFIG,
469 user_include_ops=[],
470 user_exclude_ops=exclude_ops,
471 cpp_patched_ops=list(set(aten_patch_list)),
472 lib=lib,
473 )
474 setup_flaggems_logging(path=path, record=record, once=once)
477def only_enable(
478 lib=aten_lib,
479 include=None,
480 registrar=registrar,
481 record=False,
482 once=False,
483 path=None,
484):
485 """Register only the specified FlagGems ops and skip the rest.
487 Args:
488 lib: torch.library.Library instance to register into. Defaults to the
489 global `aten_lib` (IMPL mode).
490 include: Which ops to register. Supported forms:
491 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
492 - str path to a YAML file ending with .yml/.yaml (expects a list or
493 an `include:` key).
494 - "default" or None: auto-load vendor/arch-specific
495 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
496 registrar: Registrar class; defaults to `Register`.
497 record: Whether to enable FlagGems logging.
498 once: When True, log only once.
499 path: Optional log output path when recording.
501 Classic usage:
502 - Only register a few ops:
503 only_enable(include=["rms_norm", "softmax"])
504 - Use vendor default YAML:
505 only_enable(include="default") # or include=None
506 - Use a custom YAML:
507 only_enable(include="/path/to/only_enable.yaml")
509 Notes:
510 - If the include list/YAML resolves to empty or none of the names match
511 known ops, the function warns and returns without registering.
512 """
513 include_ops = resolve_user_setting(include, "include")
514 if not include_ops:
515 warnings.warn(
516 "only_enable failed: No include entries resolved from list or yaml."
517 )
518 return
520 global current_work_registrar
521 current_work_registrar = registrar(
522 _FULL_CONFIG,
523 user_include_ops=include_ops,
524 user_exclude_ops=[],
525 cpp_patched_ops=list(set(aten_patch_list)),
526 full_config_by_func=FULL_CONFIG_BY_FUNC,
527 lib=lib,
528 )
529 setup_flaggems_logging(path=path, record=record, once=once)
532class use_gems:
533 """
534 The 'include' parameter has higher priority than 'exclude'.
535 When 'include' is not None, use_gems will not process 'exclude'.
536 """
538 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
539 self.lib = torch.library.Library("aten", "IMPL")
540 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
541 self.include = include if isinstance(include, (list, tuple, set, str)) else []
542 self.registrar = Register
543 self.record = record
544 self.once = once
545 self.path = path
547 def __enter__(self):
548 if self.include:
549 only_enable(
550 lib=self.lib,
551 include=self.include,
552 registrar=self.registrar,
553 record=self.record,
554 once=self.once,
555 path=self.path,
556 )
557 else:
558 enable(
559 lib=self.lib,
560 unused=self.exclude,
561 registrar=self.registrar,
562 record=self.record,
563 once=self.once,
564 path=self.path,
565 )
567 def __exit__(self, exc_type, exc_val, exc_tb):
568 global current_work_registrar
569 if torch.__version__ >= "2.5":
570 self.lib._destroy()
571 del self.lib
572 del self.exclude
573 del self.include
574 del self.registrar
575 del current_work_registrar
576 if self.record:
577 teardown_flaggems_logging()
579 @property
580 def experimental_ops(self):
581 import flag_gems.experimental_ops
583 return flag_gems.experimental_ops
586def all_registered_ops():
587 return current_work_registrar.get_all_ops()
590def all_registered_keys():
591 return current_work_registrar.get_all_keys()
594__all__ = [
595 "enable",
596 "only_enable",
597 "use_gems",
598 "all_registered_ops",
599 "all_registered_keys",
600]