Coverage for src/flag_gems/experimental_ops/celu.py: 0%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def celu_kernel(x_ptr, out_ptr, n_elements, alpha, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 

15 x_fp = x.to(tl.float32) 

16 y_fp = tl.where(x_fp > 0.0, x_fp, alpha * (tl.exp(x_fp / alpha) - 1.0)) 

17 y = y_fp.to(x.dtype) 

18 

19 tl.store(out_ptr + offsets, y, mask=mask) 

20 

21 

22def _parse_alpha(alpha): 

23 if isinstance(alpha, torch.Tensor): 

24 if alpha.numel() != 1: 

25 raise ValueError("alpha tensor must be a scalar (numel() == 1)") 

26 alpha = float(alpha.item()) 

27 else: 

28 alpha = float(alpha) 

29 if alpha == 0.0: 

30 raise ValueError("alpha must be non-zero") 

31 return alpha 

32 

33 

34def celu(input: torch.Tensor, alpha: float = 1.0): 

35 alpha = _parse_alpha(alpha) 

36 if not isinstance(input, torch.Tensor): 

37 raise TypeError("input must be a torch.Tensor") 

38 if not input.is_cuda: 

39 raise ValueError("input must be on CUDA device") 

40 if not torch.is_floating_point(input): 

41 raise TypeError("input must be a floating point tensor") 

42 

43 x_contig = input.contiguous() 

44 out = torch.empty_like(x_contig) 

45 

46 n_elements = out.numel() 

47 if n_elements == 0: 

48 return out 

49 

50 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

51 celu_kernel[grid](x_contig, out, n_elements, alpha, BLOCK_SIZE=1024) 

52 return out 

53 

54 

55def celu_out(input: torch.Tensor, alpha: float = 1.0, out: torch.Tensor = None): 

56 alpha = _parse_alpha(alpha) 

57 if not isinstance(input, torch.Tensor): 

58 raise TypeError("input must be a torch.Tensor") 

59 if out is None or not isinstance(out, torch.Tensor): 

60 raise TypeError("out must be a preallocated torch.Tensor") 

61 if not input.is_cuda or not out.is_cuda: 

62 raise ValueError("input and out must be on CUDA device") 

63 if not torch.is_floating_point(input) or not torch.is_floating_point(out): 

64 raise TypeError("input and out must be floating point tensors") 

65 if out.shape != input.shape: 

66 raise ValueError("out must have the same shape as input") 

67 if out.dtype != input.dtype: 

68 raise ValueError("out must have the same dtype as input") 

69 

70 x_contig = input.contiguous() 

71 if out.is_contiguous(): 

72 out_contig = out 

73 else: 

74 out_contig = torch.empty_like(x_contig) 

75 

76 n_elements = x_contig.numel() 

77 if n_elements == 0: 

78 if out_contig is not out: 

79 out.copy_(out_contig) 

80 return out 

81 

82 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

83 celu_kernel[grid](x_contig, out_contig, n_elements, alpha, BLOCK_SIZE=1024) 

84 

85 if out_contig is not out: 

86 out.copy_(out_contig) 

87 return out