Coverage for src/flag_gems/ops/__init__.py: 100%
220 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1from flag_gems.ops._functional_sym_constrain_range_for_size import (
2 _functional_sym_constrain_range_for_size,
3)
4from flag_gems.ops._safe_softmax import _safe_softmax
5from flag_gems.ops._upsample_nearest_exact1d import _upsample_nearest_exact1d
6from flag_gems.ops.abs import abs, abs_
7from flag_gems.ops.absolute import absolute
8from flag_gems.ops.acos import acos
9from flag_gems.ops.add import add, add_
10from flag_gems.ops.addcdiv import addcdiv
11from flag_gems.ops.addcmul import addcmul
12from flag_gems.ops.addmm import addmm, addmm_out
13from flag_gems.ops.addmv import addmv, addmv_out
14from flag_gems.ops.addr import addr
15from flag_gems.ops.alias_copy import alias_copy, alias_copy_out
16from flag_gems.ops.all import all, all_dim, all_dims
17from flag_gems.ops.amax import amax
18from flag_gems.ops.angle import angle
19from flag_gems.ops.any import any, any_dim, any_dims
20from flag_gems.ops.arange import arange, arange_start
21from flag_gems.ops.arcsinh import arcsinh, arcsinh_out
22from flag_gems.ops.arcsinh_ import arcsinh_
23from flag_gems.ops.arctanh_ import arctanh_
24from flag_gems.ops.argmax import argmax
25from flag_gems.ops.argmin import argmin
26from flag_gems.ops.asinh_ import asinh_
27from flag_gems.ops.atan import atan, atan_
28from flag_gems.ops.attention import (
29 ScaleDotProductAttention,
30 flash_attention_forward,
31 flash_attn_varlen_func,
32 scaled_dot_product_attention,
33 scaled_dot_product_attention_backward,
34 scaled_dot_product_attention_forward,
35)
36from flag_gems.ops.avg_pool2d import avg_pool2d, avg_pool2d_backward
37from flag_gems.ops.baddbmm import baddbmm
38from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward
39from flag_gems.ops.bitwise_and import (
40 bitwise_and_scalar,
41 bitwise_and_scalar_,
42 bitwise_and_scalar_tensor,
43 bitwise_and_tensor,
44 bitwise_and_tensor_,
45)
46from flag_gems.ops.bitwise_left_shift import bitwise_left_shift
47from flag_gems.ops.bitwise_not import bitwise_not, bitwise_not_
48from flag_gems.ops.bitwise_or import (
49 bitwise_or_scalar,
50 bitwise_or_scalar_,
51 bitwise_or_scalar_tensor,
52 bitwise_or_tensor,
53 bitwise_or_tensor_,
54)
55from flag_gems.ops.bitwise_right_shift import bitwise_right_shift
56from flag_gems.ops.bmm import bmm, bmm_out
57from flag_gems.ops.cat import cat
58from flag_gems.ops.ceil import ceil, ceil_, ceil_out
59from flag_gems.ops.celu import celu, celu_
60from flag_gems.ops.clamp import (
61 clamp,
62 clamp_,
63 clamp_min,
64 clamp_min_,
65 clamp_tensor,
66 clamp_tensor_,
67)
68from flag_gems.ops.contiguous import contiguous
69from flag_gems.ops.conv1d import conv1d
70from flag_gems.ops.conv2d import conv2d
71from flag_gems.ops.conv3d import conv3d
72from flag_gems.ops.conv_depthwise2d import _conv_depthwise2d
73from flag_gems.ops.copy import copy, copy_
74from flag_gems.ops.cos import cos, cos_
75from flag_gems.ops.count_nonzero import count_nonzero
76from flag_gems.ops.cummax import cummax
77from flag_gems.ops.cummin import cummin
78from flag_gems.ops.cumsum import cumsum, cumsum_out, normed_cumsum
79from flag_gems.ops.diag import diag
80from flag_gems.ops.diag_embed import diag_embed
81from flag_gems.ops.diagonal import diagonal_backward
82from flag_gems.ops.digamma_ import digamma_
83from flag_gems.ops.div import (
84 div_mode,
85 div_mode_,
86 floor_divide,
87 floor_divide_,
88 remainder,
89 remainder_,
90 true_divide,
91 true_divide_,
92 true_divide_out,
93)
94from flag_gems.ops.dot import dot
95from flag_gems.ops.dropout import dropout, dropout_backward
96from flag_gems.ops.elu import elu, elu_, elu_backward
97from flag_gems.ops.embedding import embedding, embedding_backward
98from flag_gems.ops.embedding_dense_backward import embedding_dense_backward
99from flag_gems.ops.eq import eq, eq_scalar, equal
100from flag_gems.ops.erf import erf, erf_
101from flag_gems.ops.exp import exp, exp_, exp_out
102from flag_gems.ops.exp2 import exp2, exp2_
103from flag_gems.ops.exponential_ import exponential_
104from flag_gems.ops.eye import eye
105from flag_gems.ops.eye_m import eye_m
106from flag_gems.ops.fill import (
107 fill_scalar,
108 fill_scalar_,
109 fill_scalar_out,
110 fill_tensor,
111 fill_tensor_,
112 fill_tensor_out,
113)
114from flag_gems.ops.flip import flip
115from flag_gems.ops.floor_ import floor_
116from flag_gems.ops.fmin import fmin, fmin_out
117from flag_gems.ops.full import full
118from flag_gems.ops.full_like import full_like
119from flag_gems.ops.gather import gather, gather_backward
120from flag_gems.ops.ge import ge, ge_scalar
121from flag_gems.ops.gelu import gelu, gelu_, gelu_backward
122from flag_gems.ops.get_scheduler_metadata import get_scheduler_metadata
123from flag_gems.ops.glu import glu, glu_backward
124from flag_gems.ops.groupnorm import group_norm, group_norm_backward
125from flag_gems.ops.gt import gt, gt_scalar
126from flag_gems.ops.hardsigmoid import hardsigmoid, hardsigmoid_out
127from flag_gems.ops.hardswish_ import hardswish_
128from flag_gems.ops.hstack import hstack
129from flag_gems.ops.hypot import hypot, hypot_out
130from flag_gems.ops.i0 import i0, i0_out
131from flag_gems.ops.i0_ import i0_
132from flag_gems.ops.index import index
133from flag_gems.ops.index_add import index_add, index_add_
134from flag_gems.ops.index_put import index_put, index_put_
135from flag_gems.ops.index_select import index_select
136from flag_gems.ops.isclose import allclose, isclose
137from flag_gems.ops.isfinite import isfinite
138from flag_gems.ops.isin import isin
139from flag_gems.ops.isinf import isinf
140from flag_gems.ops.isnan import isnan
141from flag_gems.ops.kron import kron
142from flag_gems.ops.layernorm import layer_norm, layer_norm_backward
143from flag_gems.ops.le import le, le_scalar
144from flag_gems.ops.lerp import lerp_scalar, lerp_scalar_, lerp_tensor, lerp_tensor_
145from flag_gems.ops.lift_fresh_copy import lift_fresh_copy, lift_fresh_copy_out
146from flag_gems.ops.linspace import linspace
147from flag_gems.ops.log import log
148from flag_gems.ops.log1p_ import log1p_
149from flag_gems.ops.log_sigmoid import log_sigmoid
150from flag_gems.ops.log_softmax import log_softmax, log_softmax_backward
151from flag_gems.ops.logaddexp import logaddexp, logaddexp_out
152from flag_gems.ops.logical_and import logical_and, logical_and_
153from flag_gems.ops.logical_not import logical_not
154from flag_gems.ops.logical_or import logical_or, logical_or_
155from flag_gems.ops.logical_xor import logical_xor
156from flag_gems.ops.logit import logit, logit_out
157from flag_gems.ops.logspace import logspace
158from flag_gems.ops.lt import lt, lt_scalar
159from flag_gems.ops.margin_ranking_loss import margin_ranking_loss
160from flag_gems.ops.masked_fill import masked_fill, masked_fill_
161from flag_gems.ops.masked_scatter import masked_scatter, masked_scatter_
162from flag_gems.ops.masked_select import masked_select
163from flag_gems.ops.max import max, max_dim
164from flag_gems.ops.max_pool2d_with_indices import (
165 max_pool2d_backward,
166 max_pool2d_with_indices,
167)
168from flag_gems.ops.maximum import maximum
169from flag_gems.ops.mean import mean, mean_dim
170from flag_gems.ops.min import min, min_dim
171from flag_gems.ops.minimum import minimum
172from flag_gems.ops.mm import mm, mm_out
173from flag_gems.ops.mse_loss import mse_loss
174from flag_gems.ops.mul import mul, mul_
175from flag_gems.ops.multinomial import multinomial
176from flag_gems.ops.mv import mv
177from flag_gems.ops.nan_to_num import nan_to_num
178from flag_gems.ops.ne import ne, ne_scalar
179from flag_gems.ops.neg import neg, neg_
180from flag_gems.ops.nll_loss_nd import nll_loss_nd_backward, nll_loss_nd_forward
181from flag_gems.ops.nllloss import (
182 nll_loss2d_backward,
183 nll_loss2d_forward,
184 nll_loss_backward,
185 nll_loss_forward,
186)
187from flag_gems.ops.nonzero import nonzero
188from flag_gems.ops.normal import (
189 normal_,
190 normal_float_tensor,
191 normal_tensor_float,
192 normal_tensor_tensor,
193)
194from flag_gems.ops.one_hot import one_hot
195from flag_gems.ops.ones import ones
196from flag_gems.ops.ones_like import ones_like
197from flag_gems.ops.pad import constant_pad_nd, pad
198from flag_gems.ops.per_token_group_quant_fp8 import (
199 SUPPORTED_FP8_DTYPE,
200 per_token_group_quant_fp8,
201)
202from flag_gems.ops.pixel_unshuffle import pixel_unshuffle, pixel_unshuffle_out
203from flag_gems.ops.polar import polar
204from flag_gems.ops.pow import (
205 pow_scalar,
206 pow_tensor_scalar,
207 pow_tensor_scalar_,
208 pow_tensor_tensor,
209 pow_tensor_tensor_,
210)
211from flag_gems.ops.prelu import prelu
212from flag_gems.ops.prod import prod, prod_dim
213from flag_gems.ops.quantile import quantile
214from flag_gems.ops.rand import rand
215from flag_gems.ops.rand_like import rand_like
216from flag_gems.ops.randn import randn
217from flag_gems.ops.randn_like import randn_like
218from flag_gems.ops.randperm import randperm
219from flag_gems.ops.reciprocal import reciprocal, reciprocal_
220from flag_gems.ops.reflection_pad1d import reflection_pad1d, reflection_pad1d_out
221from flag_gems.ops.reflection_pad2d import reflection_pad2d, reflection_pad2d_out
222from flag_gems.ops.relu import relu, relu_
223from flag_gems.ops.relu6 import relu6
224from flag_gems.ops.repeat import repeat
225from flag_gems.ops.repeat_interleave import (
226 repeat_interleave_self_int,
227 repeat_interleave_self_tensor,
228 repeat_interleave_tensor,
229)
230from flag_gems.ops.replication_pad1d import replication_pad1d, replication_pad1d_out
231from flag_gems.ops.replication_pad3d import replication_pad3d
232from flag_gems.ops.resolve_conj import resolve_conj
233from flag_gems.ops.resolve_neg import resolve_neg
234from flag_gems.ops.rms_norm import rms_norm, rms_norm_backward, rms_norm_forward
235from flag_gems.ops.rrelu_with_noise_backward import rrelu_with_noise_backward
236from flag_gems.ops.rsqrt import rsqrt, rsqrt_
237from flag_gems.ops.scaled_softmax import scaled_softmax_backward, scaled_softmax_forward
238from flag_gems.ops.scatter import scatter, scatter_
239from flag_gems.ops.scatter_add_ import scatter_add_
240from flag_gems.ops.select_scatter import select_scatter
241from flag_gems.ops.selu import selu
242from flag_gems.ops.selu_ import selu_
243from flag_gems.ops.sgn_ import sgn_
244from flag_gems.ops.sigmoid import sigmoid, sigmoid_, sigmoid_backward
245from flag_gems.ops.silu import silu, silu_, silu_backward
246from flag_gems.ops.sin import sin, sin_
247from flag_gems.ops.sinh_ import sinh_
248from flag_gems.ops.slice_backward import slice_backward
249from flag_gems.ops.slice_scatter import slice_scatter
250from flag_gems.ops.soft_margin_loss import soft_margin_loss, soft_margin_loss_out
251from flag_gems.ops.softmax import softmax, softmax_backward
252from flag_gems.ops.softplus import softplus
253from flag_gems.ops.softshrink import softshrink, softshrink_out
254from flag_gems.ops.sort import sort, sort_stable
255from flag_gems.ops.special_i0e import special_i0e, special_i0e_out
256from flag_gems.ops.special_i1 import special_i1, special_i1_out
257from flag_gems.ops.sqrt import sqrt, sqrt_
258from flag_gems.ops.stack import stack
259from flag_gems.ops.std import std
260from flag_gems.ops.sub import sub, sub_
261from flag_gems.ops.sum import sum, sum_dim, sum_dim_out, sum_out
262from flag_gems.ops.t_copy import t_copy, t_copy_out
263from flag_gems.ops.tan import tan, tan_
264from flag_gems.ops.tanh import tanh, tanh_, tanh_backward
265from flag_gems.ops.threshold import threshold, threshold_backward
266from flag_gems.ops.tile import tile
267from flag_gems.ops.to import to_copy
268from flag_gems.ops.topk import topk
269from flag_gems.ops.trace import trace
270from flag_gems.ops.tril import tril, tril_out
271from flag_gems.ops.triu import triu, triu_
272from flag_gems.ops.unfold_backward import unfold_backward
273from flag_gems.ops.uniform import uniform_
274from flag_gems.ops.unique import _unique2
275from flag_gems.ops.upsample_bicubic2d import upsample_bicubic2d
276from flag_gems.ops.upsample_bicubic2d_aa import _upsample_bicubic2d_aa
277from flag_gems.ops.upsample_linear1d import upsample_linear1d
278from flag_gems.ops.upsample_nearest1d import upsample_nearest1d
279from flag_gems.ops.upsample_nearest2d import upsample_nearest2d
280from flag_gems.ops.upsample_nearest3d import upsample_nearest3d
281from flag_gems.ops.var_mean import var_mean
282from flag_gems.ops.vdot import vdot
283from flag_gems.ops.vector_norm import vector_norm
284from flag_gems.ops.vstack import vstack
285from flag_gems.ops.w8a8_block_fp8_matmul import w8a8_block_fp8_matmul
286from flag_gems.ops.weightnorm import (
287 weight_norm_interface,
288 weight_norm_interface_backward,
289)
290from flag_gems.ops.where import (
291 where_scalar_other,
292 where_scalar_self,
293 where_self,
294 where_self_out,
295)
296from flag_gems.ops.zero import zero, zero_out
297from flag_gems.ops.zeros import zero_, zeros
298from flag_gems.ops.zeros_like import zeros_like
300__all__ = [
301 "_conv_depthwise2d",
302 "_functional_sym_constrain_range_for_size",
303 "_safe_softmax",
304 "_unique2",
305 "_upsample_bicubic2d_aa",
306 "_upsample_nearest_exact1d",
307 "abs",
308 "abs_",
309 "absolute",
310 "acos",
311 "add",
312 "add_",
313 "addcdiv",
314 "addcmul",
315 "addmm",
316 "addmm_out",
317 "addmv",
318 "addmv_out",
319 "addr",
320 "alias_copy",
321 "alias_copy_out",
322 "all",
323 "all_dim",
324 "all_dims",
325 "allclose",
326 "amax",
327 "angle",
328 "any",
329 "any_dim",
330 "any_dims",
331 "arange",
332 "arange_start",
333 "arcsinh",
334 "arcsinh_out",
335 "arctanh_",
336 "arcsinh_",
337 "argmax",
338 "argmin",
339 "asinh_",
340 "atan",
341 "atan_",
342 "avg_pool2d",
343 "avg_pool2d_backward",
344 "baddbmm",
345 "batch_norm",
346 "batch_norm_backward",
347 "bitwise_and_scalar",
348 "bitwise_and_scalar_",
349 "bitwise_and_scalar_tensor",
350 "bitwise_and_tensor",
351 "bitwise_and_tensor_",
352 "bitwise_left_shift",
353 "bitwise_not",
354 "bitwise_not_",
355 "bitwise_or_scalar",
356 "bitwise_or_scalar_",
357 "bitwise_or_scalar_tensor",
358 "bitwise_or_tensor",
359 "bitwise_or_tensor_",
360 "bitwise_right_shift",
361 "bmm",
362 "bmm_out",
363 "cat",
364 "ceil",
365 "ceil_",
366 "ceil_out",
367 "celu",
368 "celu_",
369 "clamp",
370 "clamp_",
371 "clamp_min",
372 "clamp_min_",
373 "clamp_tensor",
374 "clamp_tensor_",
375 "constant_pad_nd",
376 "contiguous",
377 "conv1d",
378 "conv2d",
379 "conv3d",
380 "copy",
381 "copy_",
382 "cos",
383 "cos_",
384 "count_nonzero",
385 "cummax",
386 "cummin",
387 "cumsum",
388 "cumsum_out",
389 "diag",
390 "diag_embed",
391 "diagonal_backward",
392 "digamma_",
393 "div_mode",
394 "div_mode_",
395 "dot",
396 "dropout",
397 "dropout_backward",
398 "elu",
399 "elu_",
400 "elu_backward",
401 "embedding",
402 "embedding_backward",
403 "embedding_dense_backward",
404 "eq",
405 "eq_scalar",
406 "equal",
407 "erf",
408 "erf_",
409 "exp",
410 "exp_",
411 "exp_out",
412 "exp2",
413 "exp2_",
414 "exponential_",
415 "eye",
416 "eye_m",
417 "fill_scalar",
418 "fill_scalar_",
419 "fill_scalar_out",
420 "fill_tensor",
421 "fill_tensor_",
422 "fill_tensor_out",
423 "flash_attention_forward",
424 "flash_attn_varlen_func",
425 "flip",
426 "floor_",
427 "floor_divide",
428 "floor_divide_",
429 "fmin",
430 "fmin_out",
431 "full",
432 "full_like",
433 "gather",
434 "gather_backward",
435 "ge",
436 "ge_scalar",
437 "gelu",
438 "gelu_",
439 "gelu_backward",
440 "get_scheduler_metadata",
441 "glu",
442 "glu_backward",
443 "group_norm",
444 "group_norm_backward",
445 "gt",
446 "gt_scalar",
447 "hardsigmoid",
448 "hardsigmoid_out",
449 "hardswish_",
450 "hstack",
451 "hypot",
452 "hypot_out",
453 "i0",
454 "i0_out",
455 "i0_",
456 "index",
457 "index_add",
458 "index_add_",
459 "index_put",
460 "index_put_",
461 "index_select",
462 "isclose",
463 "isfinite",
464 "isin",
465 "isinf",
466 "isnan",
467 "kron",
468 "layer_norm",
469 "layer_norm_backward",
470 "le",
471 "le_scalar",
472 "lerp_scalar",
473 "lerp_scalar_",
474 "lerp_tensor",
475 "lerp_tensor_",
476 "lift_fresh_copy",
477 "lift_fresh_copy_out",
478 "linspace",
479 "log",
480 "log_sigmoid",
481 "log_softmax",
482 "log_softmax_backward",
483 "log1p_",
484 "logaddexp",
485 "logaddexp_out",
486 "logical_and",
487 "logical_and_",
488 "logical_not",
489 "logical_or",
490 "logical_or_",
491 "logical_xor",
492 "logit",
493 "logit_out",
494 "logspace",
495 "lt",
496 "lt_scalar",
497 "margin_ranking_loss",
498 "masked_fill",
499 "masked_fill_",
500 "masked_scatter",
501 "masked_scatter_",
502 "masked_select",
503 "max",
504 "max_dim",
505 "max_pool2d_with_indices",
506 "max_pool2d_backward",
507 "maximum",
508 "mean",
509 "mean_dim",
510 "min",
511 "min_dim",
512 "minimum",
513 "mm",
514 "mm_out",
515 "mse_loss",
516 "mul",
517 "mul_",
518 "multinomial",
519 "mv",
520 "nan_to_num",
521 "ne",
522 "ne_scalar",
523 "neg",
524 "neg_",
525 "nll_loss_backward",
526 "nll_loss_forward",
527 "nll_loss2d_backward",
528 "nll_loss2d_forward",
529 "nll_loss_nd_forward",
530 "nll_loss_nd_backward",
531 "nonzero",
532 "normal_float_tensor",
533 "normal_tensor_float",
534 "normal_tensor_tensor",
535 "normal_",
536 "normed_cumsum",
537 "ones",
538 "ones_like",
539 "one_hot",
540 "pad",
541 "per_token_group_quant_fp8",
542 "pixel_unshuffle",
543 "pixel_unshuffle_out",
544 "polar",
545 "pow_scalar",
546 "pow_tensor_scalar",
547 "pow_tensor_scalar_",
548 "pow_tensor_tensor",
549 "pow_tensor_tensor_",
550 "prelu",
551 "prod",
552 "prod_dim",
553 "quantile",
554 "rand",
555 "rand_like",
556 "randn",
557 "randn_like",
558 "randperm",
559 "reciprocal",
560 "reciprocal_",
561 "reflection_pad2d",
562 "reflection_pad2d_out",
563 "reflection_pad1d",
564 "reflection_pad1d_out",
565 "relu",
566 "relu_",
567 "relu6",
568 "remainder",
569 "remainder_",
570 "repeat",
571 "repeat_interleave_self_int",
572 "repeat_interleave_self_tensor",
573 "repeat_interleave_tensor",
574 "replication_pad1d",
575 "replication_pad1d_out",
576 "replication_pad3d",
577 "resolve_conj",
578 "resolve_neg",
579 "rms_norm",
580 "rms_norm_backward",
581 "rms_norm_forward",
582 "rrelu_with_noise_backward",
583 "rsqrt",
584 "rsqrt_",
585 "scaled_dot_product_attention",
586 "scaled_dot_product_attention_backward",
587 "scaled_dot_product_attention_forward",
588 "scaled_softmax_backward",
589 "scaled_softmax_forward",
590 "scatter",
591 "scatter_",
592 "scatter_add_",
593 "select_scatter",
594 "selu",
595 "selu_",
596 "sgn_",
597 "sigmoid",
598 "sigmoid_",
599 "sigmoid_backward",
600 "silu",
601 "silu_",
602 "silu_backward",
603 "sin",
604 "sin_",
605 "sinh_",
606 "slice_backward",
607 "slice_scatter",
608 "soft_margin_loss",
609 "soft_margin_loss_out",
610 "softmax",
611 "softmax_backward",
612 "softplus",
613 "softshrink",
614 "softshrink_out",
615 "sort",
616 "sort_stable",
617 "special_i1",
618 "special_i1_out",
619 "special_i0e",
620 "special_i0e_out",
621 "sqrt",
622 "sqrt_",
623 "stack",
624 "std",
625 "sub",
626 "sub_",
627 "sum",
628 "sum_dim",
629 "sum_dim_out",
630 "sum_out",
631 "ScaleDotProductAttention",
632 "SUPPORTED_FP8_DTYPE",
633 "t_copy",
634 "t_copy_out",
635 "tan",
636 "tan_",
637 "tanh",
638 "tanh_",
639 "tanh_backward",
640 "threshold",
641 "threshold_backward",
642 "tile",
643 "to_copy",
644 "topk",
645 "trace",
646 "tril",
647 "tril_out",
648 "triu",
649 "triu_",
650 "true_divide",
651 "true_divide_",
652 "true_divide_out",
653 "unfold_backward",
654 "uniform_",
655 "upsample_bicubic2d",
656 "upsample_linear1d",
657 "upsample_nearest1d",
658 "upsample_nearest2d",
659 "upsample_nearest3d",
660 "var_mean",
661 "vdot",
662 "vector_norm",
663 "vstack",
664 "w8a8_block_fp8_matmul",
665 "weight_norm_interface",
666 "weight_norm_interface_backward",
667 "where_scalar_other",
668 "where_scalar_self",
669 "where_self",
670 "where_self_out",
671 "zero",
672 "zero_out",
673 "zero_",
674 "zeros",
675 "zeros_like",
676]