Coverage for src/flag_gems/__init__.py: 89%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import warnings
3import torch
4from packaging import version
6from flag_gems import testing # noqa: F401
7from flag_gems import runtime
8from flag_gems.config import aten_patch_list, resolve_user_setting
9from flag_gems.experimental_ops import * # noqa: F403
10from flag_gems.fused import * # noqa: F403
11from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
12from flag_gems.modules import * # noqa: F403
13from flag_gems.ops import * # noqa: F403
14from flag_gems.patches import * # noqa: F403
15from flag_gems.runtime.register import Register
17__version__ = "4.2.1.rc.0"
18device = runtime.device.name
19vendor_name = runtime.device.vendor_name
20aten_lib = torch.library.Library("aten", "IMPL")
21registrar = Register
22current_work_registrar = None
23runtime.replace_customized_ops(globals())
26def torch_ge(v):
27 return version.parse(torch.__version__) >= version.parse(v)
30_FULL_CONFIG = (
31 ("_flash_attention_forward", flash_attention_forward),
32 ("_log_softmax", log_softmax),
33 ("_log_softmax_backward_data", log_softmax_backward),
34 ("_softmax", softmax),
35 ("_softmax_backward_data", softmax_backward),
36 (
37 "_to_copy",
38 to_copy,
39 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
40 ),
41 ("_unique2", _unique2),
42 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
43 ("_weight_norm_interface", weight_norm_interface),
44 ("_weight_norm_interface_backward", weight_norm_interface_backward),
45 ("abs", abs),
46 ("abs_", abs_),
47 ("acos", acos),
48 ("add.Tensor", add),
49 ("add_.Tensor", add_),
50 ("addcdiv", addcdiv),
51 ("addcmul", addcmul),
52 ("addmv", addmv),
53 ("addmv.out", addmv_out),
54 ("addmm", addmm),
55 ("addmm.out", addmm_out),
56 ("addr", addr),
57 ("all", all),
58 ("all.dim", all_dim),
59 ("all.dims", all_dims),
60 ("allclose", allclose),
61 ("amax", amax),
62 ("angle", angle),
63 ("any", any),
64 ("any.dim", any_dim),
65 ("any.dims", any_dims),
66 ("arange", arange),
67 ("arange.start", arange_start),
68 ("arange.start_step", arange_start),
69 ("argmax", argmax),
70 ("argmin", argmin),
71 ("atan", atan),
72 ("atan_", atan_),
73 ("avg_pool2d", avg_pool2d),
74 ("avg_pool2d_backward", avg_pool2d_backward),
75 ("baddbmm", baddbmm),
76 ("bitwise_and.Scalar", bitwise_and_scalar),
77 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
78 ("bitwise_and.Tensor", bitwise_and_tensor),
79 ("bitwise_and_.Scalar", bitwise_and_scalar_),
80 ("bitwise_and_.Tensor", bitwise_and_tensor_),
81 ("bitwise_left_shift", bitwise_left_shift),
82 ("bitwise_not", bitwise_not),
83 ("bitwise_not_", bitwise_not_),
84 ("bitwise_or.Scalar", bitwise_or_scalar),
85 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
86 ("bitwise_or.Tensor", bitwise_or_tensor),
87 ("bitwise_or_.Scalar", bitwise_or_scalar_),
88 ("bitwise_or_.Tensor", bitwise_or_tensor_),
89 ("bitwise_right_shift", bitwise_right_shift),
90 ("bmm", bmm),
91 ("bmm.out", bmm_out),
92 ("cat", cat),
93 ("celu", celu),
94 ("celu_", celu_),
95 ("ceil", ceil),
96 ("ceil_", ceil_),
97 ("ceil.out", ceil_out),
98 ("clamp", clamp),
99 ("clamp.Tensor", clamp_tensor),
100 ("clamp_min", clamp_min),
101 ("clamp_", clamp_),
102 ("clamp_.Tensor", clamp_tensor_),
103 ("clamp_min_", clamp_min_),
104 ("constant_pad_nd", constant_pad_nd),
105 # ("contiguous", contiguous),
106 ("conv1d", conv1d),
107 ("conv1d.padding", conv1d),
108 ("conv2d", conv2d),
109 ("conv2d.padding", conv2d),
110 ("conv3d", conv3d),
111 ("conv3d.padding", conv3d),
112 (
113 "copy_",
114 copy_,
115 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
116 ),
117 ("cos", cos),
118 ("cos_", cos_),
119 ("count_nonzero", count_nonzero),
120 ("cummax", cummax),
121 ("cummin", cummin),
122 ("cumsum", cumsum),
123 ("cumsum.out", cumsum_out),
124 ("diag", diag),
125 ("diag_embed", diag_embed),
126 ("diagonal_backward", diagonal_backward),
127 ("div.Scalar", true_divide),
128 ("div.Scalar_mode", div_mode),
129 ("div.Tensor", true_divide),
130 ("div.Tensor_mode", div_mode),
131 ("div.out", true_divide_out),
132 ("div_.Scalar", true_divide_),
133 ("div_.Scalar_mode", div_mode_),
134 ("div_.Tensor", true_divide_),
135 ("div_.Tensor_mode", div_mode_),
136 ("divide.Scalar", true_divide),
137 ("divide.Scalar_mode", div_mode),
138 ("divide.Tensor", true_divide),
139 ("divide.Tensor_mode", div_mode),
140 ("divide_.Scalar", true_divide_),
141 ("divide_.Scalar_mode", div_mode_),
142 ("divide_.Tensor", true_divide_),
143 ("divide_.Tensor_mode", div_mode_),
144 ("dot", dot),
145 ("elu", elu),
146 ("elu_", elu_),
147 ("elu_backward", elu_backward),
148 ("embedding", embedding),
149 ("embedding_backward", embedding_backward),
150 ("eq.Scalar", eq_scalar),
151 ("eq.Tensor", eq),
152 ("equal", equal),
153 ("erf", erf),
154 ("erf_", erf_),
155 ("exp", exp),
156 ("exp_", exp_),
157 ("exp.out", exp_out),
158 ("exp2", exp2),
159 ("exp2_", exp2_),
160 ("exponential_", exponential_),
161 ("eye", eye),
162 ("eye.m", eye_m),
163 ("fill.Scalar", fill_scalar),
164 ("fill.Tensor", fill_tensor),
165 ("fill_.Scalar", fill_scalar_),
166 ("fill_.Tensor", fill_tensor_),
167 ("flip", flip),
168 ("floor_divide", floor_divide),
169 ("floor_divide.Scalar", floor_divide),
170 ("floor_divide_.Scalar", floor_divide_),
171 ("floor_divide_.Tensor", floor_divide_),
172 ("full", full),
173 ("full_like", full_like),
174 ("gather", gather),
175 ("gather_backward", gather_backward),
176 ("ge.Scalar", ge_scalar),
177 ("ge.Tensor", ge),
178 ("gelu", gelu),
179 ("gelu_", gelu_),
180 ("gelu_backward", gelu_backward),
181 ("glu", glu),
182 ("glu_backward", glu_backward),
183 ("gt.Scalar", gt_scalar),
184 ("gt.Tensor", gt),
185 ("hstack", hstack),
186 ("index.Tensor", index),
187 ("index_add", index_add),
188 ("index_add_", index_add_),
189 ("index_put", index_put),
190 ("index_put_", index_put_),
191 ("index_select", index_select),
192 ("isclose", isclose),
193 ("isfinite", isfinite),
194 ("isin.Scalar_Tensor", isin),
195 ("isin.Tensor_Scalar", isin),
196 ("isin.Tensor_Tensor", isin),
197 ("isinf", isinf),
198 ("isnan", isnan),
199 ("kron", kron),
200 ("le.Scalar", le_scalar),
201 ("le.Tensor", le),
202 ("lerp.Scalar", lerp_scalar),
203 ("lerp.Tensor", lerp_tensor),
204 ("lerp_.Scalar", lerp_scalar_),
205 ("lerp_.Tensor", lerp_tensor_),
206 ("linalg_vector_norm", vector_norm),
207 ("linspace", linspace),
208 ("log", log),
209 ("log_sigmoid", log_sigmoid),
210 ("logical_and", logical_and),
211 ("logical_and_", logical_and_),
212 ("logical_not", logical_not),
213 ("logical_or", logical_or),
214 ("logical_or_", logical_or_),
215 ("logical_xor", logical_xor),
216 ("logspace", logspace),
217 ("lt.Scalar", lt_scalar),
218 ("lt.Tensor", lt),
219 ("masked_fill.Scalar", masked_fill),
220 ("masked_fill.Tensor", masked_fill),
221 ("masked_fill_.Scalar", masked_fill_),
222 ("masked_fill_.Tensor", masked_fill_),
223 ("masked_scatter", masked_scatter),
224 ("masked_scatter_", masked_scatter_),
225 ("masked_select", masked_select),
226 ("max", max),
227 ("max.dim", max_dim),
228 ("maximum", maximum),
229 ("max_pool2d_with_indices", max_pool2d_with_indices),
230 ("max_pool2d_backward", max_pool2d_backward),
231 ("mean", mean),
232 ("mean.dim", mean_dim),
233 ("min", min),
234 ("min.dim", min_dim),
235 ("minimum", minimum),
236 ("mm", mm),
237 ("mm.out", mm_out),
238 ("mse_loss", mse_loss),
239 ("mul.Tensor", mul),
240 ("mul_.Tensor", mul_),
241 ("multinomial", multinomial),
242 ("mv", mv),
243 ("nan_to_num", nan_to_num),
244 ("native_batch_norm", batch_norm),
245 ("native_batch_norm_backward", batch_norm_backward),
246 ("native_dropout", dropout),
247 ("native_dropout_backward", dropout_backward),
248 ("native_group_norm", group_norm),
249 ("native_group_norm_backward", group_norm_backward),
250 ("native_layer_norm", layer_norm),
251 ("native_layer_norm_backward", layer_norm_backward),
252 ("ne.Scalar", ne_scalar),
253 ("ne.Tensor", ne),
254 ("neg", neg),
255 ("neg_", neg_),
256 ("nll_loss_backward", nll_loss_backward),
257 ("nll_loss_forward", nll_loss_forward),
258 ("nll_loss2d_backward", nll_loss2d_backward),
259 ("nll_loss2d_forward", nll_loss2d_forward),
260 ("nonzero", nonzero),
261 ("normal.float_Tensor", normal_float_tensor),
262 ("normal.Tensor_float", normal_tensor_float),
263 ("normal.Tensor_Tensor", normal_tensor_tensor),
264 ("normal_", normal_),
265 ("ones", ones),
266 ("ones_like", ones_like),
267 ("one_hot", one_hot),
268 ("pad", pad),
269 ("polar", polar),
270 ("pow.Scalar", pow_scalar),
271 ("pow.Tensor_Scalar", pow_tensor_scalar),
272 ("pow.Tensor_Tensor", pow_tensor_tensor),
273 ("pow_.Scalar", pow_tensor_scalar_),
274 ("pow_.Tensor", pow_tensor_tensor_),
275 ("prod", prod),
276 ("prod.dim_int", prod_dim),
277 ("quantile", quantile),
278 ("rand", rand),
279 ("rand_like", rand_like),
280 ("randn", randn),
281 ("randn_like", randn_like),
282 ("randperm", randperm),
283 ("reciprocal", reciprocal),
284 ("reciprocal_", reciprocal_),
285 ("relu", relu),
286 ("relu_", relu_),
287 ("remainder.Scalar", remainder),
288 ("remainder.Scalar_Tensor", remainder),
289 ("remainder.Tensor", remainder),
290 ("remainder_.Scalar", remainder_),
291 ("remainder_.Tensor", remainder_),
292 ("repeat", repeat),
293 ("repeat_interleave.self_int", repeat_interleave_self_int),
294 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
295 ("repeat_interleave.Tensor", repeat_interleave_tensor),
296 ("resolve_conj", resolve_conj),
297 ("resolve_neg", resolve_neg),
298 ("rms_norm", rms_norm),
299 ("rsqrt", rsqrt),
300 ("rsqrt_", rsqrt_),
301 ("scaled_softmax_backward", scaled_softmax_backward),
302 ("scaled_softmax_forward", scaled_softmax_forward),
303 ("scatter.reduce", scatter),
304 ("scatter.src", scatter),
305 ("scatter_.reduce", scatter_),
306 ("scatter_.src", scatter_),
307 ("scatter_add_", scatter_add_),
308 ("select_scatter", select_scatter),
309 ("sigmoid", sigmoid),
310 ("sigmoid_", sigmoid_),
311 ("sigmoid_backward", sigmoid_backward),
312 ("silu", silu),
313 ("silu_", silu_),
314 ("silu_backward", silu_backward),
315 ("sin", sin),
316 ("sin_", sin_),
317 ("slice_scatter", slice_scatter),
318 ("softplus", softplus),
319 ("sort", sort),
320 ("sort.stable", sort_stable),
321 ("sqrt", sqrt),
322 ("sqrt_", sqrt_),
323 ("stack", stack),
324 ("std.correction", std),
325 ("sub.Tensor", sub),
326 ("sub_.Tensor", sub_),
327 ("sum", sum),
328 ("sum.dim_IntList", sum_dim),
329 ("sum.IntList_out", sum_dim_out),
330 ("sum.out", sum_out),
331 ("tan", tan),
332 ("tan_", tan_),
333 ("tanh", tanh),
334 ("tanh_", tanh_),
335 ("tanh_backward", tanh_backward),
336 ("threshold", threshold),
337 ("threshold_backward", threshold_backward),
338 ("tile", tile),
339 ("topk", topk),
340 ("trace", trace),
341 ("triu", triu),
342 ("triu_", triu_),
343 ("true_divide.Scalar", true_divide),
344 ("true_divide.Tensor", true_divide),
345 ("true_divide_.Scalar", true_divide_),
346 ("true_divide_.Tensor", true_divide_),
347 ("unfold_backward", unfold_backward),
348 ("uniform_", uniform_),
349 ("upsample_linear1d", upsample_linear1d),
350 ("upsample_nearest1d", upsample_nearest1d),
351 ("upsample_nearest2d", upsample_nearest2d),
352 ("var_mean.correction", var_mean),
353 ("vdot", vdot),
354 ("vstack", vstack),
355 ("where.ScalarOther", where_scalar_other),
356 ("where.ScalarSelf", where_scalar_self),
357 ("where.self", where_self),
358 ("where.self_out", where_self_out),
359 ("zeros", zeros),
360 ("zero_", zero_),
361 ("zeros_like", zeros_like),
362)
364# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
365FULL_CONFIG_BY_FUNC = {}
366for _item in _FULL_CONFIG:
367 if not _item or len(_item) < 2:
368 continue
369 fn = _item[1]
370 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
371 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
374def enable(
375 lib=aten_lib,
376 unused=None,
377 registrar=registrar,
378 record=False,
379 once=False,
380 path=None,
381):
382 """Register all FlagGems ops except those explicitly excluded.
384 Args:
385 lib: torch.library.Library instance to register into. Defaults to the
386 global `aten_lib` (IMPL mode).
387 unused: Which ops to skip. Supported forms:
388 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
389 - str path to a YAML file ending with .yml/.yaml containing an
390 `exclude:` list.
391 - "default" or None: auto-load vendor/arch-specific
392 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
393 registrar: Registrar class; defaults to `Register`.
394 record: Whether to enable FlagGems logging.
395 once: When True, log only once.
396 path: Optional log output path when recording.
398 Notes:
399 - If the exclude list/YAML resolves to empty, all ops are registered.
400 """
401 global current_work_registrar
402 exclude_ops = resolve_user_setting(unused, "exclude")
403 current_work_registrar = registrar(
404 _FULL_CONFIG,
405 user_include_ops=[],
406 user_exclude_ops=exclude_ops,
407 cpp_patched_ops=list(set(aten_patch_list)),
408 lib=lib,
409 )
410 setup_flaggems_logging(path=path, record=record, once=once)
413def only_enable(
414 lib=aten_lib,
415 include=None,
416 registrar=registrar,
417 record=False,
418 once=False,
419 path=None,
420):
421 """Register only the specified FlagGems ops and skip the rest.
423 Args:
424 lib: torch.library.Library instance to register into. Defaults to the
425 global `aten_lib` (IMPL mode).
426 include: Which ops to register. Supported forms:
427 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
428 - str path to a YAML file ending with .yml/.yaml (expects a list or
429 an `include:` key).
430 - "default" or None: auto-load vendor/arch-specific
431 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
432 registrar: Registrar class; defaults to `Register`.
433 record: Whether to enable FlagGems logging.
434 once: When True, log only once.
435 path: Optional log output path when recording.
437 Classic usage:
438 - Only register a few ops:
439 only_enable(include=["rms_norm", "softmax"])
440 - Use vendor default YAML:
441 only_enable(include="default") # or include=None
442 - Use a custom YAML:
443 only_enable(include="/path/to/only_enable.yaml")
445 Notes:
446 - If the include list/YAML resolves to empty or none of the names match
447 known ops, the function warns and returns without registering.
448 """
449 include_ops = resolve_user_setting(include, "include")
450 if not include_ops:
451 warnings.warn(
452 "only_enable failed: No include entries resolved from list or yaml."
453 )
454 return
456 global current_work_registrar
457 current_work_registrar = registrar(
458 _FULL_CONFIG,
459 user_include_ops=include_ops,
460 user_exclude_ops=[],
461 cpp_patched_ops=list(set(aten_patch_list)),
462 full_config_by_func=FULL_CONFIG_BY_FUNC,
463 lib=lib,
464 )
465 setup_flaggems_logging(path=path, record=record, once=once)
468class use_gems:
469 """
470 The 'include' parameter has higher priority than 'exclude'.
471 When 'include' is not None, use_gems will not process 'exclude'.
472 """
474 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
475 self.lib = torch.library.Library("aten", "IMPL")
476 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
477 self.include = include if isinstance(include, (list, tuple, set, str)) else []
478 self.registrar = Register
479 self.record = record
480 self.once = once
481 self.path = path
483 def __enter__(self):
484 if self.include:
485 only_enable(
486 lib=self.lib,
487 include=self.include,
488 registrar=self.registrar,
489 record=self.record,
490 once=self.once,
491 path=self.path,
492 )
493 else:
494 enable(
495 lib=self.lib,
496 unused=self.exclude,
497 registrar=self.registrar,
498 record=self.record,
499 once=self.once,
500 path=self.path,
501 )
503 def __exit__(self, exc_type, exc_val, exc_tb):
504 global current_work_registrar
505 if torch.__version__ >= "2.5":
506 self.lib._destroy()
507 del self.lib
508 del self.exclude
509 del self.include
510 del self.registrar
511 del current_work_registrar
512 if self.record:
513 teardown_flaggems_logging()
515 @property
516 def experimental_ops(self):
517 import flag_gems.experimental_ops
519 return flag_gems.experimental_ops
522def all_registered_ops():
523 return current_work_registrar.get_all_ops()
526def all_registered_keys():
527 return current_work_registrar.get_all_keys()
530__all__ = [
531 "enable",
532 "only_enable",
533 "use_gems",
534 "all_registered_ops",
535 "all_registered_keys",
536]