Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/angle.py: 0%

28 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11atan2 = tl_extra_shim.atan2 

12 

13 

14@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")]) 

15@triton.jit 

16def angle_func(real, imag): 

17 real_last, imag_last = ( 

18 (real.to(tl.float32), imag.to(tl.float32)) 

19 if real.dtype == tl.float16 

20 else (real, imag) 

21 ) 

22 result = atan2(imag_last, real_last) 

23 return result 

24 

25 

26@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "INT_TO_FLOAT")]) 

27@triton.jit 

28def angle_float_and_int(real): 

29 zero = 0.0 

30 pi = math.pi 

31 real_positive = real >= zero 

32 result = tl.where(real_positive, zero, pi) 

33 return result 

34 

35 

36def angle(input_tensor: torch.Tensor) -> torch.Tensor: 

37 if input_tensor.dtype == torch.complex32 or input_tensor.dtype == torch.complex64: 

38 real = input_tensor.real 

39 imag = input_tensor.imag 

40 return angle_func(real, imag) 

41 else: 

42 real = input_tensor 

43 return angle_float_and_int(real)