Coverage for src/flag_gems/experimental_ops/arctanh_.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def arctanh_(
8 x_ptr, n_elements, BLOCK_SIZE: tl.constexpr, COMPUTE_IN_FP32: tl.constexpr
9):
10 pid = tl.program_id(axis=0)
11 block_start = pid * BLOCK_SIZE
12 offsets = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offsets < n_elements
15 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
17 if COMPUTE_IN_FP32:
18 xf = x.to(tl.float32)
19 num = 1.0 + xf
20 den = 1.0 - xf
21 val = 0.5 * tl.log(num / den)
22 y = val.to(x.dtype)
23 else:
24 xf = x
25 num = 1 + xf
26 den = 1 - xf
27 val = 0.5 * tl.log(num / den)
28 y = val
30 tl.store(x_ptr + offsets, y, mask=mask)
33ARCTANH_KERNEL = arctanh_
36def arctanh_(*args, **kwargs):
37 # Extract the input tensor; accept positional or keywords like 'input' or 'self'
38 x = None
39 if len(args) >= 1 and isinstance(args[0], torch.Tensor):
40 x = args[0]
41 else:
42 x = kwargs.get("input", kwargs.get("self", None))
43 if not isinstance(x, torch.Tensor):
44 raise TypeError("arctanh_ expects a single Tensor argument")
46 if not x.is_cuda:
47 raise ValueError("Input tensor must be on CUDA device")
48 if not x.is_contiguous():
49 raise ValueError("Input tensor must be contiguous")
50 if not x.is_floating_point():
51 raise TypeError("arctanh_ only supports floating point tensors")
53 n_elements = x.numel()
54 if n_elements == 0:
55 return x
57 use_fp32 = x.dtype in (torch.float16, torch.bfloat16)
59 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
60 ARCTANH_KERNEL[grid](x, n_elements, BLOCK_SIZE=1024, COMPUTE_IN_FP32=use_fp32)
61 return x