Coverage for src/flag_gems/ops/__init__.py: 100%
173 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1from flag_gems.ops.abs import abs, abs_
2from flag_gems.ops.acos import acos
3from flag_gems.ops.add import add, add_
4from flag_gems.ops.addcdiv import addcdiv
5from flag_gems.ops.addcmul import addcmul
6from flag_gems.ops.addmm import addmm, addmm_out
7from flag_gems.ops.addmv import addmv, addmv_out
8from flag_gems.ops.addr import addr
9from flag_gems.ops.all import all, all_dim, all_dims
10from flag_gems.ops.amax import amax
11from flag_gems.ops.angle import angle
12from flag_gems.ops.any import any, any_dim, any_dims
13from flag_gems.ops.arange import arange, arange_start
14from flag_gems.ops.argmax import argmax
15from flag_gems.ops.argmin import argmin
16from flag_gems.ops.atan import atan, atan_
17from flag_gems.ops.attention import (
18 ScaleDotProductAttention,
19 flash_attention_forward,
20 flash_attn_varlen_func,
21 scaled_dot_product_attention,
22 scaled_dot_product_attention_backward,
23 scaled_dot_product_attention_forward,
24)
25from flag_gems.ops.avg_pool2d import avg_pool2d, avg_pool2d_backward
26from flag_gems.ops.baddbmm import baddbmm
27from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward
28from flag_gems.ops.bitwise_and import (
29 bitwise_and_scalar,
30 bitwise_and_scalar_,
31 bitwise_and_scalar_tensor,
32 bitwise_and_tensor,
33 bitwise_and_tensor_,
34)
35from flag_gems.ops.bitwise_left_shift import bitwise_left_shift
36from flag_gems.ops.bitwise_not import bitwise_not, bitwise_not_
37from flag_gems.ops.bitwise_or import (
38 bitwise_or_scalar,
39 bitwise_or_scalar_,
40 bitwise_or_scalar_tensor,
41 bitwise_or_tensor,
42 bitwise_or_tensor_,
43)
44from flag_gems.ops.bitwise_right_shift import bitwise_right_shift
45from flag_gems.ops.bmm import bmm, bmm_out
46from flag_gems.ops.cat import cat
47from flag_gems.ops.ceil import ceil, ceil_, ceil_out
48from flag_gems.ops.celu import celu, celu_
49from flag_gems.ops.clamp import (
50 clamp,
51 clamp_,
52 clamp_min,
53 clamp_min_,
54 clamp_tensor,
55 clamp_tensor_,
56)
57from flag_gems.ops.contiguous import contiguous
58from flag_gems.ops.conv1d import conv1d
59from flag_gems.ops.conv2d import conv2d
60from flag_gems.ops.conv3d import conv3d
61from flag_gems.ops.conv_depthwise2d import _conv_depthwise2d
62from flag_gems.ops.copy import copy, copy_
63from flag_gems.ops.cos import cos, cos_
64from flag_gems.ops.count_nonzero import count_nonzero
65from flag_gems.ops.cummax import cummax
66from flag_gems.ops.cummin import cummin
67from flag_gems.ops.cumsum import cumsum, cumsum_out, normed_cumsum
68from flag_gems.ops.diag import diag
69from flag_gems.ops.diag_embed import diag_embed
70from flag_gems.ops.diagonal import diagonal_backward
71from flag_gems.ops.div import (
72 div_mode,
73 div_mode_,
74 floor_divide,
75 floor_divide_,
76 remainder,
77 remainder_,
78 true_divide,
79 true_divide_,
80 true_divide_out,
81)
82from flag_gems.ops.dot import dot
83from flag_gems.ops.dropout import dropout, dropout_backward
84from flag_gems.ops.elu import elu, elu_, elu_backward
85from flag_gems.ops.embedding import embedding, embedding_backward
86from flag_gems.ops.eq import eq, eq_scalar, equal
87from flag_gems.ops.erf import erf, erf_
88from flag_gems.ops.exp import exp, exp_, exp_out
89from flag_gems.ops.exp2 import exp2, exp2_
90from flag_gems.ops.exponential_ import exponential_
91from flag_gems.ops.eye import eye
92from flag_gems.ops.eye_m import eye_m
93from flag_gems.ops.fill import fill_scalar, fill_scalar_, fill_tensor, fill_tensor_
94from flag_gems.ops.flip import flip
95from flag_gems.ops.full import full
96from flag_gems.ops.full_like import full_like
97from flag_gems.ops.gather import gather, gather_backward
98from flag_gems.ops.ge import ge, ge_scalar
99from flag_gems.ops.gelu import gelu, gelu_, gelu_backward
100from flag_gems.ops.get_scheduler_metadata import get_scheduler_metadata
101from flag_gems.ops.glu import glu, glu_backward
102from flag_gems.ops.groupnorm import group_norm, group_norm_backward
103from flag_gems.ops.gt import gt, gt_scalar
104from flag_gems.ops.hstack import hstack
105from flag_gems.ops.index import index
106from flag_gems.ops.index_add import index_add, index_add_
107from flag_gems.ops.index_put import index_put, index_put_
108from flag_gems.ops.index_select import index_select
109from flag_gems.ops.isclose import allclose, isclose
110from flag_gems.ops.isfinite import isfinite
111from flag_gems.ops.isin import isin
112from flag_gems.ops.isinf import isinf
113from flag_gems.ops.isnan import isnan
114from flag_gems.ops.kron import kron
115from flag_gems.ops.layernorm import layer_norm, layer_norm_backward
116from flag_gems.ops.le import le, le_scalar
117from flag_gems.ops.lerp import lerp_scalar, lerp_scalar_, lerp_tensor, lerp_tensor_
118from flag_gems.ops.linspace import linspace
119from flag_gems.ops.log import log
120from flag_gems.ops.log_sigmoid import log_sigmoid
121from flag_gems.ops.log_softmax import log_softmax, log_softmax_backward
122from flag_gems.ops.logical_and import logical_and, logical_and_
123from flag_gems.ops.logical_not import logical_not
124from flag_gems.ops.logical_or import logical_or, logical_or_
125from flag_gems.ops.logical_xor import logical_xor
126from flag_gems.ops.logspace import logspace
127from flag_gems.ops.lt import lt, lt_scalar
128from flag_gems.ops.masked_fill import masked_fill, masked_fill_
129from flag_gems.ops.masked_scatter import masked_scatter, masked_scatter_
130from flag_gems.ops.masked_select import masked_select
131from flag_gems.ops.max import max, max_dim
132from flag_gems.ops.max_pool2d_with_indices import (
133 max_pool2d_backward,
134 max_pool2d_with_indices,
135)
136from flag_gems.ops.maximum import maximum
137from flag_gems.ops.mean import mean, mean_dim
138from flag_gems.ops.min import min, min_dim
139from flag_gems.ops.minimum import minimum
140from flag_gems.ops.mm import mm, mm_out
141from flag_gems.ops.mse_loss import mse_loss
142from flag_gems.ops.mul import mul, mul_
143from flag_gems.ops.multinomial import multinomial
144from flag_gems.ops.mv import mv
145from flag_gems.ops.nan_to_num import nan_to_num
146from flag_gems.ops.ne import ne, ne_scalar
147from flag_gems.ops.neg import neg, neg_
148from flag_gems.ops.nllloss import (
149 nll_loss2d_backward,
150 nll_loss2d_forward,
151 nll_loss_backward,
152 nll_loss_forward,
153)
154from flag_gems.ops.nonzero import nonzero
155from flag_gems.ops.normal import (
156 normal_,
157 normal_float_tensor,
158 normal_tensor_float,
159 normal_tensor_tensor,
160)
161from flag_gems.ops.one_hot import one_hot
162from flag_gems.ops.ones import ones
163from flag_gems.ops.ones_like import ones_like
164from flag_gems.ops.pad import constant_pad_nd, pad
165from flag_gems.ops.per_token_group_quant_fp8 import (
166 SUPPORTED_FP8_DTYPE,
167 per_token_group_quant_fp8,
168)
169from flag_gems.ops.polar import polar
170from flag_gems.ops.pow import (
171 pow_scalar,
172 pow_tensor_scalar,
173 pow_tensor_scalar_,
174 pow_tensor_tensor,
175 pow_tensor_tensor_,
176)
177from flag_gems.ops.prod import prod, prod_dim
178from flag_gems.ops.quantile import quantile
179from flag_gems.ops.rand import rand
180from flag_gems.ops.rand_like import rand_like
181from flag_gems.ops.randn import randn
182from flag_gems.ops.randn_like import randn_like
183from flag_gems.ops.randperm import randperm
184from flag_gems.ops.reciprocal import reciprocal, reciprocal_
185from flag_gems.ops.relu import relu, relu_
186from flag_gems.ops.repeat import repeat
187from flag_gems.ops.repeat_interleave import (
188 repeat_interleave_self_int,
189 repeat_interleave_self_tensor,
190 repeat_interleave_tensor,
191)
192from flag_gems.ops.resolve_conj import resolve_conj
193from flag_gems.ops.resolve_neg import resolve_neg
194from flag_gems.ops.rms_norm import rms_norm, rms_norm_backward, rms_norm_forward
195from flag_gems.ops.rsqrt import rsqrt, rsqrt_
196from flag_gems.ops.scaled_softmax import scaled_softmax_backward, scaled_softmax_forward
197from flag_gems.ops.scatter import scatter, scatter_
198from flag_gems.ops.scatter_add_ import scatter_add_
199from flag_gems.ops.select_scatter import select_scatter
200from flag_gems.ops.sigmoid import sigmoid, sigmoid_, sigmoid_backward
201from flag_gems.ops.silu import silu, silu_, silu_backward
202from flag_gems.ops.sin import sin, sin_
203from flag_gems.ops.slice_scatter import slice_scatter
204from flag_gems.ops.softmax import softmax, softmax_backward
205from flag_gems.ops.softplus import softplus
206from flag_gems.ops.sort import sort, sort_stable
207from flag_gems.ops.sqrt import sqrt, sqrt_
208from flag_gems.ops.stack import stack
209from flag_gems.ops.std import std
210from flag_gems.ops.sub import sub, sub_
211from flag_gems.ops.sum import sum, sum_dim, sum_dim_out, sum_out
212from flag_gems.ops.tan import tan, tan_
213from flag_gems.ops.tanh import tanh, tanh_, tanh_backward
214from flag_gems.ops.threshold import threshold, threshold_backward
215from flag_gems.ops.tile import tile
216from flag_gems.ops.to import to_copy
217from flag_gems.ops.topk import topk
218from flag_gems.ops.trace import trace
219from flag_gems.ops.triu import triu, triu_
220from flag_gems.ops.unfold_backward import unfold_backward
221from flag_gems.ops.uniform import uniform_
222from flag_gems.ops.unique import _unique2
223from flag_gems.ops.upsample_bicubic2d_aa import _upsample_bicubic2d_aa
224from flag_gems.ops.upsample_linear1d import upsample_linear1d
225from flag_gems.ops.upsample_nearest1d import upsample_nearest1d
226from flag_gems.ops.upsample_nearest2d import upsample_nearest2d
227from flag_gems.ops.var_mean import var_mean
228from flag_gems.ops.vdot import vdot
229from flag_gems.ops.vector_norm import vector_norm
230from flag_gems.ops.vstack import vstack
231from flag_gems.ops.weightnorm import (
232 weight_norm_interface,
233 weight_norm_interface_backward,
234)
235from flag_gems.ops.where import (
236 where_scalar_other,
237 where_scalar_self,
238 where_self,
239 where_self_out,
240)
241from flag_gems.ops.zeros import zero_, zeros
242from flag_gems.ops.zeros_like import zeros_like
244__all__ = [
245 "_conv_depthwise2d",
246 "_unique2",
247 "_upsample_bicubic2d_aa",
248 "abs",
249 "abs_",
250 "acos",
251 "add",
252 "add_",
253 "addcdiv",
254 "addcmul",
255 "addmm",
256 "addmm_out",
257 "addmv",
258 "addmv_out",
259 "addr",
260 "all",
261 "all_dim",
262 "all_dims",
263 "allclose",
264 "amax",
265 "angle",
266 "any",
267 "any_dim",
268 "any_dims",
269 "arange",
270 "arange_start",
271 "argmax",
272 "argmin",
273 "atan",
274 "atan_",
275 "avg_pool2d",
276 "avg_pool2d_backward",
277 "baddbmm",
278 "batch_norm",
279 "batch_norm_backward",
280 "bitwise_and_scalar",
281 "bitwise_and_scalar_",
282 "bitwise_and_scalar_tensor",
283 "bitwise_and_tensor",
284 "bitwise_and_tensor_",
285 "bitwise_left_shift",
286 "bitwise_not",
287 "bitwise_not_",
288 "bitwise_or_scalar",
289 "bitwise_or_scalar_",
290 "bitwise_or_scalar_tensor",
291 "bitwise_or_tensor",
292 "bitwise_or_tensor_",
293 "bitwise_right_shift",
294 "bmm",
295 "bmm_out",
296 "cat",
297 "ceil",
298 "ceil_",
299 "ceil_out",
300 "celu",
301 "celu_",
302 "clamp",
303 "clamp_",
304 "clamp_min",
305 "clamp_min_",
306 "clamp_tensor",
307 "clamp_tensor_",
308 "constant_pad_nd",
309 "contiguous",
310 "conv1d",
311 "conv2d",
312 "conv3d",
313 "copy",
314 "copy_",
315 "cos",
316 "cos_",
317 "count_nonzero",
318 "cummax",
319 "cummin",
320 "cumsum",
321 "cumsum_out",
322 "diag",
323 "diag_embed",
324 "diagonal_backward",
325 "div_mode",
326 "div_mode_",
327 "dot",
328 "dropout",
329 "dropout_backward",
330 "elu",
331 "elu_",
332 "elu_backward",
333 "embedding",
334 "embedding_backward",
335 "eq",
336 "eq_scalar",
337 "equal",
338 "erf",
339 "erf_",
340 "exp",
341 "exp_",
342 "exp_out",
343 "exp2",
344 "exp2_",
345 "exponential_",
346 "eye",
347 "eye_m",
348 "fill_scalar",
349 "fill_scalar_",
350 "fill_tensor",
351 "fill_tensor_",
352 "flash_attention_forward",
353 "flash_attn_varlen_func",
354 "flip",
355 "floor_divide",
356 "floor_divide_",
357 "full",
358 "full_like",
359 "gather",
360 "gather_backward",
361 "ge",
362 "ge_scalar",
363 "gelu",
364 "gelu_",
365 "gelu_backward",
366 "get_scheduler_metadata",
367 "glu",
368 "glu_backward",
369 "group_norm",
370 "group_norm_backward",
371 "gt",
372 "gt_scalar",
373 "hstack",
374 "index",
375 "index_add",
376 "index_add_",
377 "index_put",
378 "index_put_",
379 "index_select",
380 "isclose",
381 "isfinite",
382 "isin",
383 "isinf",
384 "isnan",
385 "kron",
386 "layer_norm",
387 "layer_norm_backward",
388 "le",
389 "le_scalar",
390 "lerp_scalar",
391 "lerp_scalar_",
392 "lerp_tensor",
393 "lerp_tensor_",
394 "linspace",
395 "log",
396 "log_sigmoid",
397 "log_softmax",
398 "log_softmax_backward",
399 "logical_and",
400 "logical_and_",
401 "logical_not",
402 "logical_or",
403 "logical_or_",
404 "logical_xor",
405 "logspace",
406 "lt",
407 "lt_scalar",
408 "masked_fill",
409 "masked_fill_",
410 "masked_scatter",
411 "masked_scatter_",
412 "masked_select",
413 "max",
414 "max_dim",
415 "max_pool2d_with_indices",
416 "max_pool2d_backward",
417 "maximum",
418 "mean",
419 "mean_dim",
420 "min",
421 "min_dim",
422 "minimum",
423 "mm",
424 "mm_out",
425 "mse_loss",
426 "mul",
427 "mul_",
428 "multinomial",
429 "mv",
430 "nan_to_num",
431 "ne",
432 "ne_scalar",
433 "neg",
434 "neg_",
435 "nll_loss_backward",
436 "nll_loss_forward",
437 "nll_loss2d_backward",
438 "nll_loss2d_forward",
439 "nonzero",
440 "normal_float_tensor",
441 "normal_tensor_float",
442 "normal_tensor_tensor",
443 "normal_",
444 "normed_cumsum",
445 "ones",
446 "ones_like",
447 "one_hot",
448 "pad",
449 "per_token_group_quant_fp8",
450 "polar",
451 "pow_scalar",
452 "pow_tensor_scalar",
453 "pow_tensor_scalar_",
454 "pow_tensor_tensor",
455 "pow_tensor_tensor_",
456 "prod",
457 "prod_dim",
458 "quantile",
459 "rand",
460 "rand_like",
461 "randn",
462 "randn_like",
463 "randperm",
464 "reciprocal",
465 "reciprocal_",
466 "relu",
467 "relu_",
468 "remainder",
469 "remainder_",
470 "repeat",
471 "repeat_interleave_self_int",
472 "repeat_interleave_self_tensor",
473 "repeat_interleave_tensor",
474 "resolve_conj",
475 "resolve_neg",
476 "rms_norm",
477 "rms_norm_backward",
478 "rms_norm_forward",
479 "rsqrt",
480 "rsqrt_",
481 "scaled_dot_product_attention",
482 "scaled_dot_product_attention_backward",
483 "scaled_dot_product_attention_forward",
484 "scaled_softmax_backward",
485 "scaled_softmax_forward",
486 "scatter",
487 "scatter_",
488 "scatter_add_",
489 "select_scatter",
490 "sigmoid",
491 "sigmoid_",
492 "sigmoid_backward",
493 "silu",
494 "silu_",
495 "silu_backward",
496 "sin",
497 "sin_",
498 "slice_scatter",
499 "softmax",
500 "softmax_backward",
501 "softplus",
502 "sort",
503 "sort_stable",
504 "sqrt",
505 "sqrt_",
506 "stack",
507 "std",
508 "sub",
509 "sub_",
510 "sum",
511 "sum_dim",
512 "sum_dim_out",
513 "sum_out",
514 "ScaleDotProductAttention",
515 "SUPPORTED_FP8_DTYPE",
516 "tan",
517 "tan_",
518 "tanh",
519 "tanh_",
520 "tanh_backward",
521 "threshold",
522 "threshold_backward",
523 "tile",
524 "to_copy",
525 "topk",
526 "trace",
527 "triu",
528 "triu_",
529 "true_divide",
530 "true_divide_",
531 "true_divide_out",
532 "unfold_backward",
533 "uniform_",
534 "upsample_linear1d",
535 "upsample_nearest1d",
536 "upsample_nearest2d",
537 "var_mean",
538 "vdot",
539 "vector_norm",
540 "vstack",
541 "weight_norm_interface",
542 "weight_norm_interface_backward",
543 "where_scalar_other",
544 "where_scalar_self",
545 "where_self",
546 "where_self_out",
547 "zeros",
548 "zero_",
549 "zeros_like",
550]