Coverage for src/flag_gems/runtime/backend/_hygon/ops/silu.py: 0%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
12@triton.jit
13def silu_forward(x):
14 x_fp32 = x.to(tl.float32)
15 y = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
16 return y
19@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
20@triton.jit
21def silu_backward_kernel(x, dy):
22 dy_fp32 = dy.to(tl.float32)
23 x_fp32 = x.to(tl.float32)
24 sigma = 1.0 / (1.0 + tl.math.exp(-(x_fp32)))
25 dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma))
26 return dx
29def silu(self):
30 logger.debug("GEMS SILU FORWARD")
31 output = silu_forward(self)
32 return output
35def silu_backward(grad_output, self):
36 logger.debug("GEMS SILU BACKWARD")
37 grad_input = silu_backward_kernel(self, grad_output)
38 return grad_input
41def silu_(A):
42 logger.debug("GEMS SILU_ FORWARD")
43 out = silu_forward(A, out0=A)
44 return out