Coverage for src/flag_gems/ops/selu_.py: 51%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def selu_kernel_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask) 

21 

22 x_f32 = x.to(tl.float32) 

23 alpha = 1.6732632423543772 

24 scale = 1.0507009873554805 

25 y_f32 = scale * tl.where(x_f32 > 0, x_f32, alpha * (tl.exp(x_f32) - 1.0)) 

26 y = y_f32.to(x.dtype) 

27 

28 tl.store(x_ptr + offsets, y, mask=mask) 

29 

30 

31def selu_(*args, **kwargs): 

32 logger.debug("GEMS SELU_") 

33 x = None 

34 if len(args) > 0 and torch.is_tensor(args[0]): 

35 x = args[0] 

36 elif "input" in kwargs and torch.is_tensor(kwargs["input"]): 

37 x = kwargs["input"] 

38 elif "self" in kwargs and torch.is_tensor(kwargs["self"]): 

39 x = kwargs["self"] 

40 elif "x" in kwargs and torch.is_tensor(kwargs["x"]): 

41 x = kwargs["x"] 

42 else: 

43 raise ValueError( 

44 "selu_ expects a Tensor as the first argument or under 'input'/'self'/'x' keyword." 

45 ) 

46 

47 supported_dtypes = {torch.float16, torch.bfloat16, torch.float32} 

48 if (not x.is_contiguous()) or (x.dtype not in supported_dtypes): 

49 torch.ops.aten.selu_(x) 

50 return x 

51 

52 n_elements = x.numel() 

53 if n_elements == 0: 

54 return x 

55 

56 BLOCK_SIZE = 1024 

57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

58 with torch_device_fn.device(x.device): 

59 selu_kernel_[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

60 return x