Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/angle.py: 0%
28 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import math
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11atan2 = tl_extra_shim.atan2
14@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")])
15@triton.jit
16def angle_func(real, imag):
17 real_last, imag_last = (
18 (real.to(tl.float32), imag.to(tl.float32))
19 if real.dtype == tl.float16
20 else (real, imag)
21 )
22 result = atan2(imag_last, real_last)
23 return result
26@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "INT_TO_FLOAT")])
27@triton.jit
28def angle_float_and_int(real):
29 zero = 0.0
30 pi = math.pi
31 real_positive = real >= zero
32 result = tl.where(real_positive, zero, pi)
33 return result
36def angle(input_tensor: torch.Tensor) -> torch.Tensor:
37 if input_tensor.dtype == torch.complex32 or input_tensor.dtype == torch.complex64:
38 real = input_tensor.real
39 imag = input_tensor.imag
40 return angle_func(real, imag)
41 else:
42 real = input_tensor
43 return angle_float_and_int(real)