Coverage for src/flag_gems/experimental_ops/arctanh.py: 0%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def arctanh_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

14 x_f32 = x.to(tl.float32) 

15 

16 one = 1.0 

17 # atanh(x) = 0.5 * (log(1 + x) - log(1 - x)) 

18 y_f32 = 0.5 * (tl.log(one + x_f32) - tl.log(one - x_f32)) 

19 y = y_f32.to(x.dtype) 

20 

21 tl.store(out_ptr + offsets, y, mask=mask) 

22 

23 

24def _launch_arctanh(x: torch.Tensor, out: torch.Tensor): 

25 assert x.is_cuda and out.is_cuda, "Input and output must be CUDA tensors" 

26 assert x.shape == out.shape, "Input and output shapes must match" 

27 assert out.dtype == x.dtype, "Output dtype must match input dtype" 

28 assert x.dtype in ( 

29 torch.float16, 

30 torch.bfloat16, 

31 torch.float32, 

32 ), "Supported dtypes: float16, bfloat16, float32" 

33 

34 x_contig = x.contiguous() 

35 out_contig = out if out.is_contiguous() else torch.empty_like(out) 

36 

37 n_elements = x_contig.numel() 

38 if n_elements == 0: 

39 if out_contig is not out: 

40 out.copy_(out_contig) 

41 return out 

42 

43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

44 arctanh_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=1024) 

45 

46 if out_contig is not out: 

47 out.copy_(out_contig) 

48 return out 

49 

50 

51def arctanh(x: torch.Tensor): 

52 out = torch.empty_like(x) 

53 _launch_arctanh(x, out) 

54 return out 

55 

56 

57def arctanh_out(x: torch.Tensor, out: torch.Tensor): 

58 _launch_arctanh(x, out) 

59 return out