Coverage for src/flag_gems/experimental_ops/arctanh_.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def arctanh_( 

8 x_ptr, n_elements, BLOCK_SIZE: tl.constexpr, COMPUTE_IN_FP32: tl.constexpr 

9): 

10 pid = tl.program_id(axis=0) 

11 block_start = pid * BLOCK_SIZE 

12 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offsets < n_elements 

14 

15 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

16 

17 if COMPUTE_IN_FP32: 

18 xf = x.to(tl.float32) 

19 num = 1.0 + xf 

20 den = 1.0 - xf 

21 val = 0.5 * tl.log(num / den) 

22 y = val.to(x.dtype) 

23 else: 

24 xf = x 

25 num = 1 + xf 

26 den = 1 - xf 

27 val = 0.5 * tl.log(num / den) 

28 y = val 

29 

30 tl.store(x_ptr + offsets, y, mask=mask) 

31 

32 

33ARCTANH_KERNEL = arctanh_ 

34 

35 

36def arctanh_(*args, **kwargs): 

37 # Extract the input tensor; accept positional or keywords like 'input' or 'self' 

38 x = None 

39 if len(args) >= 1 and isinstance(args[0], torch.Tensor): 

40 x = args[0] 

41 else: 

42 x = kwargs.get("input", kwargs.get("self", None)) 

43 if not isinstance(x, torch.Tensor): 

44 raise TypeError("arctanh_ expects a single Tensor argument") 

45 

46 if not x.is_cuda: 

47 raise ValueError("Input tensor must be on CUDA device") 

48 if not x.is_contiguous(): 

49 raise ValueError("Input tensor must be contiguous") 

50 if not x.is_floating_point(): 

51 raise TypeError("arctanh_ only supports floating point tensors") 

52 

53 n_elements = x.numel() 

54 if n_elements == 0: 

55 return x 

56 

57 use_fp32 = x.dtype in (torch.float16, torch.bfloat16) 

58 

59 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

60 ARCTANH_KERNEL[grid](x, n_elements, BLOCK_SIZE=1024, COMPUTE_IN_FP32=use_fp32) 

61 return x