Coverage for src/flag_gems/experimental_ops/arctanh.py: 0%
39 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def arctanh_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0)
14 x_f32 = x.to(tl.float32)
16 one = 1.0
17 # atanh(x) = 0.5 * (log(1 + x) - log(1 - x))
18 y_f32 = 0.5 * (tl.log(one + x_f32) - tl.log(one - x_f32))
19 y = y_f32.to(x.dtype)
21 tl.store(out_ptr + offsets, y, mask=mask)
24def _launch_arctanh(x: torch.Tensor, out: torch.Tensor):
25 assert x.is_cuda and out.is_cuda, "Input and output must be CUDA tensors"
26 assert x.shape == out.shape, "Input and output shapes must match"
27 assert out.dtype == x.dtype, "Output dtype must match input dtype"
28 assert x.dtype in (
29 torch.float16,
30 torch.bfloat16,
31 torch.float32,
32 ), "Supported dtypes: float16, bfloat16, float32"
34 x_contig = x.contiguous()
35 out_contig = out if out.is_contiguous() else torch.empty_like(out)
37 n_elements = x_contig.numel()
38 if n_elements == 0:
39 if out_contig is not out:
40 out.copy_(out_contig)
41 return out
43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
44 arctanh_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=1024)
46 if out_contig is not out:
47 out.copy_(out_contig)
48 return out
51def arctanh(x: torch.Tensor):
52 out = torch.empty_like(x)
53 _launch_arctanh(x, out)
54 return out
57def arctanh_out(x: torch.Tensor, out: torch.Tensor):
58 _launch_arctanh(x, out)
59 return out