Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/silu_and_mul.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from ..utils.pointwise_dynamic import pointwise_dynamic 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

13@triton.jit 

14def silu_and_mul_kernel(x, y): 

15 x_fp32 = x.to(tl.float32) 

16 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

17 return x_silu * y 

18 

19 

20@pointwise_dynamic( 

21 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 

22) 

23@triton.jit 

24def silu_and_mul_grad_kernel(x, y, dgrad): 

25 x_fp32 = x.to(tl.float32) 

26 sig = 1 / (1 + tl.exp(-x_fp32)) 

27 x_silu = x_fp32 * sig 

28 d_x_silu = sig * (1 + x_fp32 * (1 - sig)) 

29 dx = d_x_silu * dgrad * y 

30 dy = dgrad * x_silu 

31 return dx, dy 

32 

33 

34class SiluAndMul(torch.autograd.Function): 

35 @staticmethod 

36 def forward(ctx, A, B): 

37 ctx.save_for_backward(A, B) 

38 logger.debug("GEMS SILU AND MUL FORWARD") 

39 return silu_and_mul_kernel(A, B) 

40 

41 def backward(ctx, grad_output): 

42 A, B = ctx.saved_tensors 

43 grad_A, grad_B = silu_and_mul_grad_kernel(A, B, grad_output) 

44 return grad_A, grad_B 

45 

46 

47def silu_and_mul(A, B): 

48 return SiluAndMul.apply(A, B) 

49 

50 

51def silu_and_mul_out(A, B, out): 

52 silu_and_mul_kernel(A, B, out0=out) 

53 return out