Coverage for src/flag_gems/ops/arctanh_.py: 48%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import torch 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7 

8 

9@triton.jit 

10def arctanh_kernel_( 

11 x_ptr, n_elements, BLOCK_SIZE: tl.constexpr, COMPUTE_IN_FP32: tl.constexpr 

12): 

13 pid = tl.program_id(axis=0) 

14 block_start = pid * BLOCK_SIZE 

15 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

16 mask = offsets < n_elements 

17 

18 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

19 

20 if COMPUTE_IN_FP32: 

21 xf = x.to(tl.float32) 

22 num = 1.0 + xf 

23 den = 1.0 - xf 

24 val = 0.5 * tl.log(num / den) 

25 y = val.to(x.dtype) 

26 else: 

27 xf = x 

28 num = 1 + xf 

29 den = 1 - xf 

30 val = 0.5 * tl.log(num / den) 

31 y = val 

32 

33 tl.store(x_ptr + offsets, y, mask=mask) 

34 

35 

36def arctanh_(*args, **kwargs): 

37 # Extract the input tensor; accept positional or keywords like 'input' or 'self' 

38 x = None 

39 if len(args) >= 1 and isinstance(args[0], torch.Tensor): 

40 x = args[0] 

41 else: 

42 x = kwargs.get("input", kwargs.get("self", None)) 

43 if not isinstance(x, torch.Tensor): 

44 raise TypeError("arctanh_ expects a single Tensor argument") 

45 

46 if not x.is_contiguous(): 

47 raise ValueError("Input tensor must be contiguous") 

48 if not x.is_floating_point(): 

49 raise TypeError("arctanh_ only supports floating point tensors") 

50 

51 n_elements = x.numel() 

52 if n_elements == 0: 

53 return x 

54 

55 use_fp32 = x.dtype in (torch.float16, torch.bfloat16) 

56 

57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

58 with torch_device_fn.device(x.device): 

59 arctanh_kernel_[grid](x, n_elements, BLOCK_SIZE=1024, COMPUTE_IN_FP32=use_fp32) 

60 return x