Coverage for src/flag_gems/ops/__init__.py: 100%
222 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1from flag_gems.ops._functional_sym_constrain_range_for_size import (
2 _functional_sym_constrain_range_for_size,
3)
4from flag_gems.ops._safe_softmax import _safe_softmax
5from flag_gems.ops._upsample_nearest_exact1d import _upsample_nearest_exact1d
6from flag_gems.ops.abs import abs, abs_
7from flag_gems.ops.absolute import absolute
8from flag_gems.ops.acos import acos
9from flag_gems.ops.add import add, add_
10from flag_gems.ops.addcdiv import addcdiv
11from flag_gems.ops.addcmul import addcmul
12from flag_gems.ops.addmm import addmm, addmm_out
13from flag_gems.ops.addmv import addmv, addmv_out
14from flag_gems.ops.addr import addr
15from flag_gems.ops.alias_copy import alias_copy, alias_copy_out
16from flag_gems.ops.all import all, all_dim, all_dims
17from flag_gems.ops.amax import amax
18from flag_gems.ops.angle import angle
19from flag_gems.ops.any import any, any_dim, any_dims
20from flag_gems.ops.arange import arange, arange_start
21from flag_gems.ops.arcsinh import arcsinh, arcsinh_out
22from flag_gems.ops.arcsinh_ import arcsinh_
23from flag_gems.ops.arctanh_ import arctanh_
24from flag_gems.ops.argmax import argmax
25from flag_gems.ops.argmin import argmin
26from flag_gems.ops.asinh_ import asinh_
27from flag_gems.ops.atan import atan, atan_
28from flag_gems.ops.attention import (
29 ScaleDotProductAttention,
30 flash_attention_forward,
31 flash_attn_varlen_func,
32 scaled_dot_product_attention,
33 scaled_dot_product_attention_backward,
34 scaled_dot_product_attention_forward,
35)
36from flag_gems.ops.avg_pool2d import avg_pool2d, avg_pool2d_backward
37from flag_gems.ops.baddbmm import baddbmm
38from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward
39from flag_gems.ops.bitwise_and import (
40 bitwise_and_scalar,
41 bitwise_and_scalar_,
42 bitwise_and_scalar_tensor,
43 bitwise_and_tensor,
44 bitwise_and_tensor_,
45)
46from flag_gems.ops.bitwise_left_shift import bitwise_left_shift
47from flag_gems.ops.bitwise_not import bitwise_not, bitwise_not_
48from flag_gems.ops.bitwise_or import (
49 bitwise_or_scalar,
50 bitwise_or_scalar_,
51 bitwise_or_scalar_tensor,
52 bitwise_or_tensor,
53 bitwise_or_tensor_,
54)
55from flag_gems.ops.bitwise_right_shift import bitwise_right_shift
56from flag_gems.ops.bmm import bmm, bmm_out
57from flag_gems.ops.cat import cat
58from flag_gems.ops.ceil import ceil, ceil_, ceil_out
59from flag_gems.ops.celu import celu, celu_
60from flag_gems.ops.clamp import (
61 clamp,
62 clamp_,
63 clamp_min,
64 clamp_min_,
65 clamp_tensor,
66 clamp_tensor_,
67)
68from flag_gems.ops.contiguous import contiguous
69from flag_gems.ops.conv1d import conv1d
70from flag_gems.ops.conv2d import conv2d
71from flag_gems.ops.conv3d import conv3d
72from flag_gems.ops.conv_depthwise2d import _conv_depthwise2d
73from flag_gems.ops.copy import copy, copy_
74from flag_gems.ops.cos import cos, cos_
75from flag_gems.ops.count_nonzero import count_nonzero
76from flag_gems.ops.cummax import cummax
77from flag_gems.ops.cummin import cummin
78from flag_gems.ops.cumsum import cumsum, cumsum_out, normed_cumsum
79from flag_gems.ops.diag import diag
80from flag_gems.ops.diag_embed import diag_embed
81from flag_gems.ops.diagonal import diagonal_backward
82from flag_gems.ops.digamma_ import digamma_
83from flag_gems.ops.div import (
84 div_mode,
85 div_mode_,
86 floor_divide,
87 floor_divide_,
88 remainder,
89 remainder_,
90 true_divide,
91 true_divide_,
92 true_divide_out,
93)
94from flag_gems.ops.dot import dot
95from flag_gems.ops.dropout import dropout, dropout_backward
96from flag_gems.ops.elu import elu, elu_, elu_backward
97from flag_gems.ops.embedding import embedding, embedding_backward
98from flag_gems.ops.embedding_dense_backward import embedding_dense_backward
99from flag_gems.ops.eq import eq, eq_scalar, equal
100from flag_gems.ops.erf import erf, erf_
101from flag_gems.ops.exp import exp, exp_, exp_out
102from flag_gems.ops.exp2 import exp2, exp2_
103from flag_gems.ops.exponential_ import exponential_
104from flag_gems.ops.eye import eye
105from flag_gems.ops.eye_m import eye_m
106from flag_gems.ops.fill import (
107 fill_scalar,
108 fill_scalar_,
109 fill_scalar_out,
110 fill_tensor,
111 fill_tensor_,
112 fill_tensor_out,
113)
114from flag_gems.ops.flip import flip
115from flag_gems.ops.floor_ import floor_
116from flag_gems.ops.fmin import fmin, fmin_out
117from flag_gems.ops.full import full
118from flag_gems.ops.full_like import full_like
119from flag_gems.ops.gather import gather, gather_backward
120from flag_gems.ops.ge import ge, ge_scalar
121from flag_gems.ops.gelu import gelu, gelu_, gelu_backward
122from flag_gems.ops.get_scheduler_metadata import get_scheduler_metadata
123from flag_gems.ops.glu import glu, glu_backward
124from flag_gems.ops.groupnorm import group_norm, group_norm_backward
125from flag_gems.ops.gt import gt, gt_scalar
126from flag_gems.ops.hardsigmoid import hardsigmoid, hardsigmoid_out
127from flag_gems.ops.hardswish_ import hardswish_
128from flag_gems.ops.hstack import hstack
129from flag_gems.ops.hypot import hypot, hypot_out
130from flag_gems.ops.i0 import i0, i0_out
131from flag_gems.ops.i0_ import i0_
132from flag_gems.ops.index import index
133from flag_gems.ops.index_add import index_add, index_add_
134from flag_gems.ops.index_put import index_put, index_put_
135from flag_gems.ops.index_select import index_select
136from flag_gems.ops.isclose import allclose, isclose
137from flag_gems.ops.isfinite import isfinite
138from flag_gems.ops.isin import isin
139from flag_gems.ops.isinf import isinf
140from flag_gems.ops.isnan import isnan
141from flag_gems.ops.kron import kron
142from flag_gems.ops.layernorm import layer_norm, layer_norm_backward
143from flag_gems.ops.le import le, le_scalar
144from flag_gems.ops.lerp import lerp_scalar, lerp_scalar_, lerp_tensor, lerp_tensor_
145from flag_gems.ops.lift_fresh_copy import lift_fresh_copy, lift_fresh_copy_out
146from flag_gems.ops.linspace import linspace
147from flag_gems.ops.log import log
148from flag_gems.ops.log1p_ import log1p_
149from flag_gems.ops.log_sigmoid import log_sigmoid
150from flag_gems.ops.log_softmax import log_softmax, log_softmax_backward
151from flag_gems.ops.logaddexp import logaddexp, logaddexp_out
152from flag_gems.ops.logical_and import logical_and, logical_and_
153from flag_gems.ops.logical_not import logical_not
154from flag_gems.ops.logical_or import logical_or, logical_or_
155from flag_gems.ops.logical_xor import logical_xor
156from flag_gems.ops.logit import logit, logit_out
157from flag_gems.ops.logit_ import logit_
158from flag_gems.ops.logspace import logspace
159from flag_gems.ops.lt import lt, lt_scalar
160from flag_gems.ops.margin_ranking_loss import margin_ranking_loss
161from flag_gems.ops.masked_fill import masked_fill, masked_fill_
162from flag_gems.ops.masked_scatter import masked_scatter, masked_scatter_
163from flag_gems.ops.masked_select import masked_select
164from flag_gems.ops.max import max, max_dim
165from flag_gems.ops.max_pool2d_with_indices import (
166 max_pool2d_backward,
167 max_pool2d_with_indices,
168)
169from flag_gems.ops.maximum import maximum
170from flag_gems.ops.mean import mean, mean_dim
171from flag_gems.ops.min import min, min_dim
172from flag_gems.ops.minimum import minimum
173from flag_gems.ops.mm import mm, mm_out
174from flag_gems.ops.mse_loss import mse_loss
175from flag_gems.ops.mul import mul, mul_
176from flag_gems.ops.multinomial import multinomial
177from flag_gems.ops.mv import mv
178from flag_gems.ops.nan_to_num import nan_to_num
179from flag_gems.ops.ne import ne, ne_scalar
180from flag_gems.ops.neg import neg, neg_
181from flag_gems.ops.nll_loss_nd import nll_loss_nd_backward, nll_loss_nd_forward
182from flag_gems.ops.nllloss import (
183 nll_loss2d_backward,
184 nll_loss2d_forward,
185 nll_loss_backward,
186 nll_loss_forward,
187)
188from flag_gems.ops.nonzero import nonzero
189from flag_gems.ops.normal import (
190 normal_,
191 normal_float_tensor,
192 normal_tensor_float,
193 normal_tensor_tensor,
194)
195from flag_gems.ops.one_hot import one_hot
196from flag_gems.ops.ones import ones
197from flag_gems.ops.ones_like import ones_like
198from flag_gems.ops.pad import constant_pad_nd, pad
199from flag_gems.ops.per_token_group_quant_fp8 import (
200 SUPPORTED_FP8_DTYPE,
201 per_token_group_quant_fp8,
202)
203from flag_gems.ops.pixel_unshuffle import pixel_unshuffle, pixel_unshuffle_out
204from flag_gems.ops.polar import polar
205from flag_gems.ops.pow import (
206 pow_scalar,
207 pow_tensor_scalar,
208 pow_tensor_scalar_,
209 pow_tensor_tensor,
210 pow_tensor_tensor_,
211)
212from flag_gems.ops.prelu import prelu
213from flag_gems.ops.prod import prod, prod_dim
214from flag_gems.ops.quantile import quantile
215from flag_gems.ops.rand import rand
216from flag_gems.ops.rand_like import rand_like
217from flag_gems.ops.randn import randn
218from flag_gems.ops.randn_like import randn_like
219from flag_gems.ops.randperm import randperm
220from flag_gems.ops.reciprocal import reciprocal, reciprocal_
221from flag_gems.ops.reflection_pad1d import reflection_pad1d, reflection_pad1d_out
222from flag_gems.ops.reflection_pad2d import reflection_pad2d, reflection_pad2d_out
223from flag_gems.ops.relu import relu, relu_
224from flag_gems.ops.relu6 import relu6
225from flag_gems.ops.repeat import repeat
226from flag_gems.ops.repeat_interleave import (
227 repeat_interleave_self_int,
228 repeat_interleave_self_tensor,
229 repeat_interleave_tensor,
230)
231from flag_gems.ops.replication_pad1d import replication_pad1d, replication_pad1d_out
232from flag_gems.ops.replication_pad3d import replication_pad3d
233from flag_gems.ops.resolve_conj import resolve_conj
234from flag_gems.ops.resolve_neg import resolve_neg
235from flag_gems.ops.rms_norm import rms_norm, rms_norm_backward, rms_norm_forward
236from flag_gems.ops.rrelu_with_noise_backward import rrelu_with_noise_backward
237from flag_gems.ops.rsqrt import rsqrt, rsqrt_
238from flag_gems.ops.scaled_softmax import scaled_softmax_backward, scaled_softmax_forward
239from flag_gems.ops.scatter import scatter, scatter_
240from flag_gems.ops.scatter_add_ import scatter_add_
241from flag_gems.ops.select_backward import select_backward
242from flag_gems.ops.select_scatter import select_scatter
243from flag_gems.ops.selu import selu
244from flag_gems.ops.selu_ import selu_
245from flag_gems.ops.sgn_ import sgn_
246from flag_gems.ops.sigmoid import sigmoid, sigmoid_, sigmoid_backward
247from flag_gems.ops.silu import silu, silu_, silu_backward
248from flag_gems.ops.sin import sin, sin_
249from flag_gems.ops.sinh_ import sinh_
250from flag_gems.ops.slice_backward import slice_backward
251from flag_gems.ops.slice_scatter import slice_scatter
252from flag_gems.ops.soft_margin_loss import soft_margin_loss, soft_margin_loss_out
253from flag_gems.ops.softmax import softmax, softmax_backward
254from flag_gems.ops.softplus import softplus
255from flag_gems.ops.softshrink import softshrink, softshrink_out
256from flag_gems.ops.sort import sort, sort_stable
257from flag_gems.ops.special_i0e import special_i0e, special_i0e_out
258from flag_gems.ops.special_i1 import special_i1, special_i1_out
259from flag_gems.ops.sqrt import sqrt, sqrt_
260from flag_gems.ops.stack import stack
261from flag_gems.ops.std import std
262from flag_gems.ops.sub import sub, sub_
263from flag_gems.ops.sum import sum, sum_dim, sum_dim_out, sum_out
264from flag_gems.ops.t_copy import t_copy, t_copy_out
265from flag_gems.ops.tan import tan, tan_
266from flag_gems.ops.tanh import tanh, tanh_, tanh_backward
267from flag_gems.ops.threshold import threshold, threshold_backward
268from flag_gems.ops.tile import tile
269from flag_gems.ops.to import to_copy
270from flag_gems.ops.topk import topk
271from flag_gems.ops.trace import trace
272from flag_gems.ops.tril import tril, tril_out
273from flag_gems.ops.triu import triu, triu_
274from flag_gems.ops.unfold_backward import unfold_backward
275from flag_gems.ops.uniform import uniform_
276from flag_gems.ops.unique import _unique2
277from flag_gems.ops.upsample_bicubic2d import upsample_bicubic2d
278from flag_gems.ops.upsample_bicubic2d_aa import _upsample_bicubic2d_aa
279from flag_gems.ops.upsample_linear1d import upsample_linear1d
280from flag_gems.ops.upsample_nearest1d import upsample_nearest1d
281from flag_gems.ops.upsample_nearest2d import upsample_nearest2d
282from flag_gems.ops.upsample_nearest3d import upsample_nearest3d
283from flag_gems.ops.var_mean import var_mean
284from flag_gems.ops.vdot import vdot
285from flag_gems.ops.vector_norm import vector_norm
286from flag_gems.ops.vstack import vstack
287from flag_gems.ops.w8a8_block_fp8_matmul import w8a8_block_fp8_matmul
288from flag_gems.ops.weightnorm import (
289 weight_norm_interface,
290 weight_norm_interface_backward,
291)
292from flag_gems.ops.where import (
293 where_scalar_other,
294 where_scalar_self,
295 where_self,
296 where_self_out,
297)
298from flag_gems.ops.zero import zero, zero_out
299from flag_gems.ops.zeros import zero_, zeros
300from flag_gems.ops.zeros_like import zeros_like
302__all__ = [
303 "_conv_depthwise2d",
304 "_functional_sym_constrain_range_for_size",
305 "_safe_softmax",
306 "_unique2",
307 "_upsample_bicubic2d_aa",
308 "_upsample_nearest_exact1d",
309 "abs",
310 "abs_",
311 "absolute",
312 "acos",
313 "add",
314 "add_",
315 "addcdiv",
316 "addcmul",
317 "addmm",
318 "addmm_out",
319 "addmv",
320 "addmv_out",
321 "addr",
322 "alias_copy",
323 "alias_copy_out",
324 "all",
325 "all_dim",
326 "all_dims",
327 "allclose",
328 "amax",
329 "angle",
330 "any",
331 "any_dim",
332 "any_dims",
333 "arange",
334 "arange_start",
335 "arcsinh",
336 "arcsinh_out",
337 "arctanh_",
338 "arcsinh_",
339 "argmax",
340 "argmin",
341 "asinh_",
342 "atan",
343 "atan_",
344 "avg_pool2d",
345 "avg_pool2d_backward",
346 "baddbmm",
347 "batch_norm",
348 "batch_norm_backward",
349 "bitwise_and_scalar",
350 "bitwise_and_scalar_",
351 "bitwise_and_scalar_tensor",
352 "bitwise_and_tensor",
353 "bitwise_and_tensor_",
354 "bitwise_left_shift",
355 "bitwise_not",
356 "bitwise_not_",
357 "bitwise_or_scalar",
358 "bitwise_or_scalar_",
359 "bitwise_or_scalar_tensor",
360 "bitwise_or_tensor",
361 "bitwise_or_tensor_",
362 "bitwise_right_shift",
363 "bmm",
364 "bmm_out",
365 "cat",
366 "ceil",
367 "ceil_",
368 "ceil_out",
369 "celu",
370 "celu_",
371 "clamp",
372 "clamp_",
373 "clamp_min",
374 "clamp_min_",
375 "clamp_tensor",
376 "clamp_tensor_",
377 "constant_pad_nd",
378 "contiguous",
379 "conv1d",
380 "conv2d",
381 "conv3d",
382 "copy",
383 "copy_",
384 "cos",
385 "cos_",
386 "count_nonzero",
387 "cummax",
388 "cummin",
389 "cumsum",
390 "cumsum_out",
391 "diag",
392 "diag_embed",
393 "diagonal_backward",
394 "digamma_",
395 "div_mode",
396 "div_mode_",
397 "dot",
398 "dropout",
399 "dropout_backward",
400 "elu",
401 "elu_",
402 "elu_backward",
403 "embedding",
404 "embedding_backward",
405 "embedding_dense_backward",
406 "eq",
407 "eq_scalar",
408 "equal",
409 "erf",
410 "erf_",
411 "exp",
412 "exp_",
413 "exp_out",
414 "exp2",
415 "exp2_",
416 "exponential_",
417 "eye",
418 "eye_m",
419 "fill_scalar",
420 "fill_scalar_",
421 "fill_scalar_out",
422 "fill_tensor",
423 "fill_tensor_",
424 "fill_tensor_out",
425 "flash_attention_forward",
426 "flash_attn_varlen_func",
427 "flip",
428 "floor_",
429 "floor_divide",
430 "floor_divide_",
431 "fmin",
432 "fmin_out",
433 "full",
434 "full_like",
435 "gather",
436 "gather_backward",
437 "ge",
438 "ge_scalar",
439 "gelu",
440 "gelu_",
441 "gelu_backward",
442 "get_scheduler_metadata",
443 "glu",
444 "glu_backward",
445 "group_norm",
446 "group_norm_backward",
447 "gt",
448 "gt_scalar",
449 "hardsigmoid",
450 "hardsigmoid_out",
451 "hardswish_",
452 "hstack",
453 "hypot",
454 "hypot_out",
455 "i0",
456 "i0_out",
457 "i0_",
458 "index",
459 "index_add",
460 "index_add_",
461 "index_put",
462 "index_put_",
463 "index_select",
464 "isclose",
465 "isfinite",
466 "isin",
467 "isinf",
468 "isnan",
469 "kron",
470 "layer_norm",
471 "layer_norm_backward",
472 "le",
473 "le_scalar",
474 "lerp_scalar",
475 "lerp_scalar_",
476 "lerp_tensor",
477 "lerp_tensor_",
478 "lift_fresh_copy",
479 "lift_fresh_copy_out",
480 "linspace",
481 "log",
482 "log_sigmoid",
483 "log_softmax",
484 "log_softmax_backward",
485 "log1p_",
486 "logaddexp",
487 "logaddexp_out",
488 "logical_and",
489 "logical_and_",
490 "logical_not",
491 "logical_or",
492 "logical_or_",
493 "logical_xor",
494 "logit",
495 "logit_out",
496 "logit_",
497 "logspace",
498 "lt",
499 "lt_scalar",
500 "margin_ranking_loss",
501 "masked_fill",
502 "masked_fill_",
503 "masked_scatter",
504 "masked_scatter_",
505 "masked_select",
506 "max",
507 "max_dim",
508 "max_pool2d_with_indices",
509 "max_pool2d_backward",
510 "maximum",
511 "mean",
512 "mean_dim",
513 "min",
514 "min_dim",
515 "minimum",
516 "mm",
517 "mm_out",
518 "mse_loss",
519 "mul",
520 "mul_",
521 "multinomial",
522 "mv",
523 "nan_to_num",
524 "ne",
525 "ne_scalar",
526 "neg",
527 "neg_",
528 "nll_loss_backward",
529 "nll_loss_forward",
530 "nll_loss2d_backward",
531 "nll_loss2d_forward",
532 "nll_loss_nd_forward",
533 "nll_loss_nd_backward",
534 "nonzero",
535 "normal_float_tensor",
536 "normal_tensor_float",
537 "normal_tensor_tensor",
538 "normal_",
539 "normed_cumsum",
540 "ones",
541 "ones_like",
542 "one_hot",
543 "pad",
544 "per_token_group_quant_fp8",
545 "pixel_unshuffle",
546 "pixel_unshuffle_out",
547 "polar",
548 "pow_scalar",
549 "pow_tensor_scalar",
550 "pow_tensor_scalar_",
551 "pow_tensor_tensor",
552 "pow_tensor_tensor_",
553 "prelu",
554 "prod",
555 "prod_dim",
556 "quantile",
557 "rand",
558 "rand_like",
559 "randn",
560 "randn_like",
561 "randperm",
562 "reciprocal",
563 "reciprocal_",
564 "reflection_pad2d",
565 "reflection_pad2d_out",
566 "reflection_pad1d",
567 "reflection_pad1d_out",
568 "relu",
569 "relu_",
570 "relu6",
571 "remainder",
572 "remainder_",
573 "repeat",
574 "repeat_interleave_self_int",
575 "repeat_interleave_self_tensor",
576 "repeat_interleave_tensor",
577 "replication_pad1d",
578 "replication_pad1d_out",
579 "replication_pad3d",
580 "resolve_conj",
581 "resolve_neg",
582 "rms_norm",
583 "rms_norm_backward",
584 "rms_norm_forward",
585 "rrelu_with_noise_backward",
586 "rsqrt",
587 "rsqrt_",
588 "scaled_dot_product_attention",
589 "scaled_dot_product_attention_backward",
590 "scaled_dot_product_attention_forward",
591 "scaled_softmax_backward",
592 "scaled_softmax_forward",
593 "scatter",
594 "scatter_",
595 "scatter_add_",
596 "select_backward",
597 "select_scatter",
598 "selu",
599 "selu_",
600 "sgn_",
601 "sigmoid",
602 "sigmoid_",
603 "sigmoid_backward",
604 "silu",
605 "silu_",
606 "silu_backward",
607 "sin",
608 "sin_",
609 "sinh_",
610 "slice_backward",
611 "slice_scatter",
612 "soft_margin_loss",
613 "soft_margin_loss_out",
614 "softmax",
615 "softmax_backward",
616 "softplus",
617 "softshrink",
618 "softshrink_out",
619 "sort",
620 "sort_stable",
621 "special_i1",
622 "special_i1_out",
623 "special_i0e",
624 "special_i0e_out",
625 "sqrt",
626 "sqrt_",
627 "stack",
628 "std",
629 "sub",
630 "sub_",
631 "sum",
632 "sum_dim",
633 "sum_dim_out",
634 "sum_out",
635 "ScaleDotProductAttention",
636 "SUPPORTED_FP8_DTYPE",
637 "t_copy",
638 "t_copy_out",
639 "tan",
640 "tan_",
641 "tanh",
642 "tanh_",
643 "tanh_backward",
644 "threshold",
645 "threshold_backward",
646 "tile",
647 "to_copy",
648 "topk",
649 "trace",
650 "tril",
651 "tril_out",
652 "triu",
653 "triu_",
654 "true_divide",
655 "true_divide_",
656 "true_divide_out",
657 "unfold_backward",
658 "uniform_",
659 "upsample_bicubic2d",
660 "upsample_linear1d",
661 "upsample_nearest1d",
662 "upsample_nearest2d",
663 "upsample_nearest3d",
664 "var_mean",
665 "vdot",
666 "vector_norm",
667 "vstack",
668 "w8a8_block_fp8_matmul",
669 "weight_norm_interface",
670 "weight_norm_interface_backward",
671 "where_scalar_other",
672 "where_scalar_self",
673 "where_self",
674 "where_self_out",
675 "zero",
676 "zero_out",
677 "zero_",
678 "zeros",
679 "zeros_like",
680]