Coverage for src/flag_gems/ops/selu.py: 51%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def selu_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask) 

21 x_f32 = x.to(tl.float32) 

22 

23 # SELU constants from PyTorch 

24 alpha = 1.6732632423543772848170429916717 

25 scale = 1.0507009873554804934193349852946 

26 

27 zero = 0.0 

28 x_neg = tl.minimum(x_f32, zero) # clamp to non-positive to avoid exp overflow 

29 neg_part = alpha * (tl.exp(x_neg) - 1.0) 

30 out_f32 = tl.where(x_f32 > 0.0, x_f32, neg_part) 

31 out_f32 = scale * out_f32 

32 

33 y = out_f32.to(x.dtype) 

34 tl.store(y_ptr + offsets, y, mask=mask) 

35 

36 

37def selu(*args, **kwargs): 

38 logger.debug("GEMS SELU") 

39 x = None 

40 if len(args) > 0: 

41 x = args[0] 

42 elif "input" in kwargs: 

43 x = kwargs["input"] 

44 elif "self" in kwargs: 

45 x = kwargs["self"] 

46 else: 

47 raise TypeError("selu() missing required argument 'input' (pos 1)") 

48 

49 if not isinstance(x, torch.Tensor): 

50 raise TypeError("selu() expected a torch.Tensor as input") 

51 

52 if not x.is_floating_point(): 

53 raise TypeError("selu() expected a floating point tensor") 

54 

55 x_contig = x.contiguous() 

56 y = torch.empty_like(x_contig) 

57 

58 n_elements = y.numel() 

59 BLOCK_SIZE = 1024 

60 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

61 

62 with torch_device_fn.device(x_contig.device): 

63 selu_kernel[grid](x_contig, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

64 return y