Coverage for src/flag_gems/runtime/backend/_cambricon/ops/silu.py: 0%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import tl_extra_shim
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11div_rn = tl_extra_shim.div_rn
14@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
15@triton.jit
16def silu_forward(x, inplace):
17 x_fp32 = x.to(tl.float32)
18 y = 1.0 / (1.0 + tl.exp(-x_fp32)) * x_fp32
19 return y
22@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
23@triton.jit
24def silu_backward_kernel(x, dy):
25 dy_fp32 = dy.to(tl.float32)
26 x_fp32 = x.to(tl.float32)
27 sigma = 1.0 / (1.0 + tl.exp(-x_fp32))
28 dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma))
29 return dx
32def silu(self):
33 logger.debug("GEMS_CAMBRICON SILU FORWARD")
34 output = silu_forward(self, False)
35 return output
38def silu_backward(grad_output, self):
39 logger.debug("GEMS_CAMBRICON SILU BACKWARD")
40 grad_input = silu_backward_kernel(self, grad_output)
41 return grad_input
44def silu_(A):
45 logger.debug("GEMS_CAMBRICON SILU_ FORWARD")
46 out = silu_forward(A, True, out0=A)
47 return out