Coverage for src/flag_gems/ops/silu.py: 75%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.triton_lang_extension import div_rn 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

13@triton.jit 

14def silu_forward(x): 

15 x_fp32 = x.to(tl.float32) 

16 y = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

17 return y 

18 

19 

20@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

21@triton.jit 

22def silu_backward_kernel(x, dy): 

23 dy_fp32 = dy.to(tl.float32) 

24 x_fp32 = x.to(tl.float32) 

25 sigma = div_rn(1.0, 1.0 + tl.exp(-x_fp32)) 

26 dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma)) 

27 return dx 

28 

29 

30def silu(self): 

31 logger.debug("GEMS SILU FORWARD") 

32 output = silu_forward(self) 

33 return output 

34 

35 

36def silu_backward(grad_output, self): 

37 logger.debug("GEMS SILU BACKWARD") 

38 grad_input = silu_backward_kernel(self, grad_output) 

39 return grad_input 

40 

41 

42def silu_(A): 

43 logger.debug("GEMS SILU_ FORWARD") 

44 out = silu_forward(A, out0=A) 

45 return out