Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/silu.py: 0%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import triton
4import triton.language as tl
5from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12div_rn = tl_extra_shim.div_rn
14config_ = CodeGenConfig(
15 512,
16 (65536, 65536, 65536),
17 32,
18 True,
19 prefer_1d_tile=True,
20 buffer_size_limit=4096,
21 isCloseVectorization=True,
22 unroll_num=8,
23)
26@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")], config=config_)
27@triton.jit
28def silu_forward(x):
29 x_fp32 = x.to(tl.float32)
30 y = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
31 return y
34@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
35@triton.jit
36def silu_backward_kernel(x, dy):
37 dy_fp32 = dy.to(tl.float32)
38 x_fp32 = x.to(tl.float32)
39 sigma = div_rn(1.0, 1.0 + tl.exp(-x_fp32))
40 dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma))
41 return dx
44def silu(self):
45 logger.debug("GEMS SILU FORWARD")
46 output = silu_forward(self)
47 return output
50def silu_backward(grad_output, self):
51 logger.debug("GEMS SILU BACKWARD")
52 grad_input = silu_backward_kernel(self, grad_output)
53 return grad_input
56def silu_(A):
57 logger.debug("GEMS SILU_ FORWARD")
58 out = silu_forward(A, out0=A)
59 return out