Coverage for src/flag_gems/experimental_ops/arccosh.py: 0%
38 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def arccosh_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
14 x32 = x.to(tl.float32)
16 # acosh(x) = log(x + sqrt(x - 1) * sqrt(x + 1))
17 s1 = tl.sqrt(x32 - 1.0)
18 s2 = tl.sqrt(x32 + 1.0)
19 y32 = tl.log(x32 + s1 * s2)
21 tl.store(out_ptr + offsets, y32, mask=mask)
24def arccosh(input: torch.Tensor):
25 assert input.is_cuda, "Input tensor must be on CUDA device"
26 assert input.dtype in (
27 torch.float16,
28 torch.bfloat16,
29 torch.float32,
30 ), "Supported dtypes: float16, bfloat16, float32"
32 x = input.contiguous()
33 out = torch.empty_like(x)
35 n_elements = x.numel()
36 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
38 arccosh_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
39 return out
42def arccosh_out(input: torch.Tensor, out: torch.Tensor):
43 assert input.is_cuda and out.is_cuda, "Tensors must be on CUDA device"
44 assert input.shape == out.shape, "Input and out must have the same shape"
45 assert input.dtype == out.dtype, "Input and out must have the same dtype"
46 assert input.dtype in (
47 torch.float16,
48 torch.bfloat16,
49 torch.float32,
50 ), "Supported dtypes: float16, bfloat16, float32"
52 x = input.contiguous()
53 n_elements = x.numel()
54 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
56 if out.is_contiguous():
57 arccosh_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
58 else:
59 tmp = torch.empty_like(x)
60 arccosh_kernel[grid](x, tmp, n_elements, BLOCK_SIZE=1024)
61 out.copy_(tmp)
62 return out