Coverage for src/flag_gems/ops/mul.py: 81%
62 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def mul_func(x, y):
14 return x * y
17@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
18@triton.jit
19def mul_func_scalar(x, y):
20 return x * y
23@pointwise_dynamic(
24 is_tensor=[True, True, True, True], # ar, ai, br, bi
25 num_outputs=2,
26 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")],
27)
28@triton.jit
29def mul_complex_kernel(ar, ai, br, bi):
30 real = ar * br - ai * bi
31 imag = ar * bi + ai * br
32 return real, imag
35def mul(A, B):
36 logger.debug("GEMS MUL")
37 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
38 A, complex
39 )
40 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
41 B, complex
42 )
43 if A_is_complex or B_is_complex:
44 # 1) A、B both are complex
45 if A_is_complex and B_is_complex:
46 Ar = torch.view_as_real(A)
47 Br = torch.view_as_real(B)
48 ar, ai = Ar[..., 0], Ar[..., 1]
49 br, bi = Br[..., 0], Br[..., 1]
50 common_dtype = torch.promote_types(ar.dtype, br.dtype)
51 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
52 br, bi = br.to(common_dtype), bi.to(common_dtype)
54 real_out = torch.empty_like(ar, dtype=common_dtype)
55 imag_out = torch.empty_like(ar, dtype=common_dtype)
56 mul_complex_kernel(ar, ai, br, bi, out0=real_out, out1=imag_out)
58 out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1))
59 return out.to(torch.result_type(A, B))
60 # 2) A complex, B real
61 elif A_is_complex and not B_is_complex:
62 Ar = torch.view_as_real(A)
63 Br = B.unsqueeze(-1) if isinstance(B, torch.Tensor) else B
64 if isinstance(Br, torch.Tensor):
65 out_real = mul_func(Ar, Br)
66 else:
67 out_real = mul_func_scalar(Ar, Br)
68 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
69 # 3) A real, B complex
70 else: # not A_is_complex and B_is_complex
71 Br = torch.view_as_real(B)
72 Ar = A.unsqueeze(-1) if isinstance(A, torch.Tensor) else A
73 if isinstance(Ar, torch.Tensor):
74 out_real = mul_func(Ar, Br) # shape broadcasting requires Ar and Br
75 else:
76 out_real = mul_func_scalar(Br, Ar) # Br is tensor, Ar is scalar
77 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
78 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
79 return mul_func(A, B)
80 elif isinstance(A, torch.Tensor):
81 return mul_func_scalar(A, B)
82 elif isinstance(B, torch.Tensor):
83 return mul_func_scalar(B, A)
84 else:
85 # Both scalar
86 return torch.tensor(A * B)
89def mul_(A, B):
90 logger.debug("GEMS MUL_")
91 if isinstance(B, torch.Tensor):
92 return mul_func(A, B, out0=A)
93 else:
94 return mul_func_scalar(A, B, out0=A)