Coverage for src/flag_gems/runtime/backend/_mthreads/ops/__init__.py: 0%
38 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1from torch_musa import current_device, get_device_capability
3from .all import all, all_dim, all_dims
4from .any import any, any_dim, any_dims
5from .arange import arange, arange_start
6from .argmin import argmin
7from .batch_norm import batch_norm, batch_norm_backward
8from .celu import celu
9from .conv2d import conv2d
10from .dropout import dropout, dropout_backward
11from .gather import gather, gather_backward
12from .index_put import index_put, index_put_
13from .log import log
14from .max import max, max_dim
15from .min import min, min_dim
16from .ones import ones
17from .ones_like import ones_like
18from .prod import prod, prod_dim
19from .rand import rand
20from .rand_like import rand_like
21from .randn import randn
22from .randn_like import randn_like
23from .randperm import randperm
24from .resolve_conj import resolve_conj
25from .sort import sort, sort_stable
26from .zeros import zero_, zeros
27from .zeros_like import zeros_like
29__all__ = [
30 "rand",
31 "rand_like",
32 "dropout",
33 "dropout_backward",
34 "celu",
35 # "celu_",
36 "ones",
37 "ones_like",
38 "randn",
39 "randn_like",
40 "zeros",
41 "zero_",
42 "zeros_like",
43 "log",
44 "sort",
45 "arange",
46 "arange_start",
47 "sort_stable",
48 "randperm",
49 "conv2d",
50 "all",
51 "all_dim",
52 "all_dims",
53 "any",
54 "any_dim",
55 "any_dims",
56 "argmin",
57 "prod",
58 "prod_dim",
59 "min",
60 "min_dim",
61 "max",
62 "max_dim",
63 "batch_norm",
64 "batch_norm_backward",
65 "gather",
66 "gather_backward",
67 "index_put",
68 "index_put_",
69 "resolve_conj",
70]
72if get_device_capability(current_device())[0] >= 3:
73 from .addmm import addmm
74 from .bmm import bmm
75 from .gelu import gelu
76 from .mm import mm
77 from .tanh import tanh
79 __all__ += ["gelu"]
80 __all__ += ["tanh"]
81 __all__ += ["mm"]
82 __all__ += ["addmm"]
83 __all__ += ["bmm"]