Coverage for src/flag_gems/ops/polar.py: 82%

17 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic( 

13 promotion_methods=[ 

14 ((0, 1), "DEFAULT"), 

15 ((0, 1), "DEFAULT"), 

16 ], 

17 num_outputs=2, 

18) 

19@triton.jit 

20def polar_kernel(abs, angle): 

21 real = abs * tl.cos(angle) 

22 imag = abs * tl.sin(angle) 

23 return real, imag 

24 

25 

26def polar(abs, angle): 

27 logger.debug("GEMS POLAR") 

28 output = torch.empty((*abs.shape, 2), dtype=abs.dtype, device=abs.device) 

29 

30 polar_kernel(abs, angle, out0=output[..., 0], out1=output[..., 1]) 

31 

32 return torch.view_as_complex(output)