Coverage for src/flag_gems/runtime/backend/_cambricon/ops/add.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(
12 is_tensor=[True, True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
13)
14@triton.jit
15def add_func(x, y, alpha, inplace):
16 return x + y * alpha
19@pointwise_dynamic(
20 is_tensor=[True, False, False, False], promotion_methods=[(0, 1, "DEFAULT")]
21)
22@triton.jit
23def add_func_tensor_scalar(x, y, alpha, inplace):
24 return x + y * alpha
27@pointwise_dynamic(
28 is_tensor=[False, True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
29)
30@triton.jit
31def add_func_scalar_tensor(x, y, alpha, inplace):
32 return x + y * alpha
35def add(A, B, *, alpha=1):
36 logger.debug("GEMS_CAMBRICON ADD")
37 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
38 if B.device != A.device:
39 B = B.to(A.device)
40 return add_func(A, B, alpha, False)
41 elif isinstance(A, torch.Tensor):
42 return add_func_tensor_scalar(A, B, alpha, False)
43 elif isinstance(B, torch.Tensor):
44 return add_func_scalar_tensor(A, B, alpha, False)
45 else:
46 return torch.tensor(A + B * alpha)
49def add_(A, B, *, alpha=1):
50 logger.debug("GEMS_CAMBRICON ADD_")
51 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
52 if B.device != A.device:
53 B = B.to(A.device)
54 return add_func(A, B, alpha, True, out0=A)
55 elif isinstance(A, torch.Tensor):
56 return add_func_tensor_scalar(A, B, alpha, True, out0=A)
57 # elif isinstance(B, torch.Tensor):
58 # return add_func_scalar_tensor(A, B, alpha, out0=A)
59 else:
60 raise ValueError("Unreachable.")