Coverage for src/flag_gems/__init__.py: 89%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import warnings
3import torch
4from packaging import version
6from flag_gems import testing # noqa: F401
7from flag_gems import runtime
8from flag_gems.config import aten_patch_list, resolve_user_setting
9from flag_gems.experimental_ops import * # noqa: F403
10from flag_gems.fused import * # noqa: F403
11from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
12from flag_gems.modules import * # noqa: F403
13from flag_gems.ops import * # noqa: F403
14from flag_gems.patches import * # noqa: F403
15from flag_gems.runtime.register import Register
17__version__ = "4.2.1.rc.0"
18device = runtime.device.name
19vendor_name = runtime.device.vendor_name
20aten_lib = torch.library.Library("aten", "IMPL")
21registrar = Register
22current_work_registrar = None
23runtime.replace_customized_ops(globals())
26def torch_ge(v):
27 return version.parse(torch.__version__) >= version.parse(v)
30_FULL_CONFIG = (
31 ("_flash_attention_forward", flash_attention_forward),
32 ("_log_softmax", log_softmax),
33 ("_log_softmax_backward_data", log_softmax_backward),
34 ("_softmax", softmax),
35 ("_softmax_backward_data", softmax_backward),
36 (
37 "_to_copy",
38 to_copy,
39 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
40 ),
41 ("_unique2", _unique2),
42 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
43 ("_weight_norm_interface", weight_norm_interface),
44 ("_weight_norm_interface_backward", weight_norm_interface_backward),
45 ("abs", abs),
46 ("abs_", abs_),
47 ("acos", acos),
48 ("add.Tensor", add),
49 ("add_.Tensor", add_),
50 ("addcdiv", addcdiv),
51 ("addcmul", addcmul),
52 ("addmv", addmv),
53 ("addmv.out", addmv_out),
54 ("addmm", addmm),
55 ("addmm.out", addmm_out),
56 ("addr", addr),
57 ("all", all),
58 ("all.dim", all_dim),
59 ("all.dims", all_dims),
60 ("allclose", allclose),
61 ("amax", amax),
62 ("angle", angle),
63 ("any", any),
64 ("any.dim", any_dim),
65 ("any.dims", any_dims),
66 ("arange", arange),
67 ("arange.start", arange_start),
68 ("arange.start_step", arange_start),
69 ("argmax", argmax),
70 ("argmin", argmin),
71 ("atan", atan),
72 ("atan_", atan_),
73 ("avg_pool2d", avg_pool2d),
74 ("avg_pool2d_backward", avg_pool2d_backward),
75 ("baddbmm", baddbmm),
76 ("bitwise_and.Scalar", bitwise_and_scalar),
77 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
78 ("bitwise_and.Tensor", bitwise_and_tensor),
79 ("bitwise_and_.Scalar", bitwise_and_scalar_),
80 ("bitwise_and_.Tensor", bitwise_and_tensor_),
81 ("bitwise_left_shift", bitwise_left_shift),
82 ("bitwise_not", bitwise_not),
83 ("bitwise_not_", bitwise_not_),
84 ("bitwise_or.Scalar", bitwise_or_scalar),
85 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
86 ("bitwise_or.Tensor", bitwise_or_tensor),
87 ("bitwise_or_.Scalar", bitwise_or_scalar_),
88 ("bitwise_or_.Tensor", bitwise_or_tensor_),
89 ("bitwise_right_shift", bitwise_right_shift),
90 ("bmm", bmm),
91 ("bmm.out", bmm_out),
92 ("cat", cat),
93 ("celu", celu),
94 ("celu_", celu_),
95 ("ceil", ceil),
96 ("ceil_", ceil_),
97 ("ceil.out", ceil_out),
98 ("clamp", clamp),
99 ("clamp.Tensor", clamp_tensor),
100 ("clamp_min", clamp_min),
101 ("clamp_", clamp_),
102 ("clamp_.Tensor", clamp_tensor_),
103 ("clamp_min_", clamp_min_),
104 ("constant_pad_nd", constant_pad_nd),
105 # ("contiguous", contiguous),
106 ("conv1d", conv1d),
107 ("conv1d.padding", conv1d),
108 ("conv2d", conv2d),
109 ("conv2d.padding", conv2d),
110 ("conv3d", conv3d),
111 ("conv3d.padding", conv3d),
112 (
113 "copy_",
114 copy_,
115 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
116 ),
117 ("cos", cos),
118 ("cos_", cos_),
119 ("count_nonzero", count_nonzero),
120 ("cummax", cummax),
121 ("cummin", cummin),
122 ("cumsum", cumsum),
123 ("cumsum.out", cumsum_out),
124 ("diag", diag),
125 ("diag_embed", diag_embed),
126 ("diagonal_backward", diagonal_backward),
127 ("div.Scalar", true_divide),
128 ("div.Scalar_mode", div_mode),
129 ("div.Tensor", true_divide),
130 ("div.Tensor_mode", div_mode),
131 ("div.out", true_divide_out),
132 ("div_.Scalar", true_divide_),
133 ("div_.Scalar_mode", div_mode_),
134 ("div_.Tensor", true_divide_),
135 ("div_.Tensor_mode", div_mode_),
136 ("divide.Scalar", true_divide),
137 ("divide.Scalar_mode", div_mode),
138 ("divide.Tensor", true_divide),
139 ("divide.Tensor_mode", div_mode),
140 ("divide_.Scalar", true_divide_),
141 ("divide_.Scalar_mode", div_mode_),
142 ("divide_.Tensor", true_divide_),
143 ("divide_.Tensor_mode", div_mode_),
144 ("dot", dot),
145 ("elu", elu),
146 ("elu_", elu_),
147 ("elu_backward", elu_backward),
148 ("embedding", embedding),
149 ("embedding_backward", embedding_backward),
150 ("eq.Scalar", eq_scalar),
151 ("eq.Tensor", eq),
152 ("equal", equal),
153 ("erf", erf),
154 ("erf_", erf_),
155 ("exp", exp),
156 ("exp_", exp_),
157 ("exp.out", exp_out),
158 ("exp2", exp2),
159 ("exp2_", exp2_),
160 ("exponential_", exponential_),
161 ("eye", eye),
162 ("eye.m", eye_m),
163 ("fill.Scalar", fill_scalar),
164 ("fill.Tensor", fill_tensor),
165 ("fill_.Scalar", fill_scalar_),
166 ("fill_.Tensor", fill_tensor_),
167 ("flip", flip),
168 ("floor_divide", floor_divide),
169 ("floor_divide.Scalar", floor_divide),
170 ("floor_divide_.Scalar", floor_divide_),
171 ("floor_divide_.Tensor", floor_divide_),
172 ("full", full),
173 ("full_like", full_like),
174 ("gather", gather),
175 ("gather_backward", gather_backward),
176 ("ge.Scalar", ge_scalar),
177 ("ge.Tensor", ge),
178 ("gelu", gelu),
179 ("gelu_", gelu_),
180 ("gelu_backward", gelu_backward),
181 ("glu", glu),
182 ("glu_backward", glu_backward),
183 ("gt.Scalar", gt_scalar),
184 ("gt.Tensor", gt),
185 ("hstack", hstack),
186 ("index.Tensor", index),
187 ("index_add", index_add),
188 ("index_add_", index_add_),
189 ("index_put", index_put),
190 ("index_put_", index_put_),
191 ("index_select", index_select),
192 ("isclose", isclose),
193 ("isfinite", isfinite),
194 ("isin.Scalar_Tensor", isin),
195 ("isin.Tensor_Scalar", isin),
196 ("isin.Tensor_Tensor", isin),
197 ("isinf", isinf),
198 ("isnan", isnan),
199 ("kron", kron),
200 ("le.Scalar", le_scalar),
201 ("le.Tensor", le),
202 ("lerp.Scalar", lerp_scalar),
203 ("lerp.Tensor", lerp_tensor),
204 ("lerp_.Scalar", lerp_scalar_),
205 ("lerp_.Tensor", lerp_tensor_),
206 ("linalg_vector_norm", vector_norm),
207 ("linspace", linspace),
208 ("log", log),
209 ("log_sigmoid", log_sigmoid),
210 ("logical_and", logical_and),
211 ("logical_and_", logical_and_),
212 ("logical_not", logical_not),
213 ("logical_or", logical_or),
214 ("logical_or_", logical_or_),
215 ("logical_xor", logical_xor),
216 ("logspace", logspace),
217 ("lt.Scalar", lt_scalar),
218 ("lt.Tensor", lt),
219 ("masked_fill.Scalar", masked_fill),
220 ("masked_fill.Tensor", masked_fill),
221 ("masked_fill_.Scalar", masked_fill_),
222 ("masked_fill_.Tensor", masked_fill_),
223 ("masked_scatter", masked_scatter),
224 ("masked_scatter_", masked_scatter_),
225 ("masked_select", masked_select),
226 ("max", max),
227 ("max.dim", max_dim),
228 ("maximum", maximum),
229 ("max_pool2d_with_indices", max_pool2d_with_indices),
230 ("max_pool2d_backward", max_pool2d_backward),
231 ("mean", mean),
232 ("mean.dim", mean_dim),
233 ("min", min),
234 ("min.dim", min_dim),
235 ("minimum", minimum),
236 ("mm", mm),
237 ("mm.out", mm_out),
238 ("mse_loss", mse_loss),
239 ("mul.Tensor", mul),
240 ("mul_.Tensor", mul_),
241 ("multinomial", multinomial),
242 ("mv", mv),
243 ("nan_to_num", nan_to_num),
244 ("native_batch_norm", batch_norm),
245 ("native_batch_norm_backward", batch_norm_backward),
246 ("native_dropout", dropout),
247 ("native_dropout_backward", dropout_backward),
248 ("native_group_norm", group_norm),
249 ("native_group_norm_backward", group_norm_backward),
250 ("native_layer_norm", layer_norm),
251 ("native_layer_norm_backward", layer_norm_backward),
252 ("ne.Scalar", ne_scalar),
253 ("ne.Tensor", ne),
254 ("neg", neg),
255 ("neg_", neg_),
256 ("nll_loss_backward", nll_loss_backward),
257 ("nll_loss_forward", nll_loss_forward),
258 ("nll_loss2d_backward", nll_loss2d_backward),
259 ("nll_loss2d_forward", nll_loss2d_forward),
260 ("nonzero", nonzero),
261 ("normal.float_Tensor", normal_float_tensor),
262 ("normal.Tensor_float", normal_tensor_float),
263 ("normal.Tensor_Tensor", normal_tensor_tensor),
264 ("normal_", normal_),
265 ("ones", ones),
266 ("ones_like", ones_like),
267 ("one_hot", one_hot),
268 ("pad", pad),
269 ("polar", polar),
270 ("pow.Scalar", pow_scalar),
271 ("pow.Tensor_Scalar", pow_tensor_scalar),
272 ("pow.Tensor_Tensor", pow_tensor_tensor),
273 ("pow_.Scalar", pow_tensor_scalar_),
274 ("pow_.Tensor", pow_tensor_tensor_),
275 ("prod", prod),
276 ("prod.dim_int", prod_dim),
277 ("quantile", quantile),
278 ("rand", rand),
279 ("rand_like", rand_like),
280 ("randn", randn),
281 ("randn_like", randn_like),
282 ("randperm", randperm),
283 ("reciprocal", reciprocal),
284 ("reciprocal_", reciprocal_),
285 ("relu", relu),
286 ("relu_", relu_),
287 ("remainder.Scalar", remainder),
288 ("remainder.Scalar_Tensor", remainder),
289 ("remainder.Tensor", remainder),
290 ("remainder_.Scalar", remainder_),
291 ("remainder_.Tensor", remainder_),
292 ("repeat", repeat),
293 ("repeat_interleave.self_int", repeat_interleave_self_int),
294 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
295 ("repeat_interleave.Tensor", repeat_interleave_tensor),
296 ("resolve_conj", resolve_conj),
297 ("resolve_neg", resolve_neg),
298 ("rms_norm", rms_norm),
299 ("rsqrt", rsqrt),
300 ("rsqrt_", rsqrt_),
301 ("scaled_softmax_backward", scaled_softmax_backward),
302 ("scaled_softmax_forward", scaled_softmax_forward),
303 ("scatter.reduce", scatter),
304 ("scatter.src", scatter),
305 ("scatter_.reduce", scatter_),
306 ("scatter_.src", scatter_),
307 ("scatter_add_", scatter_add_),
308 ("select_scatter", select_scatter),
309 ("sigmoid", sigmoid),
310 ("sigmoid_", sigmoid_),
311 ("sigmoid_backward", sigmoid_backward),
312 ("silu", silu),
313 ("silu_", silu_),
314 ("silu_backward", silu_backward),
315 ("sin", sin),
316 ("sin_", sin_),
317 ("slice_scatter", slice_scatter),
318 ("softplus", softplus),
319 ("sort", sort),
320 ("sort.stable", sort_stable),
321 ("sqrt", sqrt),
322 ("sqrt_", sqrt_),
323 ("stack", stack),
324 ("std.correction", std),
325 ("sub.Tensor", sub),
326 ("sub_.Tensor", sub_),
327 ("sum", sum),
328 ("sum.dim_IntList", sum_dim),
329 ("sum.IntList_out", sum_dim_out),
330 ("sum.out", sum_out),
331 ("tan", tan),
332 ("tan_", tan_),
333 ("tanh", tanh),
334 ("tanh_", tanh_),
335 ("tanh_backward", tanh_backward),
336 ("threshold", threshold),
337 ("threshold_backward", threshold_backward),
338 ("tile", tile),
339 ("topk", topk),
340 ("trace", trace),
341 ("triu", triu),
342 ("triu_", triu_),
343 ("true_divide.Scalar", true_divide),
344 ("true_divide.Tensor", true_divide),
345 ("true_divide_.Scalar", true_divide_),
346 ("true_divide_.Tensor", true_divide_),
347 ("uniform_", uniform_),
348 ("upsample_linear1d", upsample_linear1d),
349 ("upsample_nearest1d", upsample_nearest1d),
350 ("upsample_nearest2d", upsample_nearest2d),
351 ("var_mean.correction", var_mean),
352 ("vdot", vdot),
353 ("vstack", vstack),
354 ("where.ScalarOther", where_scalar_other),
355 ("where.ScalarSelf", where_scalar_self),
356 ("where.self", where_self),
357 ("where.self_out", where_self_out),
358 ("zeros", zeros),
359 ("zero_", zero_),
360 ("zeros_like", zeros_like),
361)
363# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
364FULL_CONFIG_BY_FUNC = {}
365for _item in _FULL_CONFIG:
366 if not _item or len(_item) < 2:
367 continue
368 fn = _item[1]
369 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
370 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
373def enable(
374 lib=aten_lib,
375 unused=None,
376 registrar=registrar,
377 record=False,
378 once=False,
379 path=None,
380):
381 """Register all FlagGems ops except those explicitly excluded.
383 Args:
384 lib: torch.library.Library instance to register into. Defaults to the
385 global `aten_lib` (IMPL mode).
386 unused: Which ops to skip. Supported forms:
387 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
388 - str path to a YAML file ending with .yml/.yaml containing an
389 `exclude:` list.
390 - "default" or None: auto-load vendor/arch-specific
391 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
392 registrar: Registrar class; defaults to `Register`.
393 record: Whether to enable FlagGems logging.
394 once: When True, log only once.
395 path: Optional log output path when recording.
397 Notes:
398 - If the exclude list/YAML resolves to empty, all ops are registered.
399 """
400 global current_work_registrar
401 exclude_ops = resolve_user_setting(unused, "exclude")
402 current_work_registrar = registrar(
403 _FULL_CONFIG,
404 user_include_ops=[],
405 user_exclude_ops=exclude_ops,
406 cpp_patched_ops=list(set(aten_patch_list)),
407 lib=lib,
408 )
409 setup_flaggems_logging(path=path, record=record, once=once)
412def only_enable(
413 lib=aten_lib,
414 include=None,
415 registrar=registrar,
416 record=False,
417 once=False,
418 path=None,
419):
420 """Register only the specified FlagGems ops and skip the rest.
422 Args:
423 lib: torch.library.Library instance to register into. Defaults to the
424 global `aten_lib` (IMPL mode).
425 include: Which ops to register. Supported forms:
426 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
427 - str path to a YAML file ending with .yml/.yaml (expects a list or
428 an `include:` key).
429 - "default" or None: auto-load vendor/arch-specific
430 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
431 registrar: Registrar class; defaults to `Register`.
432 record: Whether to enable FlagGems logging.
433 once: When True, log only once.
434 path: Optional log output path when recording.
436 Classic usage:
437 - Only register a few ops:
438 only_enable(include=["rms_norm", "softmax"])
439 - Use vendor default YAML:
440 only_enable(include="default") # or include=None
441 - Use a custom YAML:
442 only_enable(include="/path/to/only_enable.yaml")
444 Notes:
445 - If the include list/YAML resolves to empty or none of the names match
446 known ops, the function warns and returns without registering.
447 """
448 include_ops = resolve_user_setting(include, "include")
449 if not include_ops:
450 warnings.warn(
451 "only_enable failed: No include entries resolved from list or yaml."
452 )
453 return
455 global current_work_registrar
456 current_work_registrar = registrar(
457 _FULL_CONFIG,
458 user_include_ops=include_ops,
459 user_exclude_ops=[],
460 cpp_patched_ops=list(set(aten_patch_list)),
461 full_config_by_func=FULL_CONFIG_BY_FUNC,
462 lib=lib,
463 )
464 setup_flaggems_logging(path=path, record=record, once=once)
467class use_gems:
468 """
469 The 'include' parameter has higher priority than 'exclude'.
470 When 'include' is not None, use_gems will not process 'exclude'.
471 """
473 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
474 self.lib = torch.library.Library("aten", "IMPL")
475 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
476 self.include = include if isinstance(include, (list, tuple, set, str)) else []
477 self.registrar = Register
478 self.record = record
479 self.once = once
480 self.path = path
482 def __enter__(self):
483 if self.include:
484 only_enable(
485 lib=self.lib,
486 include=self.include,
487 registrar=self.registrar,
488 record=self.record,
489 once=self.once,
490 path=self.path,
491 )
492 else:
493 enable(
494 lib=self.lib,
495 unused=self.exclude,
496 registrar=self.registrar,
497 record=self.record,
498 once=self.once,
499 path=self.path,
500 )
502 def __exit__(self, exc_type, exc_val, exc_tb):
503 global current_work_registrar
504 if torch.__version__ >= "2.5":
505 self.lib._destroy()
506 del self.lib
507 del self.exclude
508 del self.include
509 del self.registrar
510 del current_work_registrar
511 if self.record:
512 teardown_flaggems_logging()
514 @property
515 def experimental_ops(self):
516 import flag_gems.experimental_ops
518 return flag_gems.experimental_ops
521def all_registered_ops():
522 return current_work_registrar.get_all_ops()
525def all_registered_keys():
526 return current_work_registrar.get_all_keys()
529__all__ = [
530 "enable",
531 "only_enable",
532 "use_gems",
533 "all_registered_ops",
534 "all_registered_keys",
535]