Coverage for src/flag_gems/experimental_ops/arccosh.py: 0%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def arccosh_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 x32 = x.to(tl.float32) 

15 

16 # acosh(x) = log(x + sqrt(x - 1) * sqrt(x + 1)) 

17 s1 = tl.sqrt(x32 - 1.0) 

18 s2 = tl.sqrt(x32 + 1.0) 

19 y32 = tl.log(x32 + s1 * s2) 

20 

21 tl.store(out_ptr + offsets, y32, mask=mask) 

22 

23 

24def arccosh(input: torch.Tensor): 

25 assert input.is_cuda, "Input tensor must be on CUDA device" 

26 assert input.dtype in ( 

27 torch.float16, 

28 torch.bfloat16, 

29 torch.float32, 

30 ), "Supported dtypes: float16, bfloat16, float32" 

31 

32 x = input.contiguous() 

33 out = torch.empty_like(x) 

34 

35 n_elements = x.numel() 

36 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

37 

38 arccosh_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) 

39 return out 

40 

41 

42def arccosh_out(input: torch.Tensor, out: torch.Tensor): 

43 assert input.is_cuda and out.is_cuda, "Tensors must be on CUDA device" 

44 assert input.shape == out.shape, "Input and out must have the same shape" 

45 assert input.dtype == out.dtype, "Input and out must have the same dtype" 

46 assert input.dtype in ( 

47 torch.float16, 

48 torch.bfloat16, 

49 torch.float32, 

50 ), "Supported dtypes: float16, bfloat16, float32" 

51 

52 x = input.contiguous() 

53 n_elements = x.numel() 

54 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

55 

56 if out.is_contiguous(): 

57 arccosh_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) 

58 else: 

59 tmp = torch.empty_like(x) 

60 arccosh_kernel[grid](x, tmp, n_elements, BLOCK_SIZE=1024) 

61 out.copy_(tmp) 

62 return out