Coverage for src/flag_gems/experimental_ops/silu.py: 0%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def silu_kernel(
8 x_ptr, # *Pointer* to input tensor
9 y_ptr, # *Pointer* to output tensor
10 n_elements, # Number of elements
11 BLOCK_SIZE: tl.constexpr,
12 COMPUTE_IN_FP32: tl.constexpr,
13):
14 pid = tl.program_id(axis=0)
15 block_start = pid * BLOCK_SIZE
16 offsets = block_start + tl.arange(0, BLOCK_SIZE)
17 mask = offsets < n_elements
19 x = tl.load(x_ptr + offsets, mask=mask)
20 if COMPUTE_IN_FP32:
21 xf = x.to(tl.float32)
22 yf = xf / (1.0 + tl.exp(-xf))
23 y = yf.to(x.dtype)
24 else:
25 y = x / (1.0 + tl.exp(-x))
26 tl.store(y_ptr + offsets, y, mask=mask)
29def _silu_impl(x: torch.Tensor, out: torch.Tensor = None):
30 if not x.is_cuda:
31 raise ValueError("Input tensor must be on CUDA device.")
32 if not torch.is_floating_point(x):
33 raise TypeError("silu expects a floating point tensor.")
34 if out is None:
35 out = torch.empty_like(x)
36 else:
37 if not out.is_cuda:
38 raise ValueError("Output tensor must be on CUDA device.")
39 if out.shape != x.shape:
40 raise ValueError(
41 f"Output shape {out.shape} does not match input shape {x.shape}."
42 )
43 if out.dtype != x.dtype:
44 raise TypeError(
45 f"Output dtype {out.dtype} does not match input dtype {x.dtype}."
46 )
48 x_contig = x.contiguous()
49 out_contig = out if out.is_contiguous() else torch.empty_like(x_contig)
51 n_elements = x_contig.numel()
52 if n_elements == 0:
53 if out_contig is not out:
54 out.copy_(out_contig)
55 return out
57 compute_in_fp32 = x_contig.dtype in (torch.float16, torch.bfloat16)
59 BLOCK_SIZE = 1024
60 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
61 silu_kernel[grid](
62 x_contig,
63 out_contig,
64 n_elements,
65 BLOCK_SIZE=BLOCK_SIZE,
66 COMPUTE_IN_FP32=compute_in_fp32,
67 )
69 if out_contig.data_ptr() != out.data_ptr():
70 out.copy_(out_contig)
71 return out
74def silu(*args, **kwargs):
75 # Expecting signature similar to aten.silu(self)
76 x = None
77 if len(args) >= 1:
78 x = args[0]
79 else:
80 x = kwargs.get("self", kwargs.get("input", None))
81 if x is None:
82 raise TypeError("silu expects a tensor as the first argument.")
83 return _silu_impl(x)
86def silu_out(*args, **kwargs):
87 # Expecting signature similar to aten.silu.out(self, out)
88 x = None
89 out = None
91 if len(args) >= 1:
92 x = args[0]
93 else:
94 x = kwargs.get("self", kwargs.get("input", None))
96 if len(args) >= 2:
97 out = args[1]
98 else:
99 out = kwargs.get("out", None)
101 if x is None or out is None:
102 raise TypeError("silu_out expects input and out tensors.")
104 _silu_impl(x, out)
105 return out