Coverage for src/flag_gems/ops/__init__.py: 100%
171 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1from flag_gems.ops.abs import abs, abs_
2from flag_gems.ops.acos import acos
3from flag_gems.ops.add import add, add_
4from flag_gems.ops.addcdiv import addcdiv
5from flag_gems.ops.addcmul import addcmul
6from flag_gems.ops.addmm import addmm, addmm_out
7from flag_gems.ops.addmv import addmv, addmv_out
8from flag_gems.ops.addr import addr
9from flag_gems.ops.all import all, all_dim, all_dims
10from flag_gems.ops.amax import amax
11from flag_gems.ops.angle import angle
12from flag_gems.ops.any import any, any_dim, any_dims
13from flag_gems.ops.arange import arange, arange_start
14from flag_gems.ops.argmax import argmax
15from flag_gems.ops.argmin import argmin
16from flag_gems.ops.atan import atan, atan_
17from flag_gems.ops.attention import (
18 ScaleDotProductAttention,
19 flash_attention_forward,
20 flash_attn_varlen_func,
21 scaled_dot_product_attention,
22 scaled_dot_product_attention_backward,
23 scaled_dot_product_attention_forward,
24)
25from flag_gems.ops.avg_pool2d import avg_pool2d, avg_pool2d_backward
26from flag_gems.ops.baddbmm import baddbmm
27from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward
28from flag_gems.ops.bitwise_and import (
29 bitwise_and_scalar,
30 bitwise_and_scalar_,
31 bitwise_and_scalar_tensor,
32 bitwise_and_tensor,
33 bitwise_and_tensor_,
34)
35from flag_gems.ops.bitwise_left_shift import bitwise_left_shift
36from flag_gems.ops.bitwise_not import bitwise_not, bitwise_not_
37from flag_gems.ops.bitwise_or import (
38 bitwise_or_scalar,
39 bitwise_or_scalar_,
40 bitwise_or_scalar_tensor,
41 bitwise_or_tensor,
42 bitwise_or_tensor_,
43)
44from flag_gems.ops.bitwise_right_shift import bitwise_right_shift
45from flag_gems.ops.bmm import bmm, bmm_out
46from flag_gems.ops.cat import cat
47from flag_gems.ops.ceil import ceil, ceil_, ceil_out
48from flag_gems.ops.celu import celu, celu_
49from flag_gems.ops.clamp import (
50 clamp,
51 clamp_,
52 clamp_min,
53 clamp_min_,
54 clamp_tensor,
55 clamp_tensor_,
56)
57from flag_gems.ops.contiguous import contiguous
58from flag_gems.ops.conv1d import conv1d
59from flag_gems.ops.conv2d import conv2d
60from flag_gems.ops.conv3d import conv3d
61from flag_gems.ops.conv_depthwise2d import _conv_depthwise2d
62from flag_gems.ops.copy import copy, copy_
63from flag_gems.ops.cos import cos, cos_
64from flag_gems.ops.count_nonzero import count_nonzero
65from flag_gems.ops.cummax import cummax
66from flag_gems.ops.cummin import cummin
67from flag_gems.ops.cumsum import cumsum, cumsum_out, normed_cumsum
68from flag_gems.ops.diag import diag
69from flag_gems.ops.diag_embed import diag_embed
70from flag_gems.ops.diagonal import diagonal_backward
71from flag_gems.ops.div import (
72 div_mode,
73 div_mode_,
74 floor_divide,
75 floor_divide_,
76 remainder,
77 remainder_,
78 true_divide,
79 true_divide_,
80 true_divide_out,
81)
82from flag_gems.ops.dot import dot
83from flag_gems.ops.dropout import dropout, dropout_backward
84from flag_gems.ops.elu import elu, elu_, elu_backward
85from flag_gems.ops.embedding import embedding, embedding_backward
86from flag_gems.ops.eq import eq, eq_scalar, equal
87from flag_gems.ops.erf import erf, erf_
88from flag_gems.ops.exp import exp, exp_, exp_out
89from flag_gems.ops.exp2 import exp2, exp2_
90from flag_gems.ops.exponential_ import exponential_
91from flag_gems.ops.eye import eye
92from flag_gems.ops.eye_m import eye_m
93from flag_gems.ops.fill import fill_scalar, fill_scalar_, fill_tensor, fill_tensor_
94from flag_gems.ops.flip import flip
95from flag_gems.ops.full import full
96from flag_gems.ops.full_like import full_like
97from flag_gems.ops.gather import gather, gather_backward
98from flag_gems.ops.ge import ge, ge_scalar
99from flag_gems.ops.gelu import gelu, gelu_, gelu_backward
100from flag_gems.ops.get_scheduler_metadata import get_scheduler_metadata
101from flag_gems.ops.glu import glu, glu_backward
102from flag_gems.ops.groupnorm import group_norm, group_norm_backward
103from flag_gems.ops.gt import gt, gt_scalar
104from flag_gems.ops.hstack import hstack
105from flag_gems.ops.index import index
106from flag_gems.ops.index_add import index_add, index_add_
107from flag_gems.ops.index_put import index_put, index_put_
108from flag_gems.ops.index_select import index_select
109from flag_gems.ops.isclose import allclose, isclose
110from flag_gems.ops.isfinite import isfinite
111from flag_gems.ops.isin import isin
112from flag_gems.ops.isinf import isinf
113from flag_gems.ops.isnan import isnan
114from flag_gems.ops.kron import kron
115from flag_gems.ops.layernorm import layer_norm, layer_norm_backward
116from flag_gems.ops.le import le, le_scalar
117from flag_gems.ops.lerp import lerp_scalar, lerp_scalar_, lerp_tensor, lerp_tensor_
118from flag_gems.ops.linspace import linspace
119from flag_gems.ops.log import log
120from flag_gems.ops.log_sigmoid import log_sigmoid
121from flag_gems.ops.log_softmax import log_softmax, log_softmax_backward
122from flag_gems.ops.logical_and import logical_and, logical_and_
123from flag_gems.ops.logical_not import logical_not
124from flag_gems.ops.logical_or import logical_or, logical_or_
125from flag_gems.ops.logical_xor import logical_xor
126from flag_gems.ops.logspace import logspace
127from flag_gems.ops.lt import lt, lt_scalar
128from flag_gems.ops.masked_fill import masked_fill, masked_fill_
129from flag_gems.ops.masked_scatter import masked_scatter, masked_scatter_
130from flag_gems.ops.masked_select import masked_select
131from flag_gems.ops.max import max, max_dim
132from flag_gems.ops.max_pool2d_with_indices import (
133 max_pool2d_backward,
134 max_pool2d_with_indices,
135)
136from flag_gems.ops.maximum import maximum
137from flag_gems.ops.mean import mean, mean_dim
138from flag_gems.ops.min import min, min_dim
139from flag_gems.ops.minimum import minimum
140from flag_gems.ops.mm import mm, mm_out
141from flag_gems.ops.mse_loss import mse_loss
142from flag_gems.ops.mul import mul, mul_
143from flag_gems.ops.multinomial import multinomial
144from flag_gems.ops.mv import mv
145from flag_gems.ops.nan_to_num import nan_to_num
146from flag_gems.ops.ne import ne, ne_scalar
147from flag_gems.ops.neg import neg, neg_
148from flag_gems.ops.nllloss import (
149 nll_loss2d_backward,
150 nll_loss2d_forward,
151 nll_loss_backward,
152 nll_loss_forward,
153)
154from flag_gems.ops.nonzero import nonzero
155from flag_gems.ops.normal import (
156 normal_,
157 normal_float_tensor,
158 normal_tensor_float,
159 normal_tensor_tensor,
160)
161from flag_gems.ops.one_hot import one_hot
162from flag_gems.ops.ones import ones
163from flag_gems.ops.ones_like import ones_like
164from flag_gems.ops.pad import constant_pad_nd, pad
165from flag_gems.ops.per_token_group_quant_fp8 import (
166 SUPPORTED_FP8_DTYPE,
167 per_token_group_quant_fp8,
168)
169from flag_gems.ops.polar import polar
170from flag_gems.ops.pow import (
171 pow_scalar,
172 pow_tensor_scalar,
173 pow_tensor_scalar_,
174 pow_tensor_tensor,
175 pow_tensor_tensor_,
176)
177from flag_gems.ops.prod import prod, prod_dim
178from flag_gems.ops.quantile import quantile
179from flag_gems.ops.rand import rand
180from flag_gems.ops.rand_like import rand_like
181from flag_gems.ops.randn import randn
182from flag_gems.ops.randn_like import randn_like
183from flag_gems.ops.randperm import randperm
184from flag_gems.ops.reciprocal import reciprocal, reciprocal_
185from flag_gems.ops.relu import relu, relu_
186from flag_gems.ops.repeat import repeat
187from flag_gems.ops.repeat_interleave import (
188 repeat_interleave_self_int,
189 repeat_interleave_self_tensor,
190 repeat_interleave_tensor,
191)
192from flag_gems.ops.resolve_conj import resolve_conj
193from flag_gems.ops.resolve_neg import resolve_neg
194from flag_gems.ops.rms_norm import rms_norm, rms_norm_backward, rms_norm_forward
195from flag_gems.ops.rsqrt import rsqrt, rsqrt_
196from flag_gems.ops.scaled_softmax import scaled_softmax_backward, scaled_softmax_forward
197from flag_gems.ops.scatter import scatter, scatter_
198from flag_gems.ops.scatter_add_ import scatter_add_
199from flag_gems.ops.select_scatter import select_scatter
200from flag_gems.ops.sigmoid import sigmoid, sigmoid_, sigmoid_backward
201from flag_gems.ops.silu import silu, silu_, silu_backward
202from flag_gems.ops.sin import sin, sin_
203from flag_gems.ops.slice_scatter import slice_scatter
204from flag_gems.ops.softmax import softmax, softmax_backward
205from flag_gems.ops.softplus import softplus
206from flag_gems.ops.sort import sort, sort_stable
207from flag_gems.ops.sqrt import sqrt, sqrt_
208from flag_gems.ops.stack import stack
209from flag_gems.ops.std import std
210from flag_gems.ops.sub import sub, sub_
211from flag_gems.ops.sum import sum, sum_dim, sum_dim_out, sum_out
212from flag_gems.ops.tan import tan, tan_
213from flag_gems.ops.tanh import tanh, tanh_, tanh_backward
214from flag_gems.ops.threshold import threshold, threshold_backward
215from flag_gems.ops.tile import tile
216from flag_gems.ops.to import to_copy
217from flag_gems.ops.topk import topk
218from flag_gems.ops.trace import trace
219from flag_gems.ops.triu import triu, triu_
220from flag_gems.ops.uniform import uniform_
221from flag_gems.ops.unique import _unique2
222from flag_gems.ops.upsample_bicubic2d_aa import _upsample_bicubic2d_aa
223from flag_gems.ops.upsample_nearest1d import upsample_nearest1d
224from flag_gems.ops.upsample_nearest2d import upsample_nearest2d
225from flag_gems.ops.var_mean import var_mean
226from flag_gems.ops.vdot import vdot
227from flag_gems.ops.vector_norm import vector_norm
228from flag_gems.ops.vstack import vstack
229from flag_gems.ops.weightnorm import (
230 weight_norm_interface,
231 weight_norm_interface_backward,
232)
233from flag_gems.ops.where import (
234 where_scalar_other,
235 where_scalar_self,
236 where_self,
237 where_self_out,
238)
239from flag_gems.ops.zeros import zero_, zeros
240from flag_gems.ops.zeros_like import zeros_like
242__all__ = [
243 "_conv_depthwise2d",
244 "_unique2",
245 "_upsample_bicubic2d_aa",
246 "abs",
247 "abs_",
248 "acos",
249 "add",
250 "add_",
251 "addcdiv",
252 "addcmul",
253 "addmm",
254 "addmm_out",
255 "addmv",
256 "addmv_out",
257 "addr",
258 "all",
259 "all_dim",
260 "all_dims",
261 "allclose",
262 "amax",
263 "angle",
264 "any",
265 "any_dim",
266 "any_dims",
267 "arange",
268 "arange_start",
269 "argmax",
270 "argmin",
271 "atan",
272 "atan_",
273 "avg_pool2d",
274 "avg_pool2d_backward",
275 "baddbmm",
276 "batch_norm",
277 "batch_norm_backward",
278 "bitwise_and_scalar",
279 "bitwise_and_scalar_",
280 "bitwise_and_scalar_tensor",
281 "bitwise_and_tensor",
282 "bitwise_and_tensor_",
283 "bitwise_left_shift",
284 "bitwise_not",
285 "bitwise_not_",
286 "bitwise_or_scalar",
287 "bitwise_or_scalar_",
288 "bitwise_or_scalar_tensor",
289 "bitwise_or_tensor",
290 "bitwise_or_tensor_",
291 "bitwise_right_shift",
292 "bmm",
293 "bmm_out",
294 "cat",
295 "ceil",
296 "ceil_",
297 "ceil_out",
298 "celu",
299 "celu_",
300 "clamp",
301 "clamp_",
302 "clamp_min",
303 "clamp_min_",
304 "clamp_tensor",
305 "clamp_tensor_",
306 "constant_pad_nd",
307 "contiguous",
308 "conv1d",
309 "conv2d",
310 "conv3d",
311 "copy",
312 "copy_",
313 "cos",
314 "cos_",
315 "count_nonzero",
316 "cummax",
317 "cummin",
318 "cumsum",
319 "cumsum_out",
320 "diag",
321 "diag_embed",
322 "diagonal_backward",
323 "div_mode",
324 "div_mode_",
325 "dot",
326 "dropout",
327 "dropout_backward",
328 "elu",
329 "elu_",
330 "elu_backward",
331 "embedding",
332 "embedding_backward",
333 "eq",
334 "eq_scalar",
335 "equal",
336 "erf",
337 "erf_",
338 "exp",
339 "exp_",
340 "exp_out",
341 "exp2",
342 "exp2_",
343 "exponential_",
344 "eye",
345 "eye_m",
346 "fill_scalar",
347 "fill_scalar_",
348 "fill_tensor",
349 "fill_tensor_",
350 "flash_attention_forward",
351 "flash_attn_varlen_func",
352 "flip",
353 "floor_divide",
354 "floor_divide_",
355 "full",
356 "full_like",
357 "gather",
358 "gather_backward",
359 "ge",
360 "ge_scalar",
361 "gelu",
362 "gelu_",
363 "gelu_backward",
364 "get_scheduler_metadata",
365 "glu",
366 "glu_backward",
367 "group_norm",
368 "group_norm_backward",
369 "gt",
370 "gt_scalar",
371 "hstack",
372 "index",
373 "index_add",
374 "index_add_",
375 "index_put",
376 "index_put_",
377 "index_select",
378 "isclose",
379 "isfinite",
380 "isin",
381 "isinf",
382 "isnan",
383 "kron",
384 "layer_norm",
385 "layer_norm_backward",
386 "le",
387 "le_scalar",
388 "lerp_scalar",
389 "lerp_scalar_",
390 "lerp_tensor",
391 "lerp_tensor_",
392 "linspace",
393 "log",
394 "log_sigmoid",
395 "log_softmax",
396 "log_softmax_backward",
397 "logical_and",
398 "logical_and_",
399 "logical_not",
400 "logical_or",
401 "logical_or_",
402 "logical_xor",
403 "logspace",
404 "lt",
405 "lt_scalar",
406 "masked_fill",
407 "masked_fill_",
408 "masked_scatter",
409 "masked_scatter_",
410 "masked_select",
411 "max",
412 "max_dim",
413 "max_pool2d_with_indices",
414 "max_pool2d_backward",
415 "maximum",
416 "mean",
417 "mean_dim",
418 "min",
419 "min_dim",
420 "minimum",
421 "mm",
422 "mm_out",
423 "mse_loss",
424 "mul",
425 "mul_",
426 "multinomial",
427 "mv",
428 "nan_to_num",
429 "ne",
430 "ne_scalar",
431 "neg",
432 "neg_",
433 "nll_loss_backward",
434 "nll_loss_forward",
435 "nll_loss2d_backward",
436 "nll_loss2d_forward",
437 "nonzero",
438 "normal_float_tensor",
439 "normal_tensor_float",
440 "normal_tensor_tensor",
441 "normal_",
442 "normed_cumsum",
443 "ones",
444 "ones_like",
445 "one_hot",
446 "pad",
447 "per_token_group_quant_fp8",
448 "polar",
449 "pow_scalar",
450 "pow_tensor_scalar",
451 "pow_tensor_scalar_",
452 "pow_tensor_tensor",
453 "pow_tensor_tensor_",
454 "prod",
455 "prod_dim",
456 "quantile",
457 "rand",
458 "rand_like",
459 "randn",
460 "randn_like",
461 "randperm",
462 "reciprocal",
463 "reciprocal_",
464 "relu",
465 "relu_",
466 "remainder",
467 "remainder_",
468 "repeat",
469 "repeat_interleave_self_int",
470 "repeat_interleave_self_tensor",
471 "repeat_interleave_tensor",
472 "resolve_conj",
473 "resolve_neg",
474 "rms_norm",
475 "rms_norm_backward",
476 "rms_norm_forward",
477 "rsqrt",
478 "rsqrt_",
479 "scaled_dot_product_attention",
480 "scaled_dot_product_attention_backward",
481 "scaled_dot_product_attention_forward",
482 "scaled_softmax_backward",
483 "scaled_softmax_forward",
484 "scatter",
485 "scatter_",
486 "scatter_add_",
487 "select_scatter",
488 "sigmoid",
489 "sigmoid_",
490 "sigmoid_backward",
491 "silu",
492 "silu_",
493 "silu_backward",
494 "sin",
495 "sin_",
496 "slice_scatter",
497 "softmax",
498 "softmax_backward",
499 "softplus",
500 "sort",
501 "sort_stable",
502 "sqrt",
503 "sqrt_",
504 "stack",
505 "std",
506 "sub",
507 "sub_",
508 "sum",
509 "sum_dim",
510 "sum_dim_out",
511 "sum_out",
512 "ScaleDotProductAttention",
513 "SUPPORTED_FP8_DTYPE",
514 "tan",
515 "tan_",
516 "tanh",
517 "tanh_",
518 "tanh_backward",
519 "threshold",
520 "threshold_backward",
521 "tile",
522 "to_copy",
523 "topk",
524 "trace",
525 "triu",
526 "triu_",
527 "true_divide",
528 "true_divide_",
529 "true_divide_out",
530 "uniform_",
531 "upsample_nearest1d",
532 "upsample_nearest2d",
533 "var_mean",
534 "vdot",
535 "vector_norm",
536 "vstack",
537 "weight_norm_interface",
538 "weight_norm_interface_backward",
539 "where_scalar_other",
540 "where_scalar_self",
541 "where_self",
542 "where_self_out",
543 "zeros",
544 "zero_",
545 "zeros_like",
546]