Coverage for src/flag_gems/ops/arctanh_.py: 48%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import torch
3import triton
4import triton.language as tl
6from flag_gems.runtime import torch_device_fn
9@triton.jit
10def arctanh_kernel_(
11 x_ptr, n_elements, BLOCK_SIZE: tl.constexpr, COMPUTE_IN_FP32: tl.constexpr
12):
13 pid = tl.program_id(axis=0)
14 block_start = pid * BLOCK_SIZE
15 offsets = block_start + tl.arange(0, BLOCK_SIZE)
16 mask = offsets < n_elements
18 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
20 if COMPUTE_IN_FP32:
21 xf = x.to(tl.float32)
22 num = 1.0 + xf
23 den = 1.0 - xf
24 val = 0.5 * tl.log(num / den)
25 y = val.to(x.dtype)
26 else:
27 xf = x
28 num = 1 + xf
29 den = 1 - xf
30 val = 0.5 * tl.log(num / den)
31 y = val
33 tl.store(x_ptr + offsets, y, mask=mask)
36def arctanh_(*args, **kwargs):
37 # Extract the input tensor; accept positional or keywords like 'input' or 'self'
38 x = None
39 if len(args) >= 1 and isinstance(args[0], torch.Tensor):
40 x = args[0]
41 else:
42 x = kwargs.get("input", kwargs.get("self", None))
43 if not isinstance(x, torch.Tensor):
44 raise TypeError("arctanh_ expects a single Tensor argument")
46 if not x.is_contiguous():
47 raise ValueError("Input tensor must be contiguous")
48 if not x.is_floating_point():
49 raise TypeError("arctanh_ only supports floating point tensors")
51 n_elements = x.numel()
52 if n_elements == 0:
53 return x
55 use_fp32 = x.dtype in (torch.float16, torch.bfloat16)
57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
58 with torch_device_fn.device(x.device):
59 arctanh_kernel_[grid](x, n_elements, BLOCK_SIZE=1024, COMPUTE_IN_FP32=use_fp32)
60 return x