Coverage for src/flag_gems/fused/silu_and_mul_with_clamp.py: 53%
53 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")])
13@triton.jit
14def silu_and_mul_with_clamp_kernel(x, y, limit):
15 x_fp32 = x.to(tl.float32)
16 y_fp32 = y.to(tl.float32)
17 limit_fp32 = limit.to(tl.float32)
19 gate = tl.minimum(x_fp32, limit_fp32)
20 up = tl.minimum(tl.maximum(y_fp32, -limit_fp32), limit_fp32)
21 gate_silu = tl.fdiv(gate, (1.0 + tl.exp(-gate)))
23 return gate_silu * up
26@pointwise_dynamic(
27 promotion_methods=[
28 (0, 1, 2, 3, "DEFAULT"),
29 (0, 1, 2, 3, "DEFAULT"),
30 ],
31 num_outputs=2,
32)
33@triton.jit
34def silu_and_mul_with_clamp_grad_kernel(x, y, dgrad, limit):
35 x_fp32 = x.to(tl.float32)
36 y_fp32 = y.to(tl.float32)
37 dgrad_fp32 = dgrad.to(tl.float32)
38 limit_fp32 = limit.to(tl.float32)
40 gate = tl.minimum(x_fp32, limit_fp32)
41 up = tl.minimum(tl.maximum(y_fp32, -limit_fp32), limit_fp32)
43 sig = 1 / (1 + tl.exp(-gate))
44 gate_silu = gate * sig
45 d_gate_silu = sig * (1 + gate * (1 - sig))
47 gate_mask = x_fp32 <= limit_fp32
48 up_mask = (y_fp32 >= -limit_fp32) & (y_fp32 <= limit_fp32)
50 dx = dgrad_fp32 * up * d_gate_silu * gate_mask.to(tl.float32)
51 dy = dgrad_fp32 * gate_silu * up_mask.to(tl.float32)
53 return dx, dy
56class SiluAndMulWithClamp(torch.autograd.Function):
57 @staticmethod
58 def forward(ctx, x, y, limit):
59 limit_tensor = torch.tensor(limit, device=x.device, dtype=x.dtype)
60 ctx.save_for_backward(x, y, limit_tensor)
61 logger.debug("GEMS SILU_AND_MUL_WITH_CLAMP_FORWARD")
62 return silu_and_mul_with_clamp_kernel(x, y, limit_tensor)
64 @staticmethod
65 def backward(ctx, dgrad):
66 x, y, limit_tensor = ctx.saved_tensors
67 logger.debug("GEMS SILU_AND_MUL_WITH_CLAMP_BACKWARD")
68 dx, dy = silu_and_mul_with_clamp_grad_kernel(x, y, dgrad, limit_tensor)
69 return dx, dy, None
72def silu_and_mul_with_clamp(x, y, limit):
73 return SiluAndMulWithClamp.apply(x, y, limit)
76def silu_and_mul_with_clamp_out(x, y, out, limit):
77 logger.debug("GEMS SILU_AND_MUL_WITH_CLAMP_OUT")
78 limit_tensor = torch.tensor(limit, device=x.device, dtype=x.dtype)
79 silu_and_mul_with_clamp_kernel(x, y, limit_tensor, out0=out)
80 return out