Coverage for src/flag_gems/experimental_ops/celu_.py: 0%
24 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def celu_(
8 x_ptr, # Pointer to input tensor (will be modified in-place)
9 n_elements, # Number of elements in the tensor
10 alpha, # CELU alpha parameter (scalar)
11 BLOCK_SIZE: tl.constexpr,
12):
13 pid = tl.program_id(axis=0)
14 block_start = pid * BLOCK_SIZE
15 offsets = block_start + tl.arange(0, BLOCK_SIZE)
16 mask = offsets < n_elements
18 x = tl.load(x_ptr + offsets, mask=mask, other=0)
19 x32 = tl.cast(x, tl.float32)
20 alpha32 = tl.cast(alpha, tl.float32)
22 neg_part = alpha32 * (tl.exp(x32 / alpha32) - 1.0)
23 y32 = tl.where(x32 > 0, x32, neg_part)
24 y = tl.cast(y32, x.dtype)
26 tl.store(x_ptr + offsets, y, mask=mask)
29# Preserve reference to the Triton kernel before defining the Python wrapper with the same name
30celu_kernel = celu_
33def celu_(x: torch.Tensor, alpha: float = 1.0):
34 assert x.is_cuda, "Input tensor must be on CUDA device."
35 assert x.is_floating_point(), "CELU requires a floating point tensor."
36 n_elements = x.numel()
37 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
38 celu_kernel[grid](x, n_elements, alpha, BLOCK_SIZE=1024)
39 return x