Coverage for src/flag_gems/experimental_ops/celu_.py: 0%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def celu_( 

8 x_ptr, # Pointer to input tensor (will be modified in-place) 

9 n_elements, # Number of elements in the tensor 

10 alpha, # CELU alpha parameter (scalar) 

11 BLOCK_SIZE: tl.constexpr, 

12): 

13 pid = tl.program_id(axis=0) 

14 block_start = pid * BLOCK_SIZE 

15 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

16 mask = offsets < n_elements 

17 

18 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

19 x32 = tl.cast(x, tl.float32) 

20 alpha32 = tl.cast(alpha, tl.float32) 

21 

22 neg_part = alpha32 * (tl.exp(x32 / alpha32) - 1.0) 

23 y32 = tl.where(x32 > 0, x32, neg_part) 

24 y = tl.cast(y32, x.dtype) 

25 

26 tl.store(x_ptr + offsets, y, mask=mask) 

27 

28 

29# Preserve reference to the Triton kernel before defining the Python wrapper with the same name 

30celu_kernel = celu_ 

31 

32 

33def celu_(x: torch.Tensor, alpha: float = 1.0): 

34 assert x.is_cuda, "Input tensor must be on CUDA device." 

35 assert x.is_floating_point(), "CELU requires a floating point tensor." 

36 n_elements = x.numel() 

37 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

38 celu_kernel[grid](x, n_elements, alpha, BLOCK_SIZE=1024) 

39 return x