Coverage for src/flag_gems/ops/atanh.py: 57%
21 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
13@triton.jit
14def atanh_func(x):
15 # atanh(x) = 0.5 * ln((1 + x) / (1 - x))
16 # Compute in float32 for better precision, then convert back
17 x_fp32 = x.to(tl.float32)
18 one = 1.0
19 # Compute result: 0.5 * log((1 + x) / (1 - x))
20 numerator = one + x_fp32
21 denominator = one - x_fp32
22 # For x outside (-1, 1), log of negative or zero gives NaN/inf naturally
23 result = 0.5 * tl.math.log(numerator / denominator)
24 return result.to(x.dtype)
27def atanh(A):
28 logger.debug("GEMS ATANH")
29 return atanh_func(A)
32def atanh_(A):
33 logger.debug("GEMS ATANH_")
34 atanh_func(A, out0=A)
35 return A