Coverage for src/flag_gems/experimental_ops/celu.py: 0%
67 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def celu_kernel(x_ptr, out_ptr, n_elements, alpha, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
15 x_fp = x.to(tl.float32)
16 y_fp = tl.where(x_fp > 0.0, x_fp, alpha * (tl.exp(x_fp / alpha) - 1.0))
17 y = y_fp.to(x.dtype)
19 tl.store(out_ptr + offsets, y, mask=mask)
22def _parse_alpha(alpha):
23 if isinstance(alpha, torch.Tensor):
24 if alpha.numel() != 1:
25 raise ValueError("alpha tensor must be a scalar (numel() == 1)")
26 alpha = float(alpha.item())
27 else:
28 alpha = float(alpha)
29 if alpha == 0.0:
30 raise ValueError("alpha must be non-zero")
31 return alpha
34def celu(input: torch.Tensor, alpha: float = 1.0):
35 alpha = _parse_alpha(alpha)
36 if not isinstance(input, torch.Tensor):
37 raise TypeError("input must be a torch.Tensor")
38 if not input.is_cuda:
39 raise ValueError("input must be on CUDA device")
40 if not torch.is_floating_point(input):
41 raise TypeError("input must be a floating point tensor")
43 x_contig = input.contiguous()
44 out = torch.empty_like(x_contig)
46 n_elements = out.numel()
47 if n_elements == 0:
48 return out
50 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
51 celu_kernel[grid](x_contig, out, n_elements, alpha, BLOCK_SIZE=1024)
52 return out
55def celu_out(input: torch.Tensor, alpha: float = 1.0, out: torch.Tensor = None):
56 alpha = _parse_alpha(alpha)
57 if not isinstance(input, torch.Tensor):
58 raise TypeError("input must be a torch.Tensor")
59 if out is None or not isinstance(out, torch.Tensor):
60 raise TypeError("out must be a preallocated torch.Tensor")
61 if not input.is_cuda or not out.is_cuda:
62 raise ValueError("input and out must be on CUDA device")
63 if not torch.is_floating_point(input) or not torch.is_floating_point(out):
64 raise TypeError("input and out must be floating point tensors")
65 if out.shape != input.shape:
66 raise ValueError("out must have the same shape as input")
67 if out.dtype != input.dtype:
68 raise ValueError("out must have the same dtype as input")
70 x_contig = input.contiguous()
71 if out.is_contiguous():
72 out_contig = out
73 else:
74 out_contig = torch.empty_like(x_contig)
76 n_elements = x_contig.numel()
77 if n_elements == 0:
78 if out_contig is not out:
79 out.copy_(out_contig)
80 return out
82 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
83 celu_kernel[grid](x_contig, out_contig, n_elements, alpha, BLOCK_SIZE=1024)
85 if out_contig is not out:
86 out.copy_(out_contig)
87 return out