Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mul.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

12@triton.jit 

13def mul_func(x, y, inplace): 

14 return x * y 

15 

16 

17@pointwise_dynamic( 

18 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

19) 

20@triton.jit 

21def mul_func_scalar(x, y, inplace): 

22 return x * y 

23 

24 

25def mul(A, B): 

26 logger.debug("GEMS_CAMBRICON MUL") 

27 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

28 if A.device != B.device: 

29 if A.dim() == 0: 

30 assert A.device == torch.device("cpu"), "expect scalar tensor on cpu" 

31 A = A.to(B.device) 

32 elif B.dim() == 0: 

33 assert B.device == torch.device("cpu"), "expect scalar tensor on cpu" 

34 B = B.to(A.device) 

35 return mul_func(A, B, False) 

36 elif isinstance(A, torch.Tensor): 

37 return mul_func_scalar(A, B, False) 

38 elif isinstance(B, torch.Tensor): 

39 return mul_func_scalar(B, A, False) 

40 else: 

41 # Both scalar 

42 return torch.tensor(A * B) 

43 

44 

45def mul_(A, B): 

46 logger.debug("GEMS_CAMBRICON MUL_") 

47 if isinstance(B, torch.Tensor): 

48 if B.device != A.device and B.dim() == 0: 

49 assert B.device == torch.device("cpu"), "expect scalar tensor on cpu" 

50 B = B.to(A.device) 

51 return mul_func(A, B, True, out0=A) 

52 else: 

53 return mul_func_scalar(A, B, True, out0=A)