Coverage for src/flag_gems/ops/silu.py: 75%
32 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.triton_lang_extension import div_rn
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
13@triton.jit
14def silu_forward(x):
15 x_fp32 = x.to(tl.float32)
16 y = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
17 return y
20@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
21@triton.jit
22def silu_backward_kernel(x, dy):
23 dy_fp32 = dy.to(tl.float32)
24 x_fp32 = x.to(tl.float32)
25 sigma = div_rn(1.0, 1.0 + tl.exp(-x_fp32))
26 dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma))
27 return dx
30def silu(self):
31 logger.debug("GEMS SILU FORWARD")
32 output = silu_forward(self)
33 return output
36def silu_backward(grad_output, self):
37 logger.debug("GEMS SILU BACKWARD")
38 grad_input = silu_backward_kernel(self, grad_output)
39 return grad_input
42def silu_(A):
43 logger.debug("GEMS SILU_ FORWARD")
44 out = silu_forward(A, out0=A)
45 return out