Coverage for src/flag_gems/experimental_ops/arcsinh.py: 0%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def arcsinh_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

14 

15 # Compute asinh using: asinh(x) = log(x + sqrt(x*x + 1)) 

16 x_f32 = x.to(tl.float32) 

17 tmp = x_f32 * x_f32 + 1.0 

18 sqrt_term = tl.sqrt(tmp) 

19 y_f32 = tl.log(x_f32 + sqrt_term) 

20 

21 # Store result; will cast to out dtype as needed 

22 tl.store(out_ptr + offsets, y_f32, mask=mask) 

23 

24 

25def _ensure_cuda_tensor(t): 

26 if not isinstance(t, torch.Tensor): 

27 raise TypeError("Expected a torch.Tensor") 

28 if not t.is_cuda: 

29 raise ValueError("Input tensors must be on CUDA device") 

30 if t.is_complex(): 

31 raise NotImplementedError( 

32 "Complex dtypes are not supported by this Triton kernel" 

33 ) 

34 

35 

36def _arcsinh_impl(input_tensor: torch.Tensor, out_tensor: torch.Tensor = None): 

37 _ensure_cuda_tensor(input_tensor) 

38 

39 # Determine result dtype following basic promotion: float -> same, otherwise float32 

40 if input_tensor.is_floating_point(): 

41 result_dtype = input_tensor.dtype 

42 else: 

43 result_dtype = torch.float32 

44 

45 x = input_tensor 

46 n_elements = x.numel() 

47 

48 if out_tensor is None: 

49 out = torch.empty_like(x, dtype=result_dtype, device=x.device) 

50 else: 

51 _ensure_cuda_tensor(out_tensor) 

52 if out_tensor.numel() != n_elements: 

53 raise ValueError( 

54 "Output tensor must have the same number of elements as input" 

55 ) 

56 # Enforce dtype consistent with promotion 

57 if out_tensor.dtype != (result_dtype): 

58 raise TypeError( 

59 f"Output tensor has dtype {out_tensor.dtype}, expected {result_dtype}" 

60 ) 

61 out = out_tensor 

62 

63 # Work with contiguous buffers for the kernel 

64 x_contig = x.contiguous() 

65 out_contig = out if out.is_contiguous() else out.contiguous() 

66 

67 # Launch kernel 

68 BLOCK_SIZE = 1024 

69 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

70 arcsinh_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

71 

72 # If out was non-contiguous, copy back 

73 if out_contig.data_ptr() != out.data_ptr(): 

74 out.copy_(out_contig) 

75 

76 return out 

77 

78 

79def arcsinh(input_tensor: torch.Tensor): 

80 return _arcsinh_impl(input_tensor) 

81 

82 

83def arcsinh_out(input_tensor: torch.Tensor, out: torch.Tensor): 

84 return _arcsinh_impl(input_tensor, out)