Coverage for src/flag_gems/runtime/backend/_ascend/ops/polar.py: 0%
19 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.codegen_config_utils import CodeGenConfig
10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
13config_ = CodeGenConfig(
14 384,
15 tuple([48, 1, 1]),
16 32,
17 False,
18 prefer_1d_tile=int(triton.__version__[0]) < 3,
19)
22@pointwise_dynamic(
23 promotion_methods=[
24 ((0, 1), "DEFAULT"),
25 ((0, 1), "DEFAULT"),
26 ],
27 num_outputs=2,
28 config=config_,
29)
30@triton.jit
31def polar_kernel(abs, angle):
32 real = abs * tl.cos(angle)
33 imag = abs * tl.sin(angle)
34 return real, imag
37def polar(abs, angle):
38 logger.debug("GEMS_ASCEND POLAR")
39 output = torch.empty((*abs.shape, 2), dtype=abs.dtype, device=abs.device)
41 polar_kernel(abs, angle, out0=output[..., 0], out1=output[..., 1])
43 return torch.view_as_complex(output)