Coverage for src/flag_gems/ops/__init__.py: 100%
172 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1from flag_gems.ops.abs import abs, abs_
2from flag_gems.ops.acos import acos
3from flag_gems.ops.add import add, add_
4from flag_gems.ops.addcdiv import addcdiv
5from flag_gems.ops.addcmul import addcmul
6from flag_gems.ops.addmm import addmm, addmm_out
7from flag_gems.ops.addmv import addmv, addmv_out
8from flag_gems.ops.addr import addr
9from flag_gems.ops.all import all, all_dim, all_dims
10from flag_gems.ops.amax import amax
11from flag_gems.ops.angle import angle
12from flag_gems.ops.any import any, any_dim, any_dims
13from flag_gems.ops.arange import arange, arange_start
14from flag_gems.ops.argmax import argmax
15from flag_gems.ops.argmin import argmin
16from flag_gems.ops.atan import atan, atan_
17from flag_gems.ops.attention import (
18 ScaleDotProductAttention,
19 flash_attention_forward,
20 flash_attn_varlen_func,
21 scaled_dot_product_attention,
22 scaled_dot_product_attention_backward,
23 scaled_dot_product_attention_forward,
24)
25from flag_gems.ops.avg_pool2d import avg_pool2d, avg_pool2d_backward
26from flag_gems.ops.baddbmm import baddbmm
27from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward
28from flag_gems.ops.bitwise_and import (
29 bitwise_and_scalar,
30 bitwise_and_scalar_,
31 bitwise_and_scalar_tensor,
32 bitwise_and_tensor,
33 bitwise_and_tensor_,
34)
35from flag_gems.ops.bitwise_left_shift import bitwise_left_shift
36from flag_gems.ops.bitwise_not import bitwise_not, bitwise_not_
37from flag_gems.ops.bitwise_or import (
38 bitwise_or_scalar,
39 bitwise_or_scalar_,
40 bitwise_or_scalar_tensor,
41 bitwise_or_tensor,
42 bitwise_or_tensor_,
43)
44from flag_gems.ops.bitwise_right_shift import bitwise_right_shift
45from flag_gems.ops.bmm import bmm, bmm_out
46from flag_gems.ops.cat import cat
47from flag_gems.ops.ceil import ceil, ceil_, ceil_out
48from flag_gems.ops.celu import celu, celu_
49from flag_gems.ops.clamp import (
50 clamp,
51 clamp_,
52 clamp_min,
53 clamp_min_,
54 clamp_tensor,
55 clamp_tensor_,
56)
57from flag_gems.ops.contiguous import contiguous
58from flag_gems.ops.conv1d import conv1d
59from flag_gems.ops.conv2d import conv2d
60from flag_gems.ops.conv3d import conv3d
61from flag_gems.ops.conv_depthwise2d import _conv_depthwise2d
62from flag_gems.ops.copy import copy, copy_
63from flag_gems.ops.cos import cos, cos_
64from flag_gems.ops.count_nonzero import count_nonzero
65from flag_gems.ops.cummax import cummax
66from flag_gems.ops.cummin import cummin
67from flag_gems.ops.cumsum import cumsum, cumsum_out, normed_cumsum
68from flag_gems.ops.diag import diag
69from flag_gems.ops.diag_embed import diag_embed
70from flag_gems.ops.diagonal import diagonal_backward
71from flag_gems.ops.div import (
72 div_mode,
73 div_mode_,
74 floor_divide,
75 floor_divide_,
76 remainder,
77 remainder_,
78 true_divide,
79 true_divide_,
80 true_divide_out,
81)
82from flag_gems.ops.dot import dot
83from flag_gems.ops.dropout import dropout, dropout_backward
84from flag_gems.ops.elu import elu, elu_, elu_backward
85from flag_gems.ops.embedding import embedding, embedding_backward
86from flag_gems.ops.eq import eq, eq_scalar, equal
87from flag_gems.ops.erf import erf, erf_
88from flag_gems.ops.exp import exp, exp_, exp_out
89from flag_gems.ops.exp2 import exp2, exp2_
90from flag_gems.ops.exponential_ import exponential_
91from flag_gems.ops.eye import eye
92from flag_gems.ops.eye_m import eye_m
93from flag_gems.ops.fill import fill_scalar, fill_scalar_, fill_tensor, fill_tensor_
94from flag_gems.ops.flip import flip
95from flag_gems.ops.full import full
96from flag_gems.ops.full_like import full_like
97from flag_gems.ops.gather import gather, gather_backward
98from flag_gems.ops.ge import ge, ge_scalar
99from flag_gems.ops.gelu import gelu, gelu_, gelu_backward
100from flag_gems.ops.get_scheduler_metadata import get_scheduler_metadata
101from flag_gems.ops.glu import glu, glu_backward
102from flag_gems.ops.groupnorm import group_norm, group_norm_backward
103from flag_gems.ops.gt import gt, gt_scalar
104from flag_gems.ops.hstack import hstack
105from flag_gems.ops.index import index
106from flag_gems.ops.index_add import index_add, index_add_
107from flag_gems.ops.index_put import index_put, index_put_
108from flag_gems.ops.index_select import index_select
109from flag_gems.ops.isclose import allclose, isclose
110from flag_gems.ops.isfinite import isfinite
111from flag_gems.ops.isin import isin
112from flag_gems.ops.isinf import isinf
113from flag_gems.ops.isnan import isnan
114from flag_gems.ops.kron import kron
115from flag_gems.ops.layernorm import layer_norm, layer_norm_backward
116from flag_gems.ops.le import le, le_scalar
117from flag_gems.ops.lerp import lerp_scalar, lerp_scalar_, lerp_tensor, lerp_tensor_
118from flag_gems.ops.linspace import linspace
119from flag_gems.ops.log import log
120from flag_gems.ops.log_sigmoid import log_sigmoid
121from flag_gems.ops.log_softmax import log_softmax, log_softmax_backward
122from flag_gems.ops.logical_and import logical_and, logical_and_
123from flag_gems.ops.logical_not import logical_not
124from flag_gems.ops.logical_or import logical_or, logical_or_
125from flag_gems.ops.logical_xor import logical_xor
126from flag_gems.ops.logspace import logspace
127from flag_gems.ops.lt import lt, lt_scalar
128from flag_gems.ops.masked_fill import masked_fill, masked_fill_
129from flag_gems.ops.masked_scatter import masked_scatter, masked_scatter_
130from flag_gems.ops.masked_select import masked_select
131from flag_gems.ops.max import max, max_dim
132from flag_gems.ops.max_pool2d_with_indices import (
133 max_pool2d_backward,
134 max_pool2d_with_indices,
135)
136from flag_gems.ops.maximum import maximum
137from flag_gems.ops.mean import mean, mean_dim
138from flag_gems.ops.min import min, min_dim
139from flag_gems.ops.minimum import minimum
140from flag_gems.ops.mm import mm, mm_out
141from flag_gems.ops.mse_loss import mse_loss
142from flag_gems.ops.mul import mul, mul_
143from flag_gems.ops.multinomial import multinomial
144from flag_gems.ops.mv import mv
145from flag_gems.ops.nan_to_num import nan_to_num
146from flag_gems.ops.ne import ne, ne_scalar
147from flag_gems.ops.neg import neg, neg_
148from flag_gems.ops.nllloss import (
149 nll_loss2d_backward,
150 nll_loss2d_forward,
151 nll_loss_backward,
152 nll_loss_forward,
153)
154from flag_gems.ops.nonzero import nonzero
155from flag_gems.ops.normal import (
156 normal_,
157 normal_float_tensor,
158 normal_tensor_float,
159 normal_tensor_tensor,
160)
161from flag_gems.ops.one_hot import one_hot
162from flag_gems.ops.ones import ones
163from flag_gems.ops.ones_like import ones_like
164from flag_gems.ops.pad import constant_pad_nd, pad
165from flag_gems.ops.per_token_group_quant_fp8 import (
166 SUPPORTED_FP8_DTYPE,
167 per_token_group_quant_fp8,
168)
169from flag_gems.ops.polar import polar
170from flag_gems.ops.pow import (
171 pow_scalar,
172 pow_tensor_scalar,
173 pow_tensor_scalar_,
174 pow_tensor_tensor,
175 pow_tensor_tensor_,
176)
177from flag_gems.ops.prod import prod, prod_dim
178from flag_gems.ops.quantile import quantile
179from flag_gems.ops.rand import rand
180from flag_gems.ops.rand_like import rand_like
181from flag_gems.ops.randn import randn
182from flag_gems.ops.randn_like import randn_like
183from flag_gems.ops.randperm import randperm
184from flag_gems.ops.reciprocal import reciprocal, reciprocal_
185from flag_gems.ops.relu import relu, relu_
186from flag_gems.ops.repeat import repeat
187from flag_gems.ops.repeat_interleave import (
188 repeat_interleave_self_int,
189 repeat_interleave_self_tensor,
190 repeat_interleave_tensor,
191)
192from flag_gems.ops.resolve_conj import resolve_conj
193from flag_gems.ops.resolve_neg import resolve_neg
194from flag_gems.ops.rms_norm import rms_norm, rms_norm_backward, rms_norm_forward
195from flag_gems.ops.rsqrt import rsqrt, rsqrt_
196from flag_gems.ops.scaled_softmax import scaled_softmax_backward, scaled_softmax_forward
197from flag_gems.ops.scatter import scatter, scatter_
198from flag_gems.ops.scatter_add_ import scatter_add_
199from flag_gems.ops.select_scatter import select_scatter
200from flag_gems.ops.sigmoid import sigmoid, sigmoid_, sigmoid_backward
201from flag_gems.ops.silu import silu, silu_, silu_backward
202from flag_gems.ops.sin import sin, sin_
203from flag_gems.ops.slice_scatter import slice_scatter
204from flag_gems.ops.softmax import softmax, softmax_backward
205from flag_gems.ops.softplus import softplus
206from flag_gems.ops.sort import sort, sort_stable
207from flag_gems.ops.sqrt import sqrt, sqrt_
208from flag_gems.ops.stack import stack
209from flag_gems.ops.std import std
210from flag_gems.ops.sub import sub, sub_
211from flag_gems.ops.sum import sum, sum_dim, sum_dim_out, sum_out
212from flag_gems.ops.tan import tan, tan_
213from flag_gems.ops.tanh import tanh, tanh_, tanh_backward
214from flag_gems.ops.threshold import threshold, threshold_backward
215from flag_gems.ops.tile import tile
216from flag_gems.ops.to import to_copy
217from flag_gems.ops.topk import topk
218from flag_gems.ops.trace import trace
219from flag_gems.ops.triu import triu, triu_
220from flag_gems.ops.uniform import uniform_
221from flag_gems.ops.unique import _unique2
222from flag_gems.ops.upsample_bicubic2d_aa import _upsample_bicubic2d_aa
223from flag_gems.ops.upsample_linear1d import upsample_linear1d
224from flag_gems.ops.upsample_nearest1d import upsample_nearest1d
225from flag_gems.ops.upsample_nearest2d import upsample_nearest2d
226from flag_gems.ops.var_mean import var_mean
227from flag_gems.ops.vdot import vdot
228from flag_gems.ops.vector_norm import vector_norm
229from flag_gems.ops.vstack import vstack
230from flag_gems.ops.weightnorm import (
231 weight_norm_interface,
232 weight_norm_interface_backward,
233)
234from flag_gems.ops.where import (
235 where_scalar_other,
236 where_scalar_self,
237 where_self,
238 where_self_out,
239)
240from flag_gems.ops.zeros import zero_, zeros
241from flag_gems.ops.zeros_like import zeros_like
243__all__ = [
244 "_conv_depthwise2d",
245 "_unique2",
246 "_upsample_bicubic2d_aa",
247 "abs",
248 "abs_",
249 "acos",
250 "add",
251 "add_",
252 "addcdiv",
253 "addcmul",
254 "addmm",
255 "addmm_out",
256 "addmv",
257 "addmv_out",
258 "addr",
259 "all",
260 "all_dim",
261 "all_dims",
262 "allclose",
263 "amax",
264 "angle",
265 "any",
266 "any_dim",
267 "any_dims",
268 "arange",
269 "arange_start",
270 "argmax",
271 "argmin",
272 "atan",
273 "atan_",
274 "avg_pool2d",
275 "avg_pool2d_backward",
276 "baddbmm",
277 "batch_norm",
278 "batch_norm_backward",
279 "bitwise_and_scalar",
280 "bitwise_and_scalar_",
281 "bitwise_and_scalar_tensor",
282 "bitwise_and_tensor",
283 "bitwise_and_tensor_",
284 "bitwise_left_shift",
285 "bitwise_not",
286 "bitwise_not_",
287 "bitwise_or_scalar",
288 "bitwise_or_scalar_",
289 "bitwise_or_scalar_tensor",
290 "bitwise_or_tensor",
291 "bitwise_or_tensor_",
292 "bitwise_right_shift",
293 "bmm",
294 "bmm_out",
295 "cat",
296 "ceil",
297 "ceil_",
298 "ceil_out",
299 "celu",
300 "celu_",
301 "clamp",
302 "clamp_",
303 "clamp_min",
304 "clamp_min_",
305 "clamp_tensor",
306 "clamp_tensor_",
307 "constant_pad_nd",
308 "contiguous",
309 "conv1d",
310 "conv2d",
311 "conv3d",
312 "copy",
313 "copy_",
314 "cos",
315 "cos_",
316 "count_nonzero",
317 "cummax",
318 "cummin",
319 "cumsum",
320 "cumsum_out",
321 "diag",
322 "diag_embed",
323 "diagonal_backward",
324 "div_mode",
325 "div_mode_",
326 "dot",
327 "dropout",
328 "dropout_backward",
329 "elu",
330 "elu_",
331 "elu_backward",
332 "embedding",
333 "embedding_backward",
334 "eq",
335 "eq_scalar",
336 "equal",
337 "erf",
338 "erf_",
339 "exp",
340 "exp_",
341 "exp_out",
342 "exp2",
343 "exp2_",
344 "exponential_",
345 "eye",
346 "eye_m",
347 "fill_scalar",
348 "fill_scalar_",
349 "fill_tensor",
350 "fill_tensor_",
351 "flash_attention_forward",
352 "flash_attn_varlen_func",
353 "flip",
354 "floor_divide",
355 "floor_divide_",
356 "full",
357 "full_like",
358 "gather",
359 "gather_backward",
360 "ge",
361 "ge_scalar",
362 "gelu",
363 "gelu_",
364 "gelu_backward",
365 "get_scheduler_metadata",
366 "glu",
367 "glu_backward",
368 "group_norm",
369 "group_norm_backward",
370 "gt",
371 "gt_scalar",
372 "hstack",
373 "index",
374 "index_add",
375 "index_add_",
376 "index_put",
377 "index_put_",
378 "index_select",
379 "isclose",
380 "isfinite",
381 "isin",
382 "isinf",
383 "isnan",
384 "kron",
385 "layer_norm",
386 "layer_norm_backward",
387 "le",
388 "le_scalar",
389 "lerp_scalar",
390 "lerp_scalar_",
391 "lerp_tensor",
392 "lerp_tensor_",
393 "linspace",
394 "log",
395 "log_sigmoid",
396 "log_softmax",
397 "log_softmax_backward",
398 "logical_and",
399 "logical_and_",
400 "logical_not",
401 "logical_or",
402 "logical_or_",
403 "logical_xor",
404 "logspace",
405 "lt",
406 "lt_scalar",
407 "masked_fill",
408 "masked_fill_",
409 "masked_scatter",
410 "masked_scatter_",
411 "masked_select",
412 "max",
413 "max_dim",
414 "max_pool2d_with_indices",
415 "max_pool2d_backward",
416 "maximum",
417 "mean",
418 "mean_dim",
419 "min",
420 "min_dim",
421 "minimum",
422 "mm",
423 "mm_out",
424 "mse_loss",
425 "mul",
426 "mul_",
427 "multinomial",
428 "mv",
429 "nan_to_num",
430 "ne",
431 "ne_scalar",
432 "neg",
433 "neg_",
434 "nll_loss_backward",
435 "nll_loss_forward",
436 "nll_loss2d_backward",
437 "nll_loss2d_forward",
438 "nonzero",
439 "normal_float_tensor",
440 "normal_tensor_float",
441 "normal_tensor_tensor",
442 "normal_",
443 "normed_cumsum",
444 "ones",
445 "ones_like",
446 "one_hot",
447 "pad",
448 "per_token_group_quant_fp8",
449 "polar",
450 "pow_scalar",
451 "pow_tensor_scalar",
452 "pow_tensor_scalar_",
453 "pow_tensor_tensor",
454 "pow_tensor_tensor_",
455 "prod",
456 "prod_dim",
457 "quantile",
458 "rand",
459 "rand_like",
460 "randn",
461 "randn_like",
462 "randperm",
463 "reciprocal",
464 "reciprocal_",
465 "relu",
466 "relu_",
467 "remainder",
468 "remainder_",
469 "repeat",
470 "repeat_interleave_self_int",
471 "repeat_interleave_self_tensor",
472 "repeat_interleave_tensor",
473 "resolve_conj",
474 "resolve_neg",
475 "rms_norm",
476 "rms_norm_backward",
477 "rms_norm_forward",
478 "rsqrt",
479 "rsqrt_",
480 "scaled_dot_product_attention",
481 "scaled_dot_product_attention_backward",
482 "scaled_dot_product_attention_forward",
483 "scaled_softmax_backward",
484 "scaled_softmax_forward",
485 "scatter",
486 "scatter_",
487 "scatter_add_",
488 "select_scatter",
489 "sigmoid",
490 "sigmoid_",
491 "sigmoid_backward",
492 "silu",
493 "silu_",
494 "silu_backward",
495 "sin",
496 "sin_",
497 "slice_scatter",
498 "softmax",
499 "softmax_backward",
500 "softplus",
501 "sort",
502 "sort_stable",
503 "sqrt",
504 "sqrt_",
505 "stack",
506 "std",
507 "sub",
508 "sub_",
509 "sum",
510 "sum_dim",
511 "sum_dim_out",
512 "sum_out",
513 "ScaleDotProductAttention",
514 "SUPPORTED_FP8_DTYPE",
515 "tan",
516 "tan_",
517 "tanh",
518 "tanh_",
519 "tanh_backward",
520 "threshold",
521 "threshold_backward",
522 "tile",
523 "to_copy",
524 "topk",
525 "trace",
526 "triu",
527 "triu_",
528 "true_divide",
529 "true_divide_",
530 "true_divide_out",
531 "uniform_",
532 "upsample_linear1d",
533 "upsample_nearest1d",
534 "upsample_nearest2d",
535 "var_mean",
536 "vdot",
537 "vector_norm",
538 "vstack",
539 "weight_norm_interface",
540 "weight_norm_interface_backward",
541 "where_scalar_other",
542 "where_scalar_self",
543 "where_self",
544 "where_self_out",
545 "zeros",
546 "zero_",
547 "zeros_like",
548]