Coverage for src/flag_gems/experimental_ops/silu_.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def silu_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(x_ptr + offsets, mask=mask)
13 x_f = x.to(tl.float32)
14 y = x_f * tl.sigmoid(x_f)
15 y = y.to(x.dtype)
16 tl.store(x_ptr + offsets, y, mask=mask)
19_silu_kernel = silu_
22def silu_(*args, **kwargs):
23 x = None
24 if len(args) > 0:
25 x = args[0]
26 else:
27 x = kwargs.get("input", kwargs.get("self", None))
28 if x is None:
29 raise ValueError("silu_ expects a tensor as the first argument (self).")
30 if not x.is_cuda:
31 # Fallback to PyTorch for non-CUDA tensors
32 return torch.ops.aten.silu_(x)
33 if not x.dtype.is_floating_point:
34 raise TypeError(f"silu_ expects a floating point tensor, got {x.dtype}")
35 # Fallback for unsupported dtypes or non-contiguous tensors
36 supported_dtypes = {torch.float16, torch.bfloat16, torch.float32}
37 if (x.dtype not in supported_dtypes) or (not x.is_contiguous()):
38 return torch.ops.aten.silu_(x)
40 n_elements = x.numel()
41 if n_elements == 0:
42 return x
44 BLOCK_SIZE = 1024
45 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
46 _silu_kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE)
47 return x