Coverage for src/flag_gems/ops/mul.py: 81%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

12@triton.jit 

13def mul_func(x, y): 

14 return x * y 

15 

16 

17@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

18@triton.jit 

19def mul_func_scalar(x, y): 

20 return x * y 

21 

22 

23@pointwise_dynamic( 

24 is_tensor=[True, True, True, True], # ar, ai, br, bi 

25 num_outputs=2, 

26 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")], 

27) 

28@triton.jit 

29def mul_complex_kernel(ar, ai, br, bi): 

30 real = ar * br - ai * bi 

31 imag = ar * bi + ai * br 

32 return real, imag 

33 

34 

35def mul(A, B): 

36 logger.debug("GEMS MUL") 

37 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

38 A, complex 

39 ) 

40 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

41 B, complex 

42 ) 

43 if A_is_complex or B_is_complex: 

44 # 1) A、B both are complex 

45 if A_is_complex and B_is_complex: 

46 Ar = torch.view_as_real(A) 

47 Br = torch.view_as_real(B) 

48 ar, ai = Ar[..., 0], Ar[..., 1] 

49 br, bi = Br[..., 0], Br[..., 1] 

50 common_dtype = torch.promote_types(ar.dtype, br.dtype) 

51 ar, ai = ar.to(common_dtype), ai.to(common_dtype) 

52 br, bi = br.to(common_dtype), bi.to(common_dtype) 

53 

54 real_out = torch.empty_like(ar, dtype=common_dtype) 

55 imag_out = torch.empty_like(ar, dtype=common_dtype) 

56 mul_complex_kernel(ar, ai, br, bi, out0=real_out, out1=imag_out) 

57 

58 out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1)) 

59 return out.to(torch.result_type(A, B)) 

60 # 2) A complex, B real 

61 elif A_is_complex and not B_is_complex: 

62 Ar = torch.view_as_real(A) 

63 Br = B.unsqueeze(-1) if isinstance(B, torch.Tensor) else B 

64 if isinstance(Br, torch.Tensor): 

65 out_real = mul_func(Ar, Br) 

66 else: 

67 out_real = mul_func_scalar(Ar, Br) 

68 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

69 # 3) A real, B complex 

70 else: # not A_is_complex and B_is_complex 

71 Br = torch.view_as_real(B) 

72 Ar = A.unsqueeze(-1) if isinstance(A, torch.Tensor) else A 

73 if isinstance(Ar, torch.Tensor): 

74 out_real = mul_func(Ar, Br) # shape broadcasting requires Ar and Br 

75 else: 

76 out_real = mul_func_scalar(Br, Ar) # Br is tensor, Ar is scalar 

77 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

78 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

79 return mul_func(A, B) 

80 elif isinstance(A, torch.Tensor): 

81 return mul_func_scalar(A, B) 

82 elif isinstance(B, torch.Tensor): 

83 return mul_func_scalar(B, A) 

84 else: 

85 # Both scalar 

86 return torch.tensor(A * B) 

87 

88 

89def mul_(A, B): 

90 logger.debug("GEMS MUL_") 

91 if isinstance(B, torch.Tensor): 

92 return mul_func(A, B, out0=A) 

93 else: 

94 return mul_func_scalar(A, B, out0=A)