Coverage for src/flag_gems/ops/angle.py: 73%
30 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import pointwise_dynamic, tl_extra_shim
10atan2 = tl_extra_shim.atan2
12logger = logging.getLogger(__name__)
15@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")])
16@triton.jit
17def angle_func(real, imag):
18 real_last, imag_last = (
19 (real.to(tl.float32), imag.to(tl.float32))
20 if real.dtype == tl.float16
21 else (real, imag)
22 )
23 result = atan2(imag_last, real_last)
24 return result
27@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "INT_TO_FLOAT")])
28@triton.jit
29def angle_float_and_int(real):
30 zero = 0.0
31 pi = math.pi
32 real_positive = real >= zero
33 result = tl.where(real_positive, zero, pi)
34 return result
37def angle(input_tensor: torch.Tensor) -> torch.Tensor:
38 logger.debug("GEMS ANGLE")
39 if input_tensor.dtype == torch.complex32 or input_tensor.dtype == torch.complex64:
40 real = input_tensor.real
41 imag = input_tensor.imag
42 return angle_func(real, imag)
43 else:
44 real = input_tensor
45 return angle_float_and_int(real)