Coverage for src/flag_gems/ops/__init__.py: 100%
175 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1from flag_gems.ops.abs import abs, abs_
2from flag_gems.ops.acos import acos
3from flag_gems.ops.add import add, add_
4from flag_gems.ops.addcdiv import addcdiv
5from flag_gems.ops.addcmul import addcmul
6from flag_gems.ops.addmm import addmm, addmm_out
7from flag_gems.ops.addmv import addmv, addmv_out
8from flag_gems.ops.addr import addr
9from flag_gems.ops.all import all, all_dim, all_dims
10from flag_gems.ops.amax import amax
11from flag_gems.ops.angle import angle
12from flag_gems.ops.any import any, any_dim, any_dims
13from flag_gems.ops.arange import arange, arange_start
14from flag_gems.ops.argmax import argmax
15from flag_gems.ops.argmin import argmin
16from flag_gems.ops.atan import atan, atan_
17from flag_gems.ops.attention import (
18 ScaleDotProductAttention,
19 flash_attention_forward,
20 flash_attn_varlen_func,
21 scaled_dot_product_attention,
22 scaled_dot_product_attention_backward,
23 scaled_dot_product_attention_forward,
24)
25from flag_gems.ops.avg_pool2d import avg_pool2d, avg_pool2d_backward
26from flag_gems.ops.baddbmm import baddbmm
27from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward
28from flag_gems.ops.bitwise_and import (
29 bitwise_and_scalar,
30 bitwise_and_scalar_,
31 bitwise_and_scalar_tensor,
32 bitwise_and_tensor,
33 bitwise_and_tensor_,
34)
35from flag_gems.ops.bitwise_left_shift import bitwise_left_shift
36from flag_gems.ops.bitwise_not import bitwise_not, bitwise_not_
37from flag_gems.ops.bitwise_or import (
38 bitwise_or_scalar,
39 bitwise_or_scalar_,
40 bitwise_or_scalar_tensor,
41 bitwise_or_tensor,
42 bitwise_or_tensor_,
43)
44from flag_gems.ops.bitwise_right_shift import bitwise_right_shift
45from flag_gems.ops.bmm import bmm, bmm_out
46from flag_gems.ops.cat import cat
47from flag_gems.ops.ceil import ceil, ceil_, ceil_out
48from flag_gems.ops.celu import celu, celu_
49from flag_gems.ops.clamp import (
50 clamp,
51 clamp_,
52 clamp_min,
53 clamp_min_,
54 clamp_tensor,
55 clamp_tensor_,
56)
57from flag_gems.ops.contiguous import contiguous
58from flag_gems.ops.conv1d import conv1d
59from flag_gems.ops.conv2d import conv2d
60from flag_gems.ops.conv3d import conv3d
61from flag_gems.ops.conv_depthwise2d import _conv_depthwise2d
62from flag_gems.ops.copy import copy, copy_
63from flag_gems.ops.cos import cos, cos_
64from flag_gems.ops.count_nonzero import count_nonzero
65from flag_gems.ops.cummax import cummax
66from flag_gems.ops.cummin import cummin
67from flag_gems.ops.cumsum import cumsum, cumsum_out, normed_cumsum
68from flag_gems.ops.diag import diag
69from flag_gems.ops.diag_embed import diag_embed
70from flag_gems.ops.diagonal import diagonal_backward
71from flag_gems.ops.div import (
72 div_mode,
73 div_mode_,
74 floor_divide,
75 floor_divide_,
76 remainder,
77 remainder_,
78 true_divide,
79 true_divide_,
80 true_divide_out,
81)
82from flag_gems.ops.dot import dot
83from flag_gems.ops.dropout import dropout, dropout_backward
84from flag_gems.ops.elu import elu, elu_, elu_backward
85from flag_gems.ops.embedding import embedding, embedding_backward
86from flag_gems.ops.eq import eq, eq_scalar, equal
87from flag_gems.ops.erf import erf, erf_
88from flag_gems.ops.exp import exp, exp_, exp_out
89from flag_gems.ops.exp2 import exp2, exp2_
90from flag_gems.ops.exponential_ import exponential_
91from flag_gems.ops.eye import eye
92from flag_gems.ops.eye_m import eye_m
93from flag_gems.ops.fill import (
94 fill_scalar,
95 fill_scalar_,
96 fill_scalar_out,
97 fill_tensor,
98 fill_tensor_,
99 fill_tensor_out,
100)
101from flag_gems.ops.flip import flip
102from flag_gems.ops.full import full
103from flag_gems.ops.full_like import full_like
104from flag_gems.ops.gather import gather, gather_backward
105from flag_gems.ops.ge import ge, ge_scalar
106from flag_gems.ops.gelu import gelu, gelu_, gelu_backward
107from flag_gems.ops.get_scheduler_metadata import get_scheduler_metadata
108from flag_gems.ops.glu import glu, glu_backward
109from flag_gems.ops.groupnorm import group_norm, group_norm_backward
110from flag_gems.ops.gt import gt, gt_scalar
111from flag_gems.ops.hstack import hstack
112from flag_gems.ops.index import index
113from flag_gems.ops.index_add import index_add, index_add_
114from flag_gems.ops.index_put import index_put, index_put_
115from flag_gems.ops.index_select import index_select
116from flag_gems.ops.isclose import allclose, isclose
117from flag_gems.ops.isfinite import isfinite
118from flag_gems.ops.isin import isin
119from flag_gems.ops.isinf import isinf
120from flag_gems.ops.isnan import isnan
121from flag_gems.ops.kron import kron
122from flag_gems.ops.layernorm import layer_norm, layer_norm_backward
123from flag_gems.ops.le import le, le_scalar
124from flag_gems.ops.lerp import lerp_scalar, lerp_scalar_, lerp_tensor, lerp_tensor_
125from flag_gems.ops.linspace import linspace
126from flag_gems.ops.log import log
127from flag_gems.ops.log_sigmoid import log_sigmoid
128from flag_gems.ops.log_softmax import log_softmax, log_softmax_backward
129from flag_gems.ops.logical_and import logical_and, logical_and_
130from flag_gems.ops.logical_not import logical_not
131from flag_gems.ops.logical_or import logical_or, logical_or_
132from flag_gems.ops.logical_xor import logical_xor
133from flag_gems.ops.logspace import logspace
134from flag_gems.ops.lt import lt, lt_scalar
135from flag_gems.ops.masked_fill import masked_fill, masked_fill_
136from flag_gems.ops.masked_scatter import masked_scatter, masked_scatter_
137from flag_gems.ops.masked_select import masked_select
138from flag_gems.ops.max import max, max_dim
139from flag_gems.ops.max_pool2d_with_indices import (
140 max_pool2d_backward,
141 max_pool2d_with_indices,
142)
143from flag_gems.ops.maximum import maximum
144from flag_gems.ops.mean import mean, mean_dim
145from flag_gems.ops.min import min, min_dim
146from flag_gems.ops.minimum import minimum
147from flag_gems.ops.mm import mm, mm_out
148from flag_gems.ops.mse_loss import mse_loss
149from flag_gems.ops.mul import mul, mul_
150from flag_gems.ops.multinomial import multinomial
151from flag_gems.ops.mv import mv
152from flag_gems.ops.nan_to_num import nan_to_num
153from flag_gems.ops.ne import ne, ne_scalar
154from flag_gems.ops.neg import neg, neg_
155from flag_gems.ops.nllloss import (
156 nll_loss2d_backward,
157 nll_loss2d_forward,
158 nll_loss_backward,
159 nll_loss_forward,
160)
161from flag_gems.ops.nonzero import nonzero
162from flag_gems.ops.normal import (
163 normal_,
164 normal_float_tensor,
165 normal_tensor_float,
166 normal_tensor_tensor,
167)
168from flag_gems.ops.one_hot import one_hot
169from flag_gems.ops.ones import ones
170from flag_gems.ops.ones_like import ones_like
171from flag_gems.ops.pad import constant_pad_nd, pad
172from flag_gems.ops.per_token_group_quant_fp8 import (
173 SUPPORTED_FP8_DTYPE,
174 per_token_group_quant_fp8,
175)
176from flag_gems.ops.polar import polar
177from flag_gems.ops.pow import (
178 pow_scalar,
179 pow_tensor_scalar,
180 pow_tensor_scalar_,
181 pow_tensor_tensor,
182 pow_tensor_tensor_,
183)
184from flag_gems.ops.prod import prod, prod_dim
185from flag_gems.ops.quantile import quantile
186from flag_gems.ops.rand import rand
187from flag_gems.ops.rand_like import rand_like
188from flag_gems.ops.randn import randn
189from flag_gems.ops.randn_like import randn_like
190from flag_gems.ops.randperm import randperm
191from flag_gems.ops.reciprocal import reciprocal, reciprocal_
192from flag_gems.ops.relu import relu, relu_
193from flag_gems.ops.repeat import repeat
194from flag_gems.ops.repeat_interleave import (
195 repeat_interleave_self_int,
196 repeat_interleave_self_tensor,
197 repeat_interleave_tensor,
198)
199from flag_gems.ops.replication_pad3d import replication_pad3d
200from flag_gems.ops.resolve_conj import resolve_conj
201from flag_gems.ops.resolve_neg import resolve_neg
202from flag_gems.ops.rms_norm import rms_norm, rms_norm_backward, rms_norm_forward
203from flag_gems.ops.rsqrt import rsqrt, rsqrt_
204from flag_gems.ops.scaled_softmax import scaled_softmax_backward, scaled_softmax_forward
205from flag_gems.ops.scatter import scatter, scatter_
206from flag_gems.ops.scatter_add_ import scatter_add_
207from flag_gems.ops.select_scatter import select_scatter
208from flag_gems.ops.sigmoid import sigmoid, sigmoid_, sigmoid_backward
209from flag_gems.ops.silu import silu, silu_, silu_backward
210from flag_gems.ops.sin import sin, sin_
211from flag_gems.ops.slice_scatter import slice_scatter
212from flag_gems.ops.softmax import softmax, softmax_backward
213from flag_gems.ops.softplus import softplus
214from flag_gems.ops.sort import sort, sort_stable
215from flag_gems.ops.sqrt import sqrt, sqrt_
216from flag_gems.ops.stack import stack
217from flag_gems.ops.std import std
218from flag_gems.ops.sub import sub, sub_
219from flag_gems.ops.sum import sum, sum_dim, sum_dim_out, sum_out
220from flag_gems.ops.tan import tan, tan_
221from flag_gems.ops.tanh import tanh, tanh_, tanh_backward
222from flag_gems.ops.threshold import threshold, threshold_backward
223from flag_gems.ops.tile import tile
224from flag_gems.ops.to import to_copy
225from flag_gems.ops.topk import topk
226from flag_gems.ops.trace import trace
227from flag_gems.ops.triu import triu, triu_
228from flag_gems.ops.unfold_backward import unfold_backward
229from flag_gems.ops.uniform import uniform_
230from flag_gems.ops.unique import _unique2
231from flag_gems.ops.upsample_bicubic2d_aa import _upsample_bicubic2d_aa
232from flag_gems.ops.upsample_linear1d import upsample_linear1d
233from flag_gems.ops.upsample_nearest1d import upsample_nearest1d
234from flag_gems.ops.upsample_nearest2d import upsample_nearest2d
235from flag_gems.ops.upsample_nearest3d import upsample_nearest3d
236from flag_gems.ops.var_mean import var_mean
237from flag_gems.ops.vdot import vdot
238from flag_gems.ops.vector_norm import vector_norm
239from flag_gems.ops.vstack import vstack
240from flag_gems.ops.weightnorm import (
241 weight_norm_interface,
242 weight_norm_interface_backward,
243)
244from flag_gems.ops.where import (
245 where_scalar_other,
246 where_scalar_self,
247 where_self,
248 where_self_out,
249)
250from flag_gems.ops.zeros import zero_, zeros
251from flag_gems.ops.zeros_like import zeros_like
253__all__ = [
254 "_conv_depthwise2d",
255 "_unique2",
256 "_upsample_bicubic2d_aa",
257 "abs",
258 "abs_",
259 "acos",
260 "add",
261 "add_",
262 "addcdiv",
263 "addcmul",
264 "addmm",
265 "addmm_out",
266 "addmv",
267 "addmv_out",
268 "addr",
269 "all",
270 "all_dim",
271 "all_dims",
272 "allclose",
273 "amax",
274 "angle",
275 "any",
276 "any_dim",
277 "any_dims",
278 "arange",
279 "arange_start",
280 "argmax",
281 "argmin",
282 "atan",
283 "atan_",
284 "avg_pool2d",
285 "avg_pool2d_backward",
286 "baddbmm",
287 "batch_norm",
288 "batch_norm_backward",
289 "bitwise_and_scalar",
290 "bitwise_and_scalar_",
291 "bitwise_and_scalar_tensor",
292 "bitwise_and_tensor",
293 "bitwise_and_tensor_",
294 "bitwise_left_shift",
295 "bitwise_not",
296 "bitwise_not_",
297 "bitwise_or_scalar",
298 "bitwise_or_scalar_",
299 "bitwise_or_scalar_tensor",
300 "bitwise_or_tensor",
301 "bitwise_or_tensor_",
302 "bitwise_right_shift",
303 "bmm",
304 "bmm_out",
305 "cat",
306 "ceil",
307 "ceil_",
308 "ceil_out",
309 "celu",
310 "celu_",
311 "clamp",
312 "clamp_",
313 "clamp_min",
314 "clamp_min_",
315 "clamp_tensor",
316 "clamp_tensor_",
317 "constant_pad_nd",
318 "contiguous",
319 "conv1d",
320 "conv2d",
321 "conv3d",
322 "copy",
323 "copy_",
324 "cos",
325 "cos_",
326 "count_nonzero",
327 "cummax",
328 "cummin",
329 "cumsum",
330 "cumsum_out",
331 "diag",
332 "diag_embed",
333 "diagonal_backward",
334 "div_mode",
335 "div_mode_",
336 "dot",
337 "dropout",
338 "dropout_backward",
339 "elu",
340 "elu_",
341 "elu_backward",
342 "embedding",
343 "embedding_backward",
344 "eq",
345 "eq_scalar",
346 "equal",
347 "erf",
348 "erf_",
349 "exp",
350 "exp_",
351 "exp_out",
352 "exp2",
353 "exp2_",
354 "exponential_",
355 "eye",
356 "eye_m",
357 "fill_scalar",
358 "fill_scalar_",
359 "fill_scalar_out",
360 "fill_tensor",
361 "fill_tensor_",
362 "fill_tensor_out",
363 "flash_attention_forward",
364 "flash_attn_varlen_func",
365 "flip",
366 "floor_divide",
367 "floor_divide_",
368 "full",
369 "full_like",
370 "gather",
371 "gather_backward",
372 "ge",
373 "ge_scalar",
374 "gelu",
375 "gelu_",
376 "gelu_backward",
377 "get_scheduler_metadata",
378 "glu",
379 "glu_backward",
380 "group_norm",
381 "group_norm_backward",
382 "gt",
383 "gt_scalar",
384 "hstack",
385 "index",
386 "index_add",
387 "index_add_",
388 "index_put",
389 "index_put_",
390 "index_select",
391 "isclose",
392 "isfinite",
393 "isin",
394 "isinf",
395 "isnan",
396 "kron",
397 "layer_norm",
398 "layer_norm_backward",
399 "le",
400 "le_scalar",
401 "lerp_scalar",
402 "lerp_scalar_",
403 "lerp_tensor",
404 "lerp_tensor_",
405 "linspace",
406 "log",
407 "log_sigmoid",
408 "log_softmax",
409 "log_softmax_backward",
410 "logical_and",
411 "logical_and_",
412 "logical_not",
413 "logical_or",
414 "logical_or_",
415 "logical_xor",
416 "logspace",
417 "lt",
418 "lt_scalar",
419 "masked_fill",
420 "masked_fill_",
421 "masked_scatter",
422 "masked_scatter_",
423 "masked_select",
424 "max",
425 "max_dim",
426 "max_pool2d_with_indices",
427 "max_pool2d_backward",
428 "maximum",
429 "mean",
430 "mean_dim",
431 "min",
432 "min_dim",
433 "minimum",
434 "mm",
435 "mm_out",
436 "mse_loss",
437 "mul",
438 "mul_",
439 "multinomial",
440 "mv",
441 "nan_to_num",
442 "ne",
443 "ne_scalar",
444 "neg",
445 "neg_",
446 "nll_loss_backward",
447 "nll_loss_forward",
448 "nll_loss2d_backward",
449 "nll_loss2d_forward",
450 "nonzero",
451 "normal_float_tensor",
452 "normal_tensor_float",
453 "normal_tensor_tensor",
454 "normal_",
455 "normed_cumsum",
456 "ones",
457 "ones_like",
458 "one_hot",
459 "pad",
460 "per_token_group_quant_fp8",
461 "polar",
462 "pow_scalar",
463 "pow_tensor_scalar",
464 "pow_tensor_scalar_",
465 "pow_tensor_tensor",
466 "pow_tensor_tensor_",
467 "prod",
468 "prod_dim",
469 "quantile",
470 "rand",
471 "rand_like",
472 "randn",
473 "randn_like",
474 "randperm",
475 "reciprocal",
476 "reciprocal_",
477 "relu",
478 "relu_",
479 "remainder",
480 "remainder_",
481 "repeat",
482 "repeat_interleave_self_int",
483 "repeat_interleave_self_tensor",
484 "repeat_interleave_tensor",
485 "replication_pad3d",
486 "resolve_conj",
487 "resolve_neg",
488 "rms_norm",
489 "rms_norm_backward",
490 "rms_norm_forward",
491 "rsqrt",
492 "rsqrt_",
493 "scaled_dot_product_attention",
494 "scaled_dot_product_attention_backward",
495 "scaled_dot_product_attention_forward",
496 "scaled_softmax_backward",
497 "scaled_softmax_forward",
498 "scatter",
499 "scatter_",
500 "scatter_add_",
501 "select_scatter",
502 "sigmoid",
503 "sigmoid_",
504 "sigmoid_backward",
505 "silu",
506 "silu_",
507 "silu_backward",
508 "sin",
509 "sin_",
510 "slice_scatter",
511 "softmax",
512 "softmax_backward",
513 "softplus",
514 "sort",
515 "sort_stable",
516 "sqrt",
517 "sqrt_",
518 "stack",
519 "std",
520 "sub",
521 "sub_",
522 "sum",
523 "sum_dim",
524 "sum_dim_out",
525 "sum_out",
526 "ScaleDotProductAttention",
527 "SUPPORTED_FP8_DTYPE",
528 "tan",
529 "tan_",
530 "tanh",
531 "tanh_",
532 "tanh_backward",
533 "threshold",
534 "threshold_backward",
535 "tile",
536 "to_copy",
537 "topk",
538 "trace",
539 "triu",
540 "triu_",
541 "true_divide",
542 "true_divide_",
543 "true_divide_out",
544 "unfold_backward",
545 "uniform_",
546 "upsample_linear1d",
547 "upsample_nearest1d",
548 "upsample_nearest2d",
549 "upsample_nearest3d",
550 "var_mean",
551 "vdot",
552 "vector_norm",
553 "vstack",
554 "weight_norm_interface",
555 "weight_norm_interface_backward",
556 "where_scalar_other",
557 "where_scalar_self",
558 "where_self",
559 "where_self_out",
560 "zeros",
561 "zero_",
562 "zeros_like",
563]