Coverage for src/flag_gems/runtime/backend/_ascend/ops/polar.py: 0%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8from flag_gems.utils.codegen_config_utils import CodeGenConfig 

9 

10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

11 

12 

13config_ = CodeGenConfig( 

14 384, 

15 tuple([48, 1, 1]), 

16 32, 

17 False, 

18 prefer_1d_tile=int(triton.__version__[0]) < 3, 

19) 

20 

21 

22@pointwise_dynamic( 

23 promotion_methods=[ 

24 ((0, 1), "DEFAULT"), 

25 ((0, 1), "DEFAULT"), 

26 ], 

27 num_outputs=2, 

28 config=config_, 

29) 

30@triton.jit 

31def polar_kernel(abs, angle): 

32 real = abs * tl.cos(angle) 

33 imag = abs * tl.sin(angle) 

34 return real, imag 

35 

36 

37def polar(abs, angle): 

38 logger.debug("GEMS_ASCEND POLAR") 

39 output = torch.empty((*abs.shape, 2), dtype=abs.dtype, device=abs.device) 

40 

41 polar_kernel(abs, angle, out0=output[..., 0], out1=output[..., 1]) 

42 

43 return torch.view_as_complex(output)