Coverage for src/flag_gems/__init__.py: 89%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import warnings
3import torch
4from packaging import version
6from flag_gems import testing # noqa: F401
7from flag_gems import runtime
8from flag_gems.config import aten_patch_list, resolve_user_setting
9from flag_gems.experimental_ops import * # noqa: F403
10from flag_gems.fused import * # noqa: F403
11from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
12from flag_gems.modules import * # noqa: F403
13from flag_gems.ops import * # noqa: F403
14from flag_gems.patches import * # noqa: F403
15from flag_gems.runtime.register import Register
17__version__ = "4.2.1.rc.0"
18device = runtime.device.name
19vendor_name = runtime.device.vendor_name
20aten_lib = torch.library.Library("aten", "IMPL")
21registrar = Register
22current_work_registrar = None
23runtime.replace_customized_ops(globals())
26def torch_ge(v):
27 return version.parse(torch.__version__) >= version.parse(v)
30_FULL_CONFIG = (
31 ("_flash_attention_forward", flash_attention_forward),
32 ("_log_softmax", log_softmax),
33 ("_log_softmax_backward_data", log_softmax_backward),
34 ("_softmax", softmax),
35 ("_softmax_backward_data", softmax_backward),
36 (
37 "_to_copy",
38 to_copy,
39 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
40 ),
41 ("_unique2", _unique2),
42 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
43 ("_weight_norm_interface", weight_norm_interface),
44 ("_weight_norm_interface_backward", weight_norm_interface_backward),
45 ("abs", abs),
46 ("abs_", abs_),
47 ("absolute", absolute),
48 ("acos", acos),
49 ("add.Tensor", add),
50 ("add_.Tensor", add_),
51 ("addcdiv", addcdiv),
52 ("addcmul", addcmul),
53 ("addmv", addmv),
54 ("addmv.out", addmv_out),
55 ("addmm", addmm),
56 ("addmm.out", addmm_out),
57 ("addr", addr),
58 ("all", all),
59 ("all.dim", all_dim),
60 ("all.dims", all_dims),
61 ("allclose", allclose),
62 ("amax", amax),
63 ("angle", angle),
64 ("any", any),
65 ("any.dim", any_dim),
66 ("any.dims", any_dims),
67 ("arange", arange),
68 ("arange.start", arange_start),
69 ("arange.start_step", arange_start),
70 ("argmax", argmax),
71 ("argmin", argmin),
72 ("atan", atan),
73 ("atan_", atan_),
74 ("avg_pool2d", avg_pool2d),
75 ("avg_pool2d_backward", avg_pool2d_backward),
76 ("baddbmm", baddbmm),
77 ("bincount", bincount),
78 ("bitwise_and.Scalar", bitwise_and_scalar),
79 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
80 ("bitwise_and.Tensor", bitwise_and_tensor),
81 ("bitwise_and_.Scalar", bitwise_and_scalar_),
82 ("bitwise_and_.Tensor", bitwise_and_tensor_),
83 ("bitwise_left_shift", bitwise_left_shift),
84 ("bitwise_not", bitwise_not),
85 ("bitwise_not_", bitwise_not_),
86 ("bitwise_or.Scalar", bitwise_or_scalar),
87 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
88 ("bitwise_or.Tensor", bitwise_or_tensor),
89 ("bitwise_or_.Scalar", bitwise_or_scalar_),
90 ("bitwise_or_.Tensor", bitwise_or_tensor_),
91 ("bitwise_right_shift", bitwise_right_shift),
92 ("bmm", bmm),
93 ("bmm.out", bmm_out),
94 ("cat", cat),
95 ("celu", celu),
96 ("celu_", celu_),
97 ("ceil", ceil),
98 ("ceil_", ceil_),
99 ("ceil.out", ceil_out),
100 ("clamp", clamp),
101 ("clamp.Tensor", clamp_tensor),
102 ("clamp_min", clamp_min),
103 ("clamp_", clamp_),
104 ("clamp_.Tensor", clamp_tensor_),
105 ("clamp_min_", clamp_min_),
106 ("constant_pad_nd", constant_pad_nd),
107 # ("contiguous", contiguous),
108 ("conv1d", conv1d),
109 ("conv1d.padding", conv1d),
110 ("conv2d", conv2d),
111 ("conv2d.padding", conv2d),
112 ("conv3d", conv3d),
113 ("conv3d.padding", conv3d),
114 (
115 "copy_",
116 copy_,
117 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
118 ),
119 ("cos", cos),
120 ("cos_", cos_),
121 ("count_nonzero", count_nonzero),
122 ("cummax", cummax),
123 ("cummin", cummin),
124 ("cumsum", cumsum),
125 ("cumsum.out", cumsum_out),
126 ("diag", diag),
127 ("diag_embed", diag_embed),
128 ("diagonal_backward", diagonal_backward),
129 ("div.Scalar", true_divide),
130 ("div.Scalar_mode", div_mode),
131 ("div.Tensor", true_divide),
132 ("div.Tensor_mode", div_mode),
133 ("div.out", true_divide_out),
134 ("div_.Scalar", true_divide_),
135 ("div_.Scalar_mode", div_mode_),
136 ("div_.Tensor", true_divide_),
137 ("div_.Tensor_mode", div_mode_),
138 ("divide.Scalar", true_divide),
139 ("divide.Scalar_mode", div_mode),
140 ("divide.Tensor", true_divide),
141 ("divide.Tensor_mode", div_mode),
142 ("divide_.Scalar", true_divide_),
143 ("divide_.Scalar_mode", div_mode_),
144 ("divide_.Tensor", true_divide_),
145 ("divide_.Tensor_mode", div_mode_),
146 ("dot", dot),
147 ("elu", elu),
148 ("elu_", elu_),
149 ("elu_backward", elu_backward),
150 ("embedding", embedding),
151 ("embedding_backward", embedding_backward),
152 ("embedding_dense_backward", embedding_dense_backward),
153 ("eq.Scalar", eq_scalar),
154 ("eq.Tensor", eq),
155 ("equal", equal),
156 ("erf", erf),
157 ("erf_", erf_),
158 ("exp", exp),
159 ("exp_", exp_),
160 ("exp.out", exp_out),
161 ("exp2", exp2),
162 ("exp2_", exp2_),
163 ("exponential_", exponential_),
164 ("eye", eye),
165 ("eye.m", eye_m),
166 ("fill.Scalar", fill_scalar),
167 ("fill.Scalar_out", fill_scalar_out),
168 ("fill.Tensor", fill_tensor),
169 ("fill.Tensor_out", fill_tensor_out),
170 ("fill_.Scalar", fill_scalar_),
171 ("fill_.Tensor", fill_tensor_),
172 ("flip", flip),
173 ("floor_divide", floor_divide),
174 ("floor_divide.Scalar", floor_divide),
175 ("floor_divide_.Scalar", floor_divide_),
176 ("floor_divide_.Tensor", floor_divide_),
177 ("full", full),
178 ("full_like", full_like),
179 ("gather", gather),
180 ("gather_backward", gather_backward),
181 ("ge.Scalar", ge_scalar),
182 ("ge.Tensor", ge),
183 ("gelu", gelu),
184 ("gelu_", gelu_),
185 ("gelu_backward", gelu_backward),
186 ("glu", glu),
187 ("glu_backward", glu_backward),
188 ("gt.Scalar", gt_scalar),
189 ("gt.Tensor", gt),
190 ("hstack", hstack),
191 ("hypot", hypot),
192 ("i0", i0),
193 ("i0.out", i0_out),
194 ("index.Tensor", index),
195 ("index_add", index_add),
196 ("index_add_", index_add_),
197 ("index_put", index_put),
198 ("index_put_", index_put_),
199 ("index_select", index_select),
200 ("isclose", isclose),
201 ("isfinite", isfinite),
202 ("isin.Scalar_Tensor", isin),
203 ("isin.Tensor_Scalar", isin),
204 ("isin.Tensor_Tensor", isin),
205 ("isinf", isinf),
206 ("isnan", isnan),
207 ("kron", kron),
208 ("le.Scalar", le_scalar),
209 ("le.Tensor", le),
210 ("lerp.Scalar", lerp_scalar),
211 ("lerp.Tensor", lerp_tensor),
212 ("lerp_.Scalar", lerp_scalar_),
213 ("lerp_.Tensor", lerp_tensor_),
214 ("lift_fresh_copy", lift_fresh_copy),
215 ("linalg_vector_norm", vector_norm),
216 ("linspace", linspace),
217 ("log", log),
218 ("log_sigmoid", log_sigmoid),
219 ("logical_and", logical_and),
220 ("logical_and_", logical_and_),
221 ("logical_not", logical_not),
222 ("logical_or", logical_or),
223 ("logical_or_", logical_or_),
224 ("logical_xor", logical_xor),
225 ("logspace", logspace),
226 ("lt.Scalar", lt_scalar),
227 ("lt.Tensor", lt),
228 ("masked_fill.Scalar", masked_fill),
229 ("masked_fill.Tensor", masked_fill),
230 ("masked_fill_.Scalar", masked_fill_),
231 ("masked_fill_.Tensor", masked_fill_),
232 ("masked_scatter", masked_scatter),
233 ("masked_scatter_", masked_scatter_),
234 ("masked_select", masked_select),
235 ("max", max),
236 ("max.dim", max_dim),
237 ("maximum", maximum),
238 ("max_pool2d_with_indices", max_pool2d_with_indices),
239 ("max_pool2d_backward", max_pool2d_backward),
240 ("mean", mean),
241 ("mean.dim", mean_dim),
242 ("min", min),
243 ("min.dim", min_dim),
244 ("minimum", minimum),
245 ("mm", mm),
246 ("mm.out", mm_out),
247 ("mse_loss", mse_loss),
248 ("mul.Tensor", mul),
249 ("mul_.Tensor", mul_),
250 ("multinomial", multinomial),
251 ("mv", mv),
252 ("nan_to_num", nan_to_num),
253 ("native_batch_norm", batch_norm),
254 ("native_batch_norm_backward", batch_norm_backward),
255 ("native_dropout", dropout),
256 ("native_dropout_backward", dropout_backward),
257 ("native_group_norm", group_norm),
258 ("native_group_norm_backward", group_norm_backward),
259 ("native_layer_norm", layer_norm),
260 ("native_layer_norm_backward", layer_norm_backward),
261 ("ne.Scalar", ne_scalar),
262 ("ne.Tensor", ne),
263 ("neg", neg),
264 ("neg_", neg_),
265 ("nll_loss_backward", nll_loss_backward),
266 ("nll_loss_forward", nll_loss_forward),
267 ("nll_loss2d_backward", nll_loss2d_backward),
268 ("nll_loss2d_forward", nll_loss2d_forward),
269 ("nll_loss_nd_forward", nll_loss_nd_forward),
270 ("nll_loss_nd_backward", nll_loss_nd_backward),
271 ("nonzero", nonzero),
272 ("normal.float_Tensor", normal_float_tensor),
273 ("normal.Tensor_float", normal_tensor_float),
274 ("normal.Tensor_Tensor", normal_tensor_tensor),
275 ("normal_", normal_),
276 ("ones", ones),
277 ("ones_like", ones_like),
278 ("one_hot", one_hot),
279 ("pad", pad),
280 ("polar", polar),
281 ("pow.Scalar", pow_scalar),
282 ("pow.Tensor_Scalar", pow_tensor_scalar),
283 ("pow.Tensor_Tensor", pow_tensor_tensor),
284 ("pow_.Scalar", pow_tensor_scalar_),
285 ("pow_.Tensor", pow_tensor_tensor_),
286 ("prod", prod),
287 ("prod.dim_int", prod_dim),
288 ("quantile", quantile),
289 ("rand", rand),
290 ("rand_like", rand_like),
291 ("randn", randn),
292 ("randn_like", randn_like),
293 ("randperm", randperm),
294 ("reciprocal", reciprocal),
295 ("reciprocal_", reciprocal_),
296 ("relu", relu),
297 ("relu_", relu_),
298 ("remainder.Scalar", remainder),
299 ("remainder.Scalar_Tensor", remainder),
300 ("remainder.Tensor", remainder),
301 ("remainder_.Scalar", remainder_),
302 ("remainder_.Tensor", remainder_),
303 ("repeat", repeat),
304 ("repeat_interleave.self_int", repeat_interleave_self_int),
305 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
306 ("repeat_interleave.Tensor", repeat_interleave_tensor),
307 ("replication_pad3d", replication_pad3d),
308 ("resolve_conj", resolve_conj),
309 ("resolve_neg", resolve_neg),
310 ("rms_norm", rms_norm),
311 ("rsqrt", rsqrt),
312 ("rsqrt_", rsqrt_),
313 ("scaled_softmax_backward", scaled_softmax_backward),
314 ("scaled_softmax_forward", scaled_softmax_forward),
315 ("scatter.reduce", scatter),
316 ("scatter.src", scatter),
317 ("scatter_.reduce", scatter_),
318 ("scatter_.src", scatter_),
319 ("scatter_add_", scatter_add_),
320 ("select_scatter", select_scatter),
321 ("sgn_", sgn_),
322 ("sigmoid", sigmoid),
323 ("sigmoid_", sigmoid_),
324 ("sigmoid_backward", sigmoid_backward),
325 ("silu", silu),
326 ("silu_", silu_),
327 ("silu_backward", silu_backward),
328 ("sin", sin),
329 ("sin_", sin_),
330 ("sinh_", sinh_),
331 ("slice_backward", slice_backward),
332 ("slice_scatter", slice_scatter),
333 ("softplus", softplus),
334 ("sort", sort),
335 ("sort.stable", sort_stable),
336 ("sqrt", sqrt),
337 ("sqrt_", sqrt_),
338 ("stack", stack),
339 ("std.correction", std),
340 ("sub.Tensor", sub),
341 ("sub_.Tensor", sub_),
342 ("sum", sum),
343 ("sum.dim_IntList", sum_dim),
344 ("sum.IntList_out", sum_dim_out),
345 ("sum.out", sum_out),
346 ("tan", tan),
347 ("tan_", tan_),
348 ("tanh", tanh),
349 ("tanh_", tanh_),
350 ("tanh_backward", tanh_backward),
351 ("threshold", threshold),
352 ("threshold_backward", threshold_backward),
353 ("tile", tile),
354 ("topk", topk),
355 ("trace", trace),
356 ("triu", triu),
357 ("triu_", triu_),
358 ("true_divide.Scalar", true_divide),
359 ("true_divide.Tensor", true_divide),
360 ("true_divide_.Scalar", true_divide_),
361 ("true_divide_.Tensor", true_divide_),
362 ("unfold_backward", unfold_backward),
363 ("uniform_", uniform_),
364 ("upsample_bicubic2d", upsample_bicubic2d),
365 ("upsample_linear1d", upsample_linear1d),
366 ("upsample_nearest1d", upsample_nearest1d),
367 ("upsample_nearest2d", upsample_nearest2d),
368 ("upsample_nearest3d", upsample_nearest3d),
369 ("var_mean.correction", var_mean),
370 ("vdot", vdot),
371 ("vstack", vstack),
372 ("where.self", where_self),
373 ("where.self_out", where_self_out),
374 ("zeros", zeros),
375 ("zero_", zero_),
376 ("zeros_like", zeros_like),
377)
379# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
380FULL_CONFIG_BY_FUNC = {}
381for _item in _FULL_CONFIG:
382 if not _item or len(_item) < 2:
383 continue
384 fn = _item[1]
385 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
386 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
389def enable(
390 lib=aten_lib,
391 unused=None,
392 registrar=registrar,
393 record=False,
394 once=False,
395 path=None,
396):
397 """Register all FlagGems ops except those explicitly excluded.
399 Args:
400 lib: torch.library.Library instance to register into. Defaults to the
401 global `aten_lib` (IMPL mode).
402 unused: Which ops to skip. Supported forms:
403 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
404 - str path to a YAML file ending with .yml/.yaml containing an
405 `exclude:` list.
406 - "default" or None: auto-load vendor/arch-specific
407 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
408 registrar: Registrar class; defaults to `Register`.
409 record: Whether to enable FlagGems logging.
410 once: When True, log only once.
411 path: Optional log output path when recording.
413 Notes:
414 - If the exclude list/YAML resolves to empty, all ops are registered.
415 """
416 global current_work_registrar
417 exclude_ops = resolve_user_setting(unused, "exclude")
418 current_work_registrar = registrar(
419 _FULL_CONFIG,
420 user_include_ops=[],
421 user_exclude_ops=exclude_ops,
422 cpp_patched_ops=list(set(aten_patch_list)),
423 lib=lib,
424 )
425 setup_flaggems_logging(path=path, record=record, once=once)
428def only_enable(
429 lib=aten_lib,
430 include=None,
431 registrar=registrar,
432 record=False,
433 once=False,
434 path=None,
435):
436 """Register only the specified FlagGems ops and skip the rest.
438 Args:
439 lib: torch.library.Library instance to register into. Defaults to the
440 global `aten_lib` (IMPL mode).
441 include: Which ops to register. Supported forms:
442 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
443 - str path to a YAML file ending with .yml/.yaml (expects a list or
444 an `include:` key).
445 - "default" or None: auto-load vendor/arch-specific
446 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
447 registrar: Registrar class; defaults to `Register`.
448 record: Whether to enable FlagGems logging.
449 once: When True, log only once.
450 path: Optional log output path when recording.
452 Classic usage:
453 - Only register a few ops:
454 only_enable(include=["rms_norm", "softmax"])
455 - Use vendor default YAML:
456 only_enable(include="default") # or include=None
457 - Use a custom YAML:
458 only_enable(include="/path/to/only_enable.yaml")
460 Notes:
461 - If the include list/YAML resolves to empty or none of the names match
462 known ops, the function warns and returns without registering.
463 """
464 include_ops = resolve_user_setting(include, "include")
465 if not include_ops:
466 warnings.warn(
467 "only_enable failed: No include entries resolved from list or yaml."
468 )
469 return
471 global current_work_registrar
472 current_work_registrar = registrar(
473 _FULL_CONFIG,
474 user_include_ops=include_ops,
475 user_exclude_ops=[],
476 cpp_patched_ops=list(set(aten_patch_list)),
477 full_config_by_func=FULL_CONFIG_BY_FUNC,
478 lib=lib,
479 )
480 setup_flaggems_logging(path=path, record=record, once=once)
483class use_gems:
484 """
485 The 'include' parameter has higher priority than 'exclude'.
486 When 'include' is not None, use_gems will not process 'exclude'.
487 """
489 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
490 self.lib = torch.library.Library("aten", "IMPL")
491 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
492 self.include = include if isinstance(include, (list, tuple, set, str)) else []
493 self.registrar = Register
494 self.record = record
495 self.once = once
496 self.path = path
498 def __enter__(self):
499 if self.include:
500 only_enable(
501 lib=self.lib,
502 include=self.include,
503 registrar=self.registrar,
504 record=self.record,
505 once=self.once,
506 path=self.path,
507 )
508 else:
509 enable(
510 lib=self.lib,
511 unused=self.exclude,
512 registrar=self.registrar,
513 record=self.record,
514 once=self.once,
515 path=self.path,
516 )
518 def __exit__(self, exc_type, exc_val, exc_tb):
519 global current_work_registrar
520 if torch.__version__ >= "2.5":
521 self.lib._destroy()
522 del self.lib
523 del self.exclude
524 del self.include
525 del self.registrar
526 del current_work_registrar
527 if self.record:
528 teardown_flaggems_logging()
530 @property
531 def experimental_ops(self):
532 import flag_gems.experimental_ops
534 return flag_gems.experimental_ops
537def all_registered_ops():
538 return current_work_registrar.get_all_ops()
541def all_registered_keys():
542 return current_work_registrar.get_all_keys()
545__all__ = [
546 "enable",
547 "only_enable",
548 "use_gems",
549 "all_registered_ops",
550 "all_registered_keys",
551]