Coverage for src/flag_gems/runtime/backend/_mthreads/ops/__init__.py: 0%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1from torch_musa import current_device, get_device_capability
3from .all import all, all_dim, all_dims
4from .amax import amax
5from .any import any, any_dim, any_dims
6from .arange import arange, arange_start
7from .argmin import argmin
8from .batch_norm import batch_norm, batch_norm_backward
9from .celu import celu
10from .conv2d import conv2d
11from .dropout import dropout, dropout_backward
12from .gather import gather, gather_backward
13from .index_add import index_add, index_add_
14from .index_put import index_put, index_put_
15from .index_select import index_select
16from .log import log
17from .log_softmax import log_softmax, log_softmax_backward
18from .max import max, max_dim
19from .min import min, min_dim
20from .normal import normal_
21from .one_hot import one_hot
22from .ones import ones
23from .ones_like import ones_like
24from .prod import prod, prod_dim
25from .rand import rand
26from .rand_like import rand_like
27from .randn import randn
28from .randn_like import randn_like
29from .randperm import randperm
30from .repeat import repeat
31from .repeat_interleave import (
32 repeat_interleave_self_int,
33 repeat_interleave_self_tensor,
34 repeat_interleave_tensor,
35)
36from .resolve_conj import resolve_conj
37from .sort import sort, sort_stable
38from .tile import tile
39from .zeros import zero_, zeros
40from .zeros_like import zeros_like
42__all__ = [
43 "amax",
44 "rand",
45 "rand_like",
46 "dropout",
47 "dropout_backward",
48 "celu",
49 # "celu_",
50 "one_hot",
51 "ones",
52 "ones_like",
53 "randn",
54 "randn_like",
55 "zeros",
56 "zero_",
57 "zeros_like",
58 "log",
59 "log_softmax",
60 "log_softmax_backward",
61 "sort",
62 "arange",
63 "arange_start",
64 "sort_stable",
65 "randperm",
66 "repeat",
67 "repeat_interleave_self_int",
68 "repeat_interleave_self_tensor",
69 "repeat_interleave_tensor",
70 "conv2d",
71 "all",
72 "all_dim",
73 "all_dims",
74 "any",
75 "any_dim",
76 "any_dims",
77 "argmin",
78 "prod",
79 "prod_dim",
80 "min",
81 "min_dim",
82 "max",
83 "max_dim",
84 "batch_norm",
85 "batch_norm_backward",
86 "gather",
87 "gather_backward",
88 "index_add",
89 "index_add_",
90 "index_put",
91 "index_put_",
92 "index_select",
93 "resolve_conj",
94 "normal_",
95 "tile",
96]
98if get_device_capability(current_device())[0] >= 3:
99 from .addmm import addmm
100 from .bmm import bmm
101 from .gelu import gelu
102 from .mm import mm
103 from .tanh import tanh
105 __all__ += ["gelu"]
106 __all__ += ["tanh"]
107 __all__ += ["mm"]
108 __all__ += ["addmm"]
109 __all__ += ["bmm"]