Coverage for src/flag_gems/runtime/backend/_mthreads/ops/__init__.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1from torch_musa import current_device, get_device_capability
3from .all import all, all_dim, all_dims
4from .amax import amax
5from .any import any, any_dim, any_dims
6from .arange import arange, arange_start
7from .argmin import argmin
8from .batch_norm import batch_norm, batch_norm_backward
9from .celu import celu
10from .conv2d import conv2d
11from .dropout import dropout, dropout_backward
12from .gather import gather, gather_backward
13from .index_add import index_add, index_add_
14from .index_put import index_put, index_put_
15from .index_select import index_select
16from .log import log
17from .log_softmax import log_softmax, log_softmax_backward
18from .max import max, max_dim
19from .min import min, min_dim
20from .normal import normal_
21from .one_hot import one_hot
22from .ones import ones
23from .ones_like import ones_like
24from .prod import prod, prod_dim
25from .rand import rand
26from .rand_like import rand_like
27from .randn import randn
28from .randn_like import randn_like
29from .randperm import randperm
30from .repeat import repeat
31from .repeat_interleave import (
32 repeat_interleave_self_int,
33 repeat_interleave_self_tensor,
34 repeat_interleave_tensor,
35)
36from .resolve_conj import resolve_conj
37from .sort import sort, sort_stable
38from .tile import tile
39from .zeros import zero_, zeros
40from .zeros_like import zeros_like
42__all__ = [
43 "amax",
44 "all",
45 "all_dim",
46 "all_dims",
47 "any",
48 "any_dim",
49 "any_dims",
50 "arange",
51 "arange_start",
52 "argmin",
53 "batch_norm",
54 "batch_norm_backward",
55 "celu",
56 # "celu_",
57 "conv2d",
58 "dropout",
59 "dropout_backward",
60 "gather",
61 "gather_backward",
62 "index_add",
63 "index_add_",
64 "index_put",
65 "index_put_",
66 "index_select",
67 "log",
68 "log_softmax",
69 "log_softmax_backward",
70 "max",
71 "max_dim",
72 "min",
73 "min_dim",
74 "normal_",
75 "one_hot",
76 "ones",
77 "ones_like",
78 "prod",
79 "prod_dim",
80 "rand",
81 "rand_like",
82 "randn",
83 "randn_like",
84 "randperm",
85 "repeat",
86 "repeat_interleave_self_int",
87 "repeat_interleave_self_tensor",
88 "repeat_interleave_tensor",
89 "resolve_conj",
90 "sort",
91 "sort_stable",
92 "tile",
93 "zero_",
94 "zeros",
95 "zeros_like",
96]
98if get_device_capability(current_device())[0] >= 3:
99 from .addmm import addmm # noqa: F401
100 from .bmm import bmm # noqa: F401
101 from .gelu import gelu # noqa: F401
102 from .mm import mm # noqa: F401
103 from .tanh import tanh # noqa: F401
105 __all__.extend(
106 [
107 "addmm",
108 "bmm",
109 "gelu",
110 "mm",
111 "tanh",
112 ]
113 )