Coverage for src/flag_gems/experimental_ops/atanh_.py: 0%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def atanh_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0)
14 x_fp32 = x.to(tl.float32)
15 numerator = 1.0 + x_fp32
16 denominator = 1.0 - x_fp32
17 out_fp32 = 0.5 * tl.log(numerator / denominator)
18 out = out_fp32.to(x.dtype)
20 tl.store(x_ptr + offsets, out, mask=mask)
23# Preserve a handle to the Triton kernel before defining the Python wrapper of the same name
24atanh__triton_kernel = atanh_
27def atanh_(*args, **kwargs):
28 if len(args) < 1:
29 raise TypeError("atanh_ expects at least one argument: a torch.Tensor")
30 x = args[0]
31 if not isinstance(x, torch.Tensor):
32 raise TypeError("atanh_ expects a torch.Tensor as the first argument")
33 if not x.is_cuda:
34 raise ValueError("atanh_ expects the tensor to be on a CUDA device")
35 if not x.is_floating_point():
36 raise TypeError("atanh_ expects a floating-point tensor")
38 # Work on a contiguous buffer, then copy results back to x to preserve in-place semantics.
39 xc = x.contiguous()
40 n_elements = xc.numel()
41 BLOCK_SIZE = 1024
42 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
44 atanh__triton_kernel[grid](xc, n_elements, BLOCK_SIZE=BLOCK_SIZE)
45 x.copy_(xc)
46 return x