Coverage for src/flag_gems/ops/__init__.py: 100%
185 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1from flag_gems.ops.abs import abs, abs_
2from flag_gems.ops.absolute import absolute
3from flag_gems.ops.acos import acos
4from flag_gems.ops.add import add, add_
5from flag_gems.ops.addcdiv import addcdiv
6from flag_gems.ops.addcmul import addcmul
7from flag_gems.ops.addmm import addmm, addmm_out
8from flag_gems.ops.addmv import addmv, addmv_out
9from flag_gems.ops.addr import addr
10from flag_gems.ops.all import all, all_dim, all_dims
11from flag_gems.ops.amax import amax
12from flag_gems.ops.angle import angle
13from flag_gems.ops.any import any, any_dim, any_dims
14from flag_gems.ops.arange import arange, arange_start
15from flag_gems.ops.argmax import argmax
16from flag_gems.ops.argmin import argmin
17from flag_gems.ops.atan import atan, atan_
18from flag_gems.ops.attention import (
19 ScaleDotProductAttention,
20 flash_attention_forward,
21 flash_attn_varlen_func,
22 scaled_dot_product_attention,
23 scaled_dot_product_attention_backward,
24 scaled_dot_product_attention_forward,
25)
26from flag_gems.ops.avg_pool2d import avg_pool2d, avg_pool2d_backward
27from flag_gems.ops.baddbmm import baddbmm
28from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward
29from flag_gems.ops.bitwise_and import (
30 bitwise_and_scalar,
31 bitwise_and_scalar_,
32 bitwise_and_scalar_tensor,
33 bitwise_and_tensor,
34 bitwise_and_tensor_,
35)
36from flag_gems.ops.bitwise_left_shift import bitwise_left_shift
37from flag_gems.ops.bitwise_not import bitwise_not, bitwise_not_
38from flag_gems.ops.bitwise_or import (
39 bitwise_or_scalar,
40 bitwise_or_scalar_,
41 bitwise_or_scalar_tensor,
42 bitwise_or_tensor,
43 bitwise_or_tensor_,
44)
45from flag_gems.ops.bitwise_right_shift import bitwise_right_shift
46from flag_gems.ops.bmm import bmm, bmm_out
47from flag_gems.ops.cat import cat
48from flag_gems.ops.ceil import ceil, ceil_, ceil_out
49from flag_gems.ops.celu import celu, celu_
50from flag_gems.ops.clamp import (
51 clamp,
52 clamp_,
53 clamp_min,
54 clamp_min_,
55 clamp_tensor,
56 clamp_tensor_,
57)
58from flag_gems.ops.contiguous import contiguous
59from flag_gems.ops.conv1d import conv1d
60from flag_gems.ops.conv2d import conv2d
61from flag_gems.ops.conv3d import conv3d
62from flag_gems.ops.conv_depthwise2d import _conv_depthwise2d
63from flag_gems.ops.copy import copy, copy_
64from flag_gems.ops.cos import cos, cos_
65from flag_gems.ops.count_nonzero import count_nonzero
66from flag_gems.ops.cummax import cummax
67from flag_gems.ops.cummin import cummin
68from flag_gems.ops.cumsum import cumsum, cumsum_out, normed_cumsum
69from flag_gems.ops.diag import diag
70from flag_gems.ops.diag_embed import diag_embed
71from flag_gems.ops.diagonal import diagonal_backward
72from flag_gems.ops.div import (
73 div_mode,
74 div_mode_,
75 floor_divide,
76 floor_divide_,
77 remainder,
78 remainder_,
79 true_divide,
80 true_divide_,
81 true_divide_out,
82)
83from flag_gems.ops.dot import dot
84from flag_gems.ops.dropout import dropout, dropout_backward
85from flag_gems.ops.elu import elu, elu_, elu_backward
86from flag_gems.ops.embedding import embedding, embedding_backward
87from flag_gems.ops.embedding_dense_backward import embedding_dense_backward
88from flag_gems.ops.eq import eq, eq_scalar, equal
89from flag_gems.ops.erf import erf, erf_
90from flag_gems.ops.exp import exp, exp_, exp_out
91from flag_gems.ops.exp2 import exp2, exp2_
92from flag_gems.ops.exponential_ import exponential_
93from flag_gems.ops.eye import eye
94from flag_gems.ops.eye_m import eye_m
95from flag_gems.ops.fill import (
96 fill_scalar,
97 fill_scalar_,
98 fill_scalar_out,
99 fill_tensor,
100 fill_tensor_,
101 fill_tensor_out,
102)
103from flag_gems.ops.flip import flip
104from flag_gems.ops.full import full
105from flag_gems.ops.full_like import full_like
106from flag_gems.ops.gather import gather, gather_backward
107from flag_gems.ops.ge import ge, ge_scalar
108from flag_gems.ops.gelu import gelu, gelu_, gelu_backward
109from flag_gems.ops.get_scheduler_metadata import get_scheduler_metadata
110from flag_gems.ops.glu import glu, glu_backward
111from flag_gems.ops.groupnorm import group_norm, group_norm_backward
112from flag_gems.ops.gt import gt, gt_scalar
113from flag_gems.ops.hstack import hstack
114from flag_gems.ops.hypot import hypot, hypot_out
115from flag_gems.ops.i0 import i0, i0_out
116from flag_gems.ops.index import index
117from flag_gems.ops.index_add import index_add, index_add_
118from flag_gems.ops.index_put import index_put, index_put_
119from flag_gems.ops.index_select import index_select
120from flag_gems.ops.isclose import allclose, isclose
121from flag_gems.ops.isfinite import isfinite
122from flag_gems.ops.isin import isin
123from flag_gems.ops.isinf import isinf
124from flag_gems.ops.isnan import isnan
125from flag_gems.ops.kron import kron
126from flag_gems.ops.layernorm import layer_norm, layer_norm_backward
127from flag_gems.ops.le import le, le_scalar
128from flag_gems.ops.lerp import lerp_scalar, lerp_scalar_, lerp_tensor, lerp_tensor_
129from flag_gems.ops.lift_fresh_copy import lift_fresh_copy, lift_fresh_copy_out
130from flag_gems.ops.linspace import linspace
131from flag_gems.ops.log import log
132from flag_gems.ops.log_sigmoid import log_sigmoid
133from flag_gems.ops.log_softmax import log_softmax, log_softmax_backward
134from flag_gems.ops.logical_and import logical_and, logical_and_
135from flag_gems.ops.logical_not import logical_not
136from flag_gems.ops.logical_or import logical_or, logical_or_
137from flag_gems.ops.logical_xor import logical_xor
138from flag_gems.ops.logspace import logspace
139from flag_gems.ops.lt import lt, lt_scalar
140from flag_gems.ops.masked_fill import masked_fill, masked_fill_
141from flag_gems.ops.masked_scatter import masked_scatter, masked_scatter_
142from flag_gems.ops.masked_select import masked_select
143from flag_gems.ops.max import max, max_dim
144from flag_gems.ops.max_pool2d_with_indices import (
145 max_pool2d_backward,
146 max_pool2d_with_indices,
147)
148from flag_gems.ops.maximum import maximum
149from flag_gems.ops.mean import mean, mean_dim
150from flag_gems.ops.min import min, min_dim
151from flag_gems.ops.minimum import minimum
152from flag_gems.ops.mm import mm, mm_out
153from flag_gems.ops.mse_loss import mse_loss
154from flag_gems.ops.mul import mul, mul_
155from flag_gems.ops.multinomial import multinomial
156from flag_gems.ops.mv import mv
157from flag_gems.ops.nan_to_num import nan_to_num
158from flag_gems.ops.ne import ne, ne_scalar
159from flag_gems.ops.neg import neg, neg_
160from flag_gems.ops.nll_loss_nd import nll_loss_nd_backward, nll_loss_nd_forward
161from flag_gems.ops.nllloss import (
162 nll_loss2d_backward,
163 nll_loss2d_forward,
164 nll_loss_backward,
165 nll_loss_forward,
166)
167from flag_gems.ops.nonzero import nonzero
168from flag_gems.ops.normal import (
169 normal_,
170 normal_float_tensor,
171 normal_tensor_float,
172 normal_tensor_tensor,
173)
174from flag_gems.ops.one_hot import one_hot
175from flag_gems.ops.ones import ones
176from flag_gems.ops.ones_like import ones_like
177from flag_gems.ops.pad import constant_pad_nd, pad
178from flag_gems.ops.per_token_group_quant_fp8 import (
179 SUPPORTED_FP8_DTYPE,
180 per_token_group_quant_fp8,
181)
182from flag_gems.ops.polar import polar
183from flag_gems.ops.pow import (
184 pow_scalar,
185 pow_tensor_scalar,
186 pow_tensor_scalar_,
187 pow_tensor_tensor,
188 pow_tensor_tensor_,
189)
190from flag_gems.ops.prod import prod, prod_dim
191from flag_gems.ops.quantile import quantile
192from flag_gems.ops.rand import rand
193from flag_gems.ops.rand_like import rand_like
194from flag_gems.ops.randn import randn
195from flag_gems.ops.randn_like import randn_like
196from flag_gems.ops.randperm import randperm
197from flag_gems.ops.reciprocal import reciprocal, reciprocal_
198from flag_gems.ops.relu import relu, relu_
199from flag_gems.ops.repeat import repeat
200from flag_gems.ops.repeat_interleave import (
201 repeat_interleave_self_int,
202 repeat_interleave_self_tensor,
203 repeat_interleave_tensor,
204)
205from flag_gems.ops.replication_pad3d import replication_pad3d
206from flag_gems.ops.resolve_conj import resolve_conj
207from flag_gems.ops.resolve_neg import resolve_neg
208from flag_gems.ops.rms_norm import rms_norm, rms_norm_backward, rms_norm_forward
209from flag_gems.ops.rsqrt import rsqrt, rsqrt_
210from flag_gems.ops.scaled_softmax import scaled_softmax_backward, scaled_softmax_forward
211from flag_gems.ops.scatter import scatter, scatter_
212from flag_gems.ops.scatter_add_ import scatter_add_
213from flag_gems.ops.select_scatter import select_scatter
214from flag_gems.ops.sgn_ import sgn_
215from flag_gems.ops.sigmoid import sigmoid, sigmoid_, sigmoid_backward
216from flag_gems.ops.silu import silu, silu_, silu_backward
217from flag_gems.ops.sin import sin, sin_
218from flag_gems.ops.sinh_ import sinh_
219from flag_gems.ops.slice_backward import slice_backward
220from flag_gems.ops.slice_scatter import slice_scatter
221from flag_gems.ops.softmax import softmax, softmax_backward
222from flag_gems.ops.softplus import softplus
223from flag_gems.ops.sort import sort, sort_stable
224from flag_gems.ops.sqrt import sqrt, sqrt_
225from flag_gems.ops.stack import stack
226from flag_gems.ops.std import std
227from flag_gems.ops.sub import sub, sub_
228from flag_gems.ops.sum import sum, sum_dim, sum_dim_out, sum_out
229from flag_gems.ops.tan import tan, tan_
230from flag_gems.ops.tanh import tanh, tanh_, tanh_backward
231from flag_gems.ops.threshold import threshold, threshold_backward
232from flag_gems.ops.tile import tile
233from flag_gems.ops.to import to_copy
234from flag_gems.ops.topk import topk
235from flag_gems.ops.trace import trace
236from flag_gems.ops.triu import triu, triu_
237from flag_gems.ops.unfold_backward import unfold_backward
238from flag_gems.ops.uniform import uniform_
239from flag_gems.ops.unique import _unique2
240from flag_gems.ops.upsample_bicubic2d import upsample_bicubic2d
241from flag_gems.ops.upsample_bicubic2d_aa import _upsample_bicubic2d_aa
242from flag_gems.ops.upsample_linear1d import upsample_linear1d
243from flag_gems.ops.upsample_nearest1d import upsample_nearest1d
244from flag_gems.ops.upsample_nearest2d import upsample_nearest2d
245from flag_gems.ops.upsample_nearest3d import upsample_nearest3d
246from flag_gems.ops.var_mean import var_mean
247from flag_gems.ops.vdot import vdot
248from flag_gems.ops.vector_norm import vector_norm
249from flag_gems.ops.vstack import vstack
250from flag_gems.ops.weightnorm import (
251 weight_norm_interface,
252 weight_norm_interface_backward,
253)
254from flag_gems.ops.where import (
255 where_scalar_other,
256 where_scalar_self,
257 where_self,
258 where_self_out,
259)
260from flag_gems.ops.zeros import zero_, zeros
261from flag_gems.ops.zeros_like import zeros_like
263__all__ = [
264 "_conv_depthwise2d",
265 "_unique2",
266 "_upsample_bicubic2d_aa",
267 "abs",
268 "abs_",
269 "absolute",
270 "acos",
271 "add",
272 "add_",
273 "addcdiv",
274 "addcmul",
275 "addmm",
276 "addmm_out",
277 "addmv",
278 "addmv_out",
279 "addr",
280 "all",
281 "all_dim",
282 "all_dims",
283 "allclose",
284 "amax",
285 "angle",
286 "any",
287 "any_dim",
288 "any_dims",
289 "arange",
290 "arange_start",
291 "argmax",
292 "argmin",
293 "atan",
294 "atan_",
295 "avg_pool2d",
296 "avg_pool2d_backward",
297 "baddbmm",
298 "batch_norm",
299 "batch_norm_backward",
300 "bitwise_and_scalar",
301 "bitwise_and_scalar_",
302 "bitwise_and_scalar_tensor",
303 "bitwise_and_tensor",
304 "bitwise_and_tensor_",
305 "bitwise_left_shift",
306 "bitwise_not",
307 "bitwise_not_",
308 "bitwise_or_scalar",
309 "bitwise_or_scalar_",
310 "bitwise_or_scalar_tensor",
311 "bitwise_or_tensor",
312 "bitwise_or_tensor_",
313 "bitwise_right_shift",
314 "bmm",
315 "bmm_out",
316 "cat",
317 "ceil",
318 "ceil_",
319 "ceil_out",
320 "celu",
321 "celu_",
322 "clamp",
323 "clamp_",
324 "clamp_min",
325 "clamp_min_",
326 "clamp_tensor",
327 "clamp_tensor_",
328 "constant_pad_nd",
329 "contiguous",
330 "conv1d",
331 "conv2d",
332 "conv3d",
333 "copy",
334 "copy_",
335 "cos",
336 "cos_",
337 "count_nonzero",
338 "cummax",
339 "cummin",
340 "cumsum",
341 "cumsum_out",
342 "diag",
343 "diag_embed",
344 "diagonal_backward",
345 "div_mode",
346 "div_mode_",
347 "dot",
348 "dropout",
349 "dropout_backward",
350 "elu",
351 "elu_",
352 "elu_backward",
353 "embedding",
354 "embedding_backward",
355 "embedding_dense_backward",
356 "eq",
357 "eq_scalar",
358 "equal",
359 "erf",
360 "erf_",
361 "exp",
362 "exp_",
363 "exp_out",
364 "exp2",
365 "exp2_",
366 "exponential_",
367 "eye",
368 "eye_m",
369 "fill_scalar",
370 "fill_scalar_",
371 "fill_scalar_out",
372 "fill_tensor",
373 "fill_tensor_",
374 "fill_tensor_out",
375 "flash_attention_forward",
376 "flash_attn_varlen_func",
377 "flip",
378 "floor_divide",
379 "floor_divide_",
380 "full",
381 "full_like",
382 "gather",
383 "gather_backward",
384 "ge",
385 "ge_scalar",
386 "gelu",
387 "gelu_",
388 "gelu_backward",
389 "get_scheduler_metadata",
390 "glu",
391 "glu_backward",
392 "group_norm",
393 "group_norm_backward",
394 "gt",
395 "gt_scalar",
396 "hstack",
397 "hypot",
398 "hypot_out",
399 "i0",
400 "i0_out",
401 "index",
402 "index_add",
403 "index_add_",
404 "index_put",
405 "index_put_",
406 "index_select",
407 "isclose",
408 "isfinite",
409 "isin",
410 "isinf",
411 "isnan",
412 "kron",
413 "layer_norm",
414 "layer_norm_backward",
415 "le",
416 "le_scalar",
417 "lerp_scalar",
418 "lerp_scalar_",
419 "lerp_tensor",
420 "lerp_tensor_",
421 "lift_fresh_copy",
422 "lift_fresh_copy_out",
423 "linspace",
424 "log",
425 "log_sigmoid",
426 "log_softmax",
427 "log_softmax_backward",
428 "logical_and",
429 "logical_and_",
430 "logical_not",
431 "logical_or",
432 "logical_or_",
433 "logical_xor",
434 "logspace",
435 "lt",
436 "lt_scalar",
437 "masked_fill",
438 "masked_fill_",
439 "masked_scatter",
440 "masked_scatter_",
441 "masked_select",
442 "max",
443 "max_dim",
444 "max_pool2d_with_indices",
445 "max_pool2d_backward",
446 "maximum",
447 "mean",
448 "mean_dim",
449 "min",
450 "min_dim",
451 "minimum",
452 "mm",
453 "mm_out",
454 "mse_loss",
455 "mul",
456 "mul_",
457 "multinomial",
458 "mv",
459 "nan_to_num",
460 "ne",
461 "ne_scalar",
462 "neg",
463 "neg_",
464 "nll_loss_backward",
465 "nll_loss_forward",
466 "nll_loss2d_backward",
467 "nll_loss2d_forward",
468 "nll_loss_nd_forward",
469 "nll_loss_nd_backward",
470 "nonzero",
471 "normal_float_tensor",
472 "normal_tensor_float",
473 "normal_tensor_tensor",
474 "normal_",
475 "normed_cumsum",
476 "ones",
477 "ones_like",
478 "one_hot",
479 "pad",
480 "per_token_group_quant_fp8",
481 "polar",
482 "pow_scalar",
483 "pow_tensor_scalar",
484 "pow_tensor_scalar_",
485 "pow_tensor_tensor",
486 "pow_tensor_tensor_",
487 "prod",
488 "prod_dim",
489 "quantile",
490 "rand",
491 "rand_like",
492 "randn",
493 "randn_like",
494 "randperm",
495 "reciprocal",
496 "reciprocal_",
497 "relu",
498 "relu_",
499 "remainder",
500 "remainder_",
501 "repeat",
502 "repeat_interleave_self_int",
503 "repeat_interleave_self_tensor",
504 "repeat_interleave_tensor",
505 "replication_pad3d",
506 "resolve_conj",
507 "resolve_neg",
508 "rms_norm",
509 "rms_norm_backward",
510 "rms_norm_forward",
511 "rsqrt",
512 "rsqrt_",
513 "scaled_dot_product_attention",
514 "scaled_dot_product_attention_backward",
515 "scaled_dot_product_attention_forward",
516 "scaled_softmax_backward",
517 "scaled_softmax_forward",
518 "scatter",
519 "scatter_",
520 "scatter_add_",
521 "select_scatter",
522 "sgn_",
523 "sigmoid",
524 "sigmoid_",
525 "sigmoid_backward",
526 "silu",
527 "silu_",
528 "silu_backward",
529 "sin",
530 "sin_",
531 "sinh_",
532 "slice_backward",
533 "slice_scatter",
534 "softmax",
535 "softmax_backward",
536 "softplus",
537 "sort",
538 "sort_stable",
539 "sqrt",
540 "sqrt_",
541 "stack",
542 "std",
543 "sub",
544 "sub_",
545 "sum",
546 "sum_dim",
547 "sum_dim_out",
548 "sum_out",
549 "ScaleDotProductAttention",
550 "SUPPORTED_FP8_DTYPE",
551 "tan",
552 "tan_",
553 "tanh",
554 "tanh_",
555 "tanh_backward",
556 "threshold",
557 "threshold_backward",
558 "tile",
559 "to_copy",
560 "topk",
561 "trace",
562 "triu",
563 "triu_",
564 "true_divide",
565 "true_divide_",
566 "true_divide_out",
567 "unfold_backward",
568 "uniform_",
569 "upsample_bicubic2d",
570 "upsample_linear1d",
571 "upsample_nearest1d",
572 "upsample_nearest2d",
573 "upsample_nearest3d",
574 "var_mean",
575 "vdot",
576 "vector_norm",
577 "vstack",
578 "weight_norm_interface",
579 "weight_norm_interface_backward",
580 "where_scalar_other",
581 "where_scalar_self",
582 "where_self",
583 "where_self_out",
584 "zeros",
585 "zero_",
586 "zeros_like",
587]