Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/sigmoid.py: 0%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
3import triton
4import triton.language as tl
5from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12exp2 = tl_extra_shim.exp2
15config_ = CodeGenConfig(
16 512,
17 (65536, 65536, 65536),
18 32,
19 True,
20 prefer_1d_tile=True,
21 buffer_size_limit=4096,
22 isCloseVectorization=True,
23 kunlunAutoGrid=True,
24 unroll_num=8,
25)
28@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")], config=config_)
29@triton.jit
30def sigmoid_forward(x):
31 # log2e: tl.constexpr = math.log2(math.e)
32 # triton 3.0.0 disallow calling non-jitted function inside jitted function, even if it is in
33 # the rhs of an assignment to a constexpr, so we use numeric literal instead to work around this.
34 # log2e: tl.constexpr = 1.4426950408889634
35 return 1 / (1 + tl.exp(-x.to(tl.float32)))
38@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")])
39@triton.jit
40def sigmoid_backward_kernel(dy, y):
41 y_f32 = y.to(tl.float32)
42 dy_f32 = dy.to(tl.float32)
43 return dy_f32 * (1.0 - y_f32) * y_f32
46def sigmoid(self):
47 logger.debug("GEMS SIGMOID FORWARD")
48 output = sigmoid_forward(self)
49 return output
52def sigmoid_backward(grad_output, output):
53 logger.debug("GEMS SIGMOID BACKWARD")
54 grad_input = sigmoid_backward_kernel(grad_output, output)
55 return grad_input
58def sigmoid_(A):
59 logger.debug("GEMS SIGMOID_ FORWARD")
60 out = sigmoid_forward(A, out0=A)
61 return out