Coverage for src/flag_gems/fused/silu_and_mul.py: 59%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
13@triton.jit
14def silu_and_mul_kernel(x, y):
15 x_fp32 = x.to(tl.float32)
16 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
17 return x_silu * y
20@pointwise_dynamic(
21 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
22)
23@triton.jit
24def silu_and_mul_grad_kernel(x, y, dgrad):
25 x_fp32 = x.to(tl.float32)
26 sig = 1 / (1 + tl.exp(-x_fp32))
27 x_silu = x_fp32 * sig
28 d_x_silu = sig * (1 + x_fp32 * (1 - sig))
29 dx = d_x_silu * dgrad * y
30 dy = dgrad * x_silu
31 return dx, dy
34class SiluAndMul(torch.autograd.Function):
35 @staticmethod
36 def forward(ctx, A, B):
37 ctx.save_for_backward(A, B)
38 logger.debug("GEMS SILU AND MUL FORWARD")
39 return silu_and_mul_kernel(A, B)
41 def backward(ctx, grad_output):
42 A, B = ctx.saved_tensors
43 grad_A, grad_B = silu_and_mul_grad_kernel(A, B, grad_output)
44 return grad_A, grad_B
47def silu_and_mul(A, B):
48 return SiluAndMul.apply(A, B)
51def silu_and_mul_out(A, B, out):
52 silu_and_mul_kernel(A, B, out0=out)
53 return out