Coverage for src/flag_gems/ops/selu_.py: 51%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def selu_kernel_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask)
22 x_f32 = x.to(tl.float32)
23 alpha = 1.6732632423543772
24 scale = 1.0507009873554805
25 y_f32 = scale * tl.where(x_f32 > 0, x_f32, alpha * (tl.exp(x_f32) - 1.0))
26 y = y_f32.to(x.dtype)
28 tl.store(x_ptr + offsets, y, mask=mask)
31def selu_(*args, **kwargs):
32 logger.debug("GEMS SELU_")
33 x = None
34 if len(args) > 0 and torch.is_tensor(args[0]):
35 x = args[0]
36 elif "input" in kwargs and torch.is_tensor(kwargs["input"]):
37 x = kwargs["input"]
38 elif "self" in kwargs and torch.is_tensor(kwargs["self"]):
39 x = kwargs["self"]
40 elif "x" in kwargs and torch.is_tensor(kwargs["x"]):
41 x = kwargs["x"]
42 else:
43 raise ValueError(
44 "selu_ expects a Tensor as the first argument or under 'input'/'self'/'x' keyword."
45 )
47 supported_dtypes = {torch.float16, torch.bfloat16, torch.float32}
48 if (not x.is_contiguous()) or (x.dtype not in supported_dtypes):
49 torch.ops.aten.selu_(x)
50 return x
52 n_elements = x.numel()
53 if n_elements == 0:
54 return x
56 BLOCK_SIZE = 1024
57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
58 with torch_device_fn.device(x.device):
59 selu_kernel_[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE)
60 return x