Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mul.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def mul_func(x, y, inplace):
14 return x * y
17@pointwise_dynamic(
18 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
19)
20@triton.jit
21def mul_func_scalar(x, y, inplace):
22 return x * y
25def mul(A, B):
26 logger.debug("GEMS_CAMBRICON MUL")
27 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
28 if A.device != B.device:
29 if A.dim() == 0:
30 assert A.device == torch.device("cpu"), "expect scalar tensor on cpu"
31 A = A.to(B.device)
32 elif B.dim() == 0:
33 assert B.device == torch.device("cpu"), "expect scalar tensor on cpu"
34 B = B.to(A.device)
35 return mul_func(A, B, False)
36 elif isinstance(A, torch.Tensor):
37 return mul_func_scalar(A, B, False)
38 elif isinstance(B, torch.Tensor):
39 return mul_func_scalar(B, A, False)
40 else:
41 # Both scalar
42 return torch.tensor(A * B)
45def mul_(A, B):
46 logger.debug("GEMS_CAMBRICON MUL_")
47 if isinstance(B, torch.Tensor):
48 if B.device != A.device and B.dim() == 0:
49 assert B.device == torch.device("cpu"), "expect scalar tensor on cpu"
50 B = B.to(A.device)
51 return mul_func(A, B, True, out0=A)
52 else:
53 return mul_func_scalar(A, B, True, out0=A)