Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mul.py: 0%
27 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def mul_func(x, y):
14 return x * y
17@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
18@triton.jit
19def mul_func_scalar(x, y):
20 return x * y
23def mul(A, B):
24 logger.debug("GEMS MUL")
25 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
26 return mul_func(A, B)
27 elif isinstance(A, torch.Tensor):
28 return mul_func_scalar(A, B)
29 elif isinstance(B, torch.Tensor):
30 return mul_func_scalar(B, A)
31 else:
32 # Both scalar
33 return torch.tensor(A * B)
36def mul_(A, B):
37 logger.debug("GEMS MUL_")
38 if isinstance(B, torch.Tensor):
39 return mul_func(A, B, out0=A)
40 else:
41 return mul_func_scalar(A, B, out0=A)