Coverage for src/flag_gems/runtime/backend/_mthreads/ops/__init__.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1from torch_musa import current_device, get_device_capability
3from .all import all, all_dim, all_dims
4from .amax import amax
5from .any import any, any_dim, any_dims
6from .arange import arange, arange_start
7from .argmin import argmin
8from .batch_norm import batch_norm, batch_norm_backward
9from .celu import celu
10from .conv2d import conv2d
11from .dropout import dropout, dropout_backward
12from .gather import gather, gather_backward
13from .index_add import index_add, index_add_
14from .index_put import index_put, index_put_
15from .index_select import index_select
16from .log import log
17from .log_softmax import log_softmax, log_softmax_backward
18from .max import max, max_dim
19from .min import min, min_dim
20from .normal import normal_
21from .ones import ones
22from .ones_like import ones_like
23from .prod import prod, prod_dim
24from .rand import rand
25from .rand_like import rand_like
26from .randn import randn
27from .randn_like import randn_like
28from .randperm import randperm
29from .repeat_interleave import (
30 repeat_interleave_self_int,
31 repeat_interleave_self_tensor,
32 repeat_interleave_tensor,
33)
34from .resolve_conj import resolve_conj
35from .sort import sort, sort_stable
36from .zeros import zero_, zeros
37from .zeros_like import zeros_like
39__all__ = [
40 "amax",
41 "rand",
42 "rand_like",
43 "dropout",
44 "dropout_backward",
45 "celu",
46 # "celu_",
47 "ones",
48 "ones_like",
49 "randn",
50 "randn_like",
51 "zeros",
52 "zero_",
53 "zeros_like",
54 "log",
55 "log_softmax",
56 "log_softmax_backward",
57 "sort",
58 "arange",
59 "arange_start",
60 "sort_stable",
61 "randperm",
62 "repeat_interleave_self_int",
63 "repeat_interleave_self_tensor",
64 "repeat_interleave_tensor",
65 "conv2d",
66 "all",
67 "all_dim",
68 "all_dims",
69 "any",
70 "any_dim",
71 "any_dims",
72 "argmin",
73 "prod",
74 "prod_dim",
75 "min",
76 "min_dim",
77 "max",
78 "max_dim",
79 "batch_norm",
80 "batch_norm_backward",
81 "gather",
82 "gather_backward",
83 "index_add",
84 "index_add_",
85 "index_put",
86 "index_put_",
87 "index_select",
88 "resolve_conj",
89 "normal_",
90]
92if get_device_capability(current_device())[0] >= 3:
93 from .addmm import addmm
94 from .bmm import bmm
95 from .gelu import gelu
96 from .mm import mm
97 from .tanh import tanh
99 __all__ += ["gelu"]
100 __all__ += ["tanh"]
101 __all__ += ["mm"]
102 __all__ += ["addmm"]
103 __all__ += ["bmm"]