Coverage for src/flag_gems/experimental_ops/cosh_.py: 0%
32 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def cosh_(
8 x_ptr, # *Pointer* to input vector (modified in-place).
9 n_elements, # Size of the vector.
10 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
11):
12 pid = tl.program_id(axis=0)
13 block_start = pid * BLOCK_SIZE
14 offsets = block_start + tl.arange(0, BLOCK_SIZE)
15 mask = offsets < n_elements
17 x = tl.load(x_ptr + offsets, mask=mask)
18 x32 = tl.cast(x, tl.float32)
19 e_pos = tl.exp(x32)
20 e_neg = tl.exp(-x32)
21 y32 = 0.5 * (e_pos + e_neg)
22 tl.store(x_ptr + offsets, y32, mask=mask)
25_triton_cosh_kernel = cosh_
28def cosh_(*args, **kwargs):
29 x = (
30 args[0]
31 if len(args) > 0
32 else kwargs.get("input", None)
33 or kwargs.get("x", None)
34 or kwargs.get("self", None)
35 )
36 if x is None:
37 raise ValueError("cosh_ expects a tensor as the first positional argument.")
38 if not isinstance(x, torch.Tensor):
39 raise TypeError("cosh_ expects a torch.Tensor input.")
40 if not x.is_cuda:
41 raise ValueError("cosh_ Triton kernel requires a CUDA tensor.")
42 if not x.is_contiguous():
43 raise ValueError(
44 "cosh_ Triton kernel currently supports contiguous tensors only."
45 )
46 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32):
47 raise TypeError(
48 "cosh_ Triton kernel supports float16, bfloat16, and float32 tensors."
49 )
51 n_elements = x.numel()
52 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
53 _triton_cosh_kernel[grid](x, n_elements, BLOCK_SIZE=1024)
54 return x