Coverage for src/flag_gems/runtime/backend/_cambricon/ops/add.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic( 

12 is_tensor=[True, True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

13) 

14@triton.jit 

15def add_func(x, y, alpha, inplace): 

16 return x + y * alpha 

17 

18 

19@pointwise_dynamic( 

20 is_tensor=[True, False, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

21) 

22@triton.jit 

23def add_func_tensor_scalar(x, y, alpha, inplace): 

24 return x + y * alpha 

25 

26 

27@pointwise_dynamic( 

28 is_tensor=[False, True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

29) 

30@triton.jit 

31def add_func_scalar_tensor(x, y, alpha, inplace): 

32 return x + y * alpha 

33 

34 

35def add(A, B, *, alpha=1): 

36 logger.debug("GEMS_CAMBRICON ADD") 

37 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

38 if B.device != A.device: 

39 B = B.to(A.device) 

40 return add_func(A, B, alpha, False) 

41 elif isinstance(A, torch.Tensor): 

42 return add_func_tensor_scalar(A, B, alpha, False) 

43 elif isinstance(B, torch.Tensor): 

44 return add_func_scalar_tensor(A, B, alpha, False) 

45 else: 

46 return torch.tensor(A + B * alpha) 

47 

48 

49def add_(A, B, *, alpha=1): 

50 logger.debug("GEMS_CAMBRICON ADD_") 

51 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

52 if B.device != A.device: 

53 B = B.to(A.device) 

54 return add_func(A, B, alpha, True, out0=A) 

55 elif isinstance(A, torch.Tensor): 

56 return add_func_tensor_scalar(A, B, alpha, True, out0=A) 

57 # elif isinstance(B, torch.Tensor): 

58 # return add_func_scalar_tensor(A, B, alpha, out0=A) 

59 else: 

60 raise ValueError("Unreachable.")