Coverage for src/flag_gems/fused/gelu_and_mul.py: 50%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9erf = tl_extra_shim.erf
10pow = tl_extra_shim.pow
11tanh = tl_extra_shim.tanh
12logger = logging.getLogger(__name__)
15@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
16@triton.jit
17def gelu_none_and_mul_kernel(x, y):
18 x_fp32 = x.to(tl.float32)
19 RCP_SQRT_2: tl.constexpr = 0.7071067811
20 x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2))
21 return x_gelu * y
24@pointwise_dynamic(
25 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
26)
27@triton.jit
28def gelu_none_and_mul_grad_kernel(x, y, dgrad):
29 RCP_SQRT_2: tl.constexpr = 0.7071067811
30 COEFF: tl.constexpr = 0.7978845608028654
32 x_fp32 = x.to(tl.float32)
33 x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2))
35 d_gelu = dgrad * y
36 dx = (
37 d_gelu
38 * 0.5
39 * (
40 1.0
41 + erf(x_fp32 * RCP_SQRT_2)
42 + x_fp32 * COEFF * tl.exp(-0.5 * x_fp32 * x_fp32)
43 )
44 )
46 dy = dgrad * x_gelu
48 return dx, dy
51@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
52@triton.jit
53def gelu_tanh_and_mul_kernel(x, y):
54 x_fp32 = x.to(tl.float32)
55 x_gelu = (
56 0.5
57 * x_fp32
58 * (
59 1
60 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32.to(tl.float32), 2)))
61 )
62 )
63 return x_gelu * y
66@pointwise_dynamic(
67 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
68)
69@triton.jit
70def gelu_tanh_and_mul_grad_kernel(x, y, dgrad):
71 x_fp32 = x.to(tl.float32)
72 y_fp32 = y.to(tl.float32)
74 sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
75 a_cubed = x_fp32 * x_fp32 * x_fp32
76 tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed)
77 tanh_result = tanh(tanh_arg)
78 geglu_a = 0.5 * x_fp32 * (1 + tanh_result)
79 dy = geglu_a * dgrad
81 term1 = 0.5 * (1 + tanh_result)
82 tanh_sq = tanh_result * tanh_result
83 term2 = (
84 0.5
85 * x_fp32
86 * (1 - tanh_sq)
87 * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32))
88 )
89 dx = dgrad * y_fp32 * (term1 + term2)
91 return dx, dy
94class GeluAndMul(torch.autograd.Function):
95 @staticmethod
96 def forward(ctx, x, y, approximate="none"):
97 logger.debug("GEMS GELU AND MUL FORWARD")
98 ctx.save_for_backward(x, y)
99 ctx.approximate = approximate
100 if approximate == "none":
101 return gelu_none_and_mul_kernel(x, y)
102 elif approximate == "tanh":
103 return gelu_tanh_and_mul_kernel(x, y)
104 else:
105 raise ValueError(f"Invalid approximate value: {approximate}")
107 @staticmethod
108 def backward(ctx, dgrad):
109 logger.debug("GEMS GELU AND MUL BACKWARD")
110 x, y = ctx.saved_tensors
111 if ctx.approximate == "none":
112 dx, dy = gelu_none_and_mul_grad_kernel(x, y, dgrad)
113 else:
114 dx, dy = gelu_tanh_and_mul_grad_kernel(x, y, dgrad)
115 return dx, dy, None
118def gelu_and_mul(x, y, approximate="none"):
119 return GeluAndMul.apply(x, y, approximate)