Coverage for src/flag_gems/ops/angle.py: 73%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

9 

10atan2 = tl_extra_shim.atan2 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")]) 

16@triton.jit 

17def angle_func(real, imag): 

18 real_last, imag_last = ( 

19 (real.to(tl.float32), imag.to(tl.float32)) 

20 if real.dtype == tl.float16 

21 else (real, imag) 

22 ) 

23 result = atan2(imag_last, real_last) 

24 return result 

25 

26 

27@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "INT_TO_FLOAT")]) 

28@triton.jit 

29def angle_float_and_int(real): 

30 zero = 0.0 

31 pi = math.pi 

32 real_positive = real >= zero 

33 result = tl.where(real_positive, zero, pi) 

34 return result 

35 

36 

37def angle(input_tensor: torch.Tensor) -> torch.Tensor: 

38 logger.debug("GEMS ANGLE") 

39 if input_tensor.dtype == torch.complex32 or input_tensor.dtype == torch.complex64: 

40 real = input_tensor.real 

41 imag = input_tensor.imag 

42 return angle_func(real, imag) 

43 else: 

44 real = input_tensor 

45 return angle_float_and_int(real)