Coverage for src/flag_gems/experimental_ops/selu.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def selu_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 x_f32 = x.to(tl.float32) 

15 

16 # SELU constants from PyTorch 

17 alpha = 1.6732632423543772848170429916717 

18 scale = 1.0507009873554804934193349852946 

19 

20 zero = 0.0 

21 x_neg = tl.minimum(x_f32, zero) # clamp to non-positive to avoid exp overflow 

22 neg_part = alpha * (tl.exp(x_neg) - 1.0) 

23 out_f32 = tl.where(x_f32 > 0.0, x_f32, neg_part) 

24 out_f32 = scale * out_f32 

25 

26 y = out_f32.to(x.dtype) 

27 tl.store(y_ptr + offsets, y, mask=mask) 

28 

29 

30def selu(*args, **kwargs): 

31 # Resolve input tensor from args/kwargs 

32 x = None 

33 if len(args) > 0: 

34 x = args[0] 

35 elif "input" in kwargs: 

36 x = kwargs["input"] 

37 elif "self" in kwargs: 

38 x = kwargs["self"] 

39 else: 

40 raise TypeError("selu() missing required argument 'input' (pos 1)") 

41 

42 if not isinstance(x, torch.Tensor): 

43 raise TypeError("selu() expected a torch.Tensor as input") 

44 

45 # Fallback to PyTorch if not on CUDA 

46 if x.device.type != "cuda": 

47 return torch.ops.aten.selu(x) 

48 

49 if not x.is_floating_point(): 

50 raise TypeError("selu() expected a floating point tensor") 

51 

52 x_contig = x.contiguous() 

53 y = torch.empty_like(x_contig) 

54 

55 n_elements = y.numel() 

56 BLOCK_SIZE = 1024 

57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

58 

59 selu_kernel[grid](x_contig, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

60 return y