Coverage for src/flag_gems/experimental_ops/cosh_.py: 0%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def cosh_( 

8 x_ptr, # *Pointer* to input vector (modified in-place). 

9 n_elements, # Size of the vector. 

10 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. 

11): 

12 pid = tl.program_id(axis=0) 

13 block_start = pid * BLOCK_SIZE 

14 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

15 mask = offsets < n_elements 

16 

17 x = tl.load(x_ptr + offsets, mask=mask) 

18 x32 = tl.cast(x, tl.float32) 

19 e_pos = tl.exp(x32) 

20 e_neg = tl.exp(-x32) 

21 y32 = 0.5 * (e_pos + e_neg) 

22 tl.store(x_ptr + offsets, y32, mask=mask) 

23 

24 

25_triton_cosh_kernel = cosh_ 

26 

27 

28def cosh_(*args, **kwargs): 

29 x = ( 

30 args[0] 

31 if len(args) > 0 

32 else kwargs.get("input", None) 

33 or kwargs.get("x", None) 

34 or kwargs.get("self", None) 

35 ) 

36 if x is None: 

37 raise ValueError("cosh_ expects a tensor as the first positional argument.") 

38 if not isinstance(x, torch.Tensor): 

39 raise TypeError("cosh_ expects a torch.Tensor input.") 

40 if not x.is_cuda: 

41 raise ValueError("cosh_ Triton kernel requires a CUDA tensor.") 

42 if not x.is_contiguous(): 

43 raise ValueError( 

44 "cosh_ Triton kernel currently supports contiguous tensors only." 

45 ) 

46 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32): 

47 raise TypeError( 

48 "cosh_ Triton kernel supports float16, bfloat16, and float32 tensors." 

49 ) 

50 

51 n_elements = x.numel() 

52 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

53 _triton_cosh_kernel[grid](x, n_elements, BLOCK_SIZE=1024) 

54 return x