Coverage for src/flag_gems/experimental_ops/arcsinh.py: 0%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def arcsinh_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0)
15 # Compute asinh using: asinh(x) = log(x + sqrt(x*x + 1))
16 x_f32 = x.to(tl.float32)
17 tmp = x_f32 * x_f32 + 1.0
18 sqrt_term = tl.sqrt(tmp)
19 y_f32 = tl.log(x_f32 + sqrt_term)
21 # Store result; will cast to out dtype as needed
22 tl.store(out_ptr + offsets, y_f32, mask=mask)
25def _ensure_cuda_tensor(t):
26 if not isinstance(t, torch.Tensor):
27 raise TypeError("Expected a torch.Tensor")
28 if not t.is_cuda:
29 raise ValueError("Input tensors must be on CUDA device")
30 if t.is_complex():
31 raise NotImplementedError(
32 "Complex dtypes are not supported by this Triton kernel"
33 )
36def _arcsinh_impl(input_tensor: torch.Tensor, out_tensor: torch.Tensor = None):
37 _ensure_cuda_tensor(input_tensor)
39 # Determine result dtype following basic promotion: float -> same, otherwise float32
40 if input_tensor.is_floating_point():
41 result_dtype = input_tensor.dtype
42 else:
43 result_dtype = torch.float32
45 x = input_tensor
46 n_elements = x.numel()
48 if out_tensor is None:
49 out = torch.empty_like(x, dtype=result_dtype, device=x.device)
50 else:
51 _ensure_cuda_tensor(out_tensor)
52 if out_tensor.numel() != n_elements:
53 raise ValueError(
54 "Output tensor must have the same number of elements as input"
55 )
56 # Enforce dtype consistent with promotion
57 if out_tensor.dtype != (result_dtype):
58 raise TypeError(
59 f"Output tensor has dtype {out_tensor.dtype}, expected {result_dtype}"
60 )
61 out = out_tensor
63 # Work with contiguous buffers for the kernel
64 x_contig = x.contiguous()
65 out_contig = out if out.is_contiguous() else out.contiguous()
67 # Launch kernel
68 BLOCK_SIZE = 1024
69 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
70 arcsinh_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE)
72 # If out was non-contiguous, copy back
73 if out_contig.data_ptr() != out.data_ptr():
74 out.copy_(out_contig)
76 return out
79def arcsinh(input_tensor: torch.Tensor):
80 return _arcsinh_impl(input_tensor)
83def arcsinh_out(input_tensor: torch.Tensor, out: torch.Tensor):
84 return _arcsinh_impl(input_tensor, out)