Coverage for src/flag_gems/fused/silu_and_mul.py: 56%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

13@triton.jit 

14def silu_and_mul_kernel(x, y): 

15 x_fp32 = x.to(tl.float32) 

16 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

17 return x_silu * y 

18 

19 

20@pointwise_dynamic( 

21 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 

22) 

23@triton.jit 

24def silu_and_mul_grad_kernel(x, y, dgrad): 

25 x_fp32 = x.to(tl.float32) 

26 sig = 1 / (1 + tl.exp(-x_fp32)) 

27 x_silu = x_fp32 * sig 

28 d_x_silu = sig * (1 + x_fp32 * (1 - sig)) 

29 dx = d_x_silu * dgrad * y 

30 dy = dgrad * x_silu 

31 return dx, dy 

32 

33 

34class SiluAndMul(torch.autograd.Function): 

35 @staticmethod 

36 def forward(ctx, A, B): 

37 ctx.save_for_backward(A, B) 

38 logger.debug("GEMS SILU AND MUL FORWARD") 

39 return silu_and_mul_kernel(A, B) 

40 

41 def backward(ctx, grad_output): 

42 A, B = ctx.saved_tensors 

43 logger.debug("GEMS SILU AND MUL BACKWARD") 

44 grad_A, grad_B = silu_and_mul_grad_kernel(A, B, grad_output) 

45 return grad_A, grad_B 

46 

47 

48def silu_and_mul(A, B): 

49 return SiluAndMul.apply(A, B) 

50 

51 

52def silu_and_mul_out(A, B, out): 

53 logger.debug("GEMS SILU AND MUL OUT") 

54 silu_and_mul_kernel(A, B, out0=out) 

55 return out