Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/silu.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5from _kunlunxin.utils.codegen_config_utils import CodeGenConfig 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12div_rn = tl_extra_shim.div_rn 

13 

14config_ = CodeGenConfig( 

15 512, 

16 (65536, 65536, 65536), 

17 32, 

18 True, 

19 prefer_1d_tile=True, 

20 buffer_size_limit=4096, 

21 isCloseVectorization=True, 

22 unroll_num=8, 

23) 

24 

25 

26@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")], config=config_) 

27@triton.jit 

28def silu_forward(x): 

29 x_fp32 = x.to(tl.float32) 

30 y = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

31 return y 

32 

33 

34@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

35@triton.jit 

36def silu_backward_kernel(x, dy): 

37 dy_fp32 = dy.to(tl.float32) 

38 x_fp32 = x.to(tl.float32) 

39 sigma = div_rn(1.0, 1.0 + tl.exp(-x_fp32)) 

40 dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma)) 

41 return dx 

42 

43 

44def silu(self): 

45 logger.debug("GEMS SILU FORWARD") 

46 output = silu_forward(self) 

47 return output 

48 

49 

50def silu_backward(grad_output, self): 

51 logger.debug("GEMS SILU BACKWARD") 

52 grad_input = silu_backward_kernel(self, grad_output) 

53 return grad_input 

54 

55 

56def silu_(A): 

57 logger.debug("GEMS SILU_ FORWARD") 

58 out = silu_forward(A, out0=A) 

59 return out