Coverage for src/flag_gems/runtime/backend/_ascend/ops/angle.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

9from flag_gems.utils.codegen_config_utils import CodeGenConfig 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14try: 

15 import torch_npu # noqa: F401 

16 

17 atan2 = tl_extra_shim.atan2 

18except ImportError: # noqa: E722 

19 atan2 = tl_extra_shim.atan2 

20 

21config_ = CodeGenConfig( 

22 256, 

23 (40, 1, 1), 

24 32, 

25 False, 

26 prefer_1d_tile=int(triton.__version__[0]) < 3, 

27) 

28 

29 

30@pointwise_dynamic( 

31 is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")], config=config_ 

32) 

33@triton.jit 

34def angle_func(real, imag): 

35 real_last, imag_last = ( 

36 (real.to(tl.float32), imag.to(tl.float32)) 

37 if real.dtype == tl.float16 

38 else (real, imag) 

39 ) 

40 result = atan2(imag_last, real_last) 

41 return result 

42 

43 

44@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "INT_TO_FLOAT")]) 

45@triton.jit 

46def angle_float_and_int(real): 

47 zero = 0.0 

48 pi = math.pi 

49 real_positive = real >= zero 

50 result = tl.where(real_positive, zero, pi) 

51 return result 

52 

53 

54def angle(input_tensor: torch.Tensor) -> torch.Tensor: 

55 logger.debug("GEMS_ASCEND ANGLE") 

56 if input_tensor.dtype == torch.complex32 or input_tensor.dtype == torch.complex64: 

57 real = input_tensor.real 

58 imag = input_tensor.imag 

59 return angle_func(real, imag) 

60 else: 

61 real = input_tensor 

62 return angle_float_and_int(real)