Coverage for src/flag_gems/ops/asinh_.py: 44%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import torch 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7 

8 

9@triton.jit 

10def asinh_kernel_( 

11 x_ptr, n_elements, BLOCK_SIZE: tl.constexpr, COMPUTE_FP32: tl.constexpr 

12): 

13 pid = tl.program_id(axis=0) 

14 block_start = pid * BLOCK_SIZE 

15 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

16 mask = offsets < n_elements 

17 

18 x = tl.load(x_ptr + offsets, mask=mask) 

19 

20 if COMPUTE_FP32: 

21 x32 = x.to(tl.float32) 

22 y32 = tl.log(x32 + tl.sqrt(x32 * x32 + 1.0)) 

23 y = y32.to(x.dtype) 

24 else: 

25 y = tl.log(x + tl.sqrt(x * x + 1.0)) 

26 

27 tl.store(x_ptr + offsets, y, mask=mask) 

28 

29 

30def asinh_(*args, **kwargs): 

31 x = None 

32 if len(args) > 0 and isinstance(args[0], torch.Tensor): 

33 x = args[0] 

34 else: 

35 for key in ("input", "self", "x"): 

36 val = kwargs.get(key, None) 

37 if isinstance(val, torch.Tensor): 

38 x = val 

39 break 

40 if x is None: 

41 raise ValueError("asinh_: expected a Tensor as the first argument") 

42 

43 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

44 return torch.ops.aten.asinh_(x) 

45 

46 BLOCK_SIZE = 1024 

47 COMPUTE_FP32 = x.dtype in (torch.float16, torch.bfloat16) 

48 

49 if x.is_contiguous(): 

50 n_elements = x.numel() 

51 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

52 with torch_device_fn.device(x.device): 

53 asinh_kernel_[grid]( 

54 x, n_elements, BLOCK_SIZE=BLOCK_SIZE, COMPUTE_FP32=COMPUTE_FP32 

55 ) 

56 return x 

57 else: 

58 y = x.contiguous() 

59 n_elements = y.numel() 

60 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

61 with torch_device_fn.device(y.device): 

62 asinh_kernel_[grid]( 

63 y, n_elements, BLOCK_SIZE=BLOCK_SIZE, COMPUTE_FP32=COMPUTE_FP32 

64 ) 

65 x.copy_(y) 

66 return x