Coverage for src/flag_gems/runtime/backend/_ascend/ops/angle.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9from flag_gems.utils.codegen_config_utils import CodeGenConfig
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14try:
15 import torch_npu # noqa: F401
17 atan2 = tl_extra_shim.atan2
18except ImportError: # noqa: E722
19 atan2 = tl_extra_shim.atan2
21config_ = CodeGenConfig(
22 256,
23 (40, 1, 1),
24 32,
25 False,
26 prefer_1d_tile=int(triton.__version__[0]) < 3,
27)
30@pointwise_dynamic(
31 is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")], config=config_
32)
33@triton.jit
34def angle_func(real, imag):
35 real_last, imag_last = (
36 (real.to(tl.float32), imag.to(tl.float32))
37 if real.dtype == tl.float16
38 else (real, imag)
39 )
40 result = atan2(imag_last, real_last)
41 return result
44@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "INT_TO_FLOAT")])
45@triton.jit
46def angle_float_and_int(real):
47 zero = 0.0
48 pi = math.pi
49 real_positive = real >= zero
50 result = tl.where(real_positive, zero, pi)
51 return result
54def angle(input_tensor: torch.Tensor) -> torch.Tensor:
55 logger.debug("GEMS_ASCEND ANGLE")
56 if input_tensor.dtype == torch.complex32 or input_tensor.dtype == torch.complex64:
57 real = input_tensor.real
58 imag = input_tensor.imag
59 return angle_func(real, imag)
60 else:
61 real = input_tensor
62 return angle_float_and_int(real)