Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mul.py: 0%
57 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-30 03:43 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-30 03:43 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def mul_func(x, y):
14 return x * y
17@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
18@triton.jit
19def mul_func_scalar(x, y):
20 return x * y
23@pointwise_dynamic(
24 is_tensor=[True, True, True, True],
25 num_outputs=2,
26 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")],
27)
28@triton.jit
29def mul_complex_kernel(ar, ai, br, bi):
30 real = ar * br - ai * bi
31 imag = ar * bi + ai * br
32 return real, imag
35def mul(A, B):
36 logger.debug("GEMS MUL")
37 A_is_complex = isinstance(A, torch.Tensor) and A.is_complex()
38 B_is_complex = isinstance(B, torch.Tensor) and B.is_complex()
40 if A_is_complex or B_is_complex:
41 if A_is_complex and B_is_complex:
42 Ar = torch.view_as_real(A.resolve_conj())
43 Br = torch.view_as_real(B.resolve_conj())
44 ar, ai = Ar[..., 0].contiguous(), Ar[..., 1].contiguous()
45 br, bi = Br[..., 0].contiguous(), Br[..., 1].contiguous()
46 real_out, imag_out = mul_complex_kernel(ar, ai, br, bi)
47 out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1))
48 return out.to(torch.result_type(A, B))
49 elif A_is_complex and not B_is_complex:
50 Ar = torch.view_as_real(A.resolve_conj())
51 if isinstance(B, torch.Tensor):
52 Br = B.unsqueeze(-1)
53 out_real = mul_func(Ar, Br)
54 else:
55 out_real = mul_func_scalar(Ar, B)
56 return torch.view_as_complex(out_real.contiguous())
57 else:
58 Br = torch.view_as_real(B.resolve_conj())
59 if isinstance(A, torch.Tensor):
60 Ar = A.unsqueeze(-1)
61 out_real = mul_func(Ar, Br)
62 else:
63 out_real = mul_func_scalar(Br, A)
64 return torch.view_as_complex(out_real.contiguous())
66 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
67 return mul_func(A, B)
68 elif isinstance(A, torch.Tensor):
69 return mul_func_scalar(A, B)
70 elif isinstance(B, torch.Tensor):
71 return mul_func_scalar(B, A)
72 else:
73 # Both scalar
74 return torch.tensor(A * B)
77def mul_(A, B):
78 logger.debug("GEMS MUL_")
79 if isinstance(B, torch.Tensor):
80 return mul_func(A, B, out0=A)
81 else:
82 return mul_func_scalar(A, B, out0=A)