Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/relu.py: 0%
21 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
3import triton
4import triton.language as tl
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
12@triton.jit
13# relu another way: maximum(x, 0)
14# tl.maximum(x, 0) to one max_instr,but tl.where two instr compare and select
15def relu_forward(x):
16 return tl.maximum(x, 0)
19@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
20@triton.jit
21def relu_backward(x, dy):
22 return tl.where(x > 0, dy, 0)
25def relu(self):
26 logger.debug("GEMS RELU FORWARD")
27 output = relu_forward(self)
28 return output
31def relu_(A):
32 logger.debug("GEMS RELU_ FORWARD")
33 out = relu_forward(A, out0=A)
34 return out