Coverage for src/flag_gems/runtime/backend/_metax/ops/polar.py: 0%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import libentry, libtuner 

9 

10logger = logging.getLogger("flag_gems." + __name__) 

11 

12 

13@libentry() 

14@libtuner(configs=runtime.get_tuned_config("polar"), key=["n_input"]) 

15@triton.jit 

16def polar_kernel_kernel( 

17 abs, 

18 angle, 

19 output, 

20 n_input: tl.constexpr, 

21 n_output: tl.constexpr, 

22 BLOCK_SIZE: tl.constexpr, 

23 num_warps: tl.constexpr, 

24): 

25 pid = tl.program_id(axis=0) 

26 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

27 mask = offset < n_input 

28 

29 inp_abs = tl.load(abs + offset, mask=mask) 

30 inp_angle = tl.load(angle + offset, mask=mask) 

31 out_abs = inp_abs * tl.cos(inp_angle) 

32 out_angle = inp_abs * tl.sin(inp_angle) 

33 

34 # interleave abs and angle for complex type results 

35 results = tl.interleave(out_abs, out_angle) 

36 output_offset = pid * BLOCK_SIZE * 2 + tl.arange(0, BLOCK_SIZE * 2) 

37 output_mask = output_offset < n_output 

38 tl.store(output + output_offset, results, mask=output_mask) 

39 

40 

41def polar(abs, angle): 

42 logger.debug("METAX GEMS polar") 

43 output = torch.empty((*abs.shape, 2), dtype=abs.dtype, device=abs.device) 

44 n_input = abs.numel() 

45 n_output = output.numel() 

46 

47 grid = lambda meta: (triton.cdiv(n_output, meta["BLOCK_SIZE"]),) 

48 polar_kernel_kernel[grid](abs, angle, output, n_input, n_output) 

49 

50 return torch.view_as_complex(output)