Coverage for src/flag_gems/runtime/backend/_hygon/ops/silu.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

12@triton.jit 

13def silu_forward(x): 

14 x_fp32 = x.to(tl.float32) 

15 y = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

16 return y 

17 

18 

19@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

20@triton.jit 

21def silu_backward_kernel(x, dy): 

22 dy_fp32 = dy.to(tl.float32) 

23 x_fp32 = x.to(tl.float32) 

24 sigma = 1.0 / (1.0 + tl.math.exp(-(x_fp32))) 

25 dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma)) 

26 return dx 

27 

28 

29def silu(self): 

30 logger.debug("GEMS SILU FORWARD") 

31 output = silu_forward(self) 

32 return output 

33 

34 

35def silu_backward(grad_output, self): 

36 logger.debug("GEMS SILU BACKWARD") 

37 grad_input = silu_backward_kernel(self, grad_output) 

38 return grad_input 

39 

40 

41def silu_(A): 

42 logger.debug("GEMS SILU_ FORWARD") 

43 out = silu_forward(A, out0=A) 

44 return out