Coverage for src/flag_gems/runtime/backend/_metax/ops/polar.py: 0%
30 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import libentry, libtuner
10logger = logging.getLogger("flag_gems." + __name__)
13@libentry()
14@libtuner(configs=runtime.get_tuned_config("polar"), key=["n_input"])
15@triton.jit
16def polar_kernel_kernel(
17 abs,
18 angle,
19 output,
20 n_input: tl.constexpr,
21 n_output: tl.constexpr,
22 BLOCK_SIZE: tl.constexpr,
23 num_warps: tl.constexpr,
24):
25 pid = tl.program_id(axis=0)
26 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
27 mask = offset < n_input
29 inp_abs = tl.load(abs + offset, mask=mask)
30 inp_angle = tl.load(angle + offset, mask=mask)
31 out_abs = inp_abs * tl.cos(inp_angle)
32 out_angle = inp_abs * tl.sin(inp_angle)
34 # interleave abs and angle for complex type results
35 results = tl.interleave(out_abs, out_angle)
36 output_offset = pid * BLOCK_SIZE * 2 + tl.arange(0, BLOCK_SIZE * 2)
37 output_mask = output_offset < n_output
38 tl.store(output + output_offset, results, mask=output_mask)
41def polar(abs, angle):
42 logger.debug("METAX GEMS polar")
43 output = torch.empty((*abs.shape, 2), dtype=abs.dtype, device=abs.device)
44 n_input = abs.numel()
45 n_output = output.numel()
47 grid = lambda meta: (triton.cdiv(n_output, meta["BLOCK_SIZE"]),)
48 polar_kernel_kernel[grid](abs, angle, output, n_input, n_output)
50 return torch.view_as_complex(output)