Coverage for src/flag_gems/ops/polar.py: 82%
17 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(
13 promotion_methods=[
14 ((0, 1), "DEFAULT"),
15 ((0, 1), "DEFAULT"),
16 ],
17 num_outputs=2,
18)
19@triton.jit
20def polar_kernel(abs, angle):
21 real = abs * tl.cos(angle)
22 imag = abs * tl.sin(angle)
23 return real, imag
26def polar(abs, angle):
27 logger.debug("GEMS POLAR")
28 output = torch.empty((*abs.shape, 2), dtype=abs.dtype, device=abs.device)
30 polar_kernel(abs, angle, out0=output[..., 0], out1=output[..., 1])
32 return torch.view_as_complex(output)