Coverage for src/flag_gems/experimental_ops/atanh_.py: 0%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def atanh_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

14 x_fp32 = x.to(tl.float32) 

15 numerator = 1.0 + x_fp32 

16 denominator = 1.0 - x_fp32 

17 out_fp32 = 0.5 * tl.log(numerator / denominator) 

18 out = out_fp32.to(x.dtype) 

19 

20 tl.store(x_ptr + offsets, out, mask=mask) 

21 

22 

23# Preserve a handle to the Triton kernel before defining the Python wrapper of the same name 

24atanh__triton_kernel = atanh_ 

25 

26 

27def atanh_(*args, **kwargs): 

28 if len(args) < 1: 

29 raise TypeError("atanh_ expects at least one argument: a torch.Tensor") 

30 x = args[0] 

31 if not isinstance(x, torch.Tensor): 

32 raise TypeError("atanh_ expects a torch.Tensor as the first argument") 

33 if not x.is_cuda: 

34 raise ValueError("atanh_ expects the tensor to be on a CUDA device") 

35 if not x.is_floating_point(): 

36 raise TypeError("atanh_ expects a floating-point tensor") 

37 

38 # Work on a contiguous buffer, then copy results back to x to preserve in-place semantics. 

39 xc = x.contiguous() 

40 n_elements = xc.numel() 

41 BLOCK_SIZE = 1024 

42 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

43 

44 atanh__triton_kernel[grid](xc, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

45 x.copy_(xc) 

46 return x