Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/relu.py: 0%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

12@triton.jit 

13# relu another way: maximum(x, 0) 

14# tl.maximum(x, 0) to one max_instr,but tl.where two instr compare and select 

15def relu_forward(x): 

16 return tl.maximum(x, 0) 

17 

18 

19@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

20@triton.jit 

21def relu_backward(x, dy): 

22 return tl.where(x > 0, dy, 0) 

23 

24 

25def relu(self): 

26 logger.debug("GEMS RELU FORWARD") 

27 output = relu_forward(self) 

28 return output 

29 

30 

31def relu_(A): 

32 logger.debug("GEMS RELU_ FORWARD") 

33 out = relu_forward(A, out0=A) 

34 return out