Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/gelu_and_mul.py: 0%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12erf = tl_extra_shim.erf
13pow = tl_extra_shim.pow
14tanh = tl_extra_shim.tanh
17@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
18@triton.jit
19def gelu_none_and_mul_kernel(x, y):
20 x_fp32 = x.to(tl.float32)
21 x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * 0.7071067811))
22 return x_gelu * y
25@pointwise_dynamic(
26 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
27)
28@triton.jit
29def gelu_none_and_mul_grad_kernel(x, y, dgrad):
30 RCP_SQRT_2: tl.constexpr = 0.7071067811
31 COEFF: tl.constexpr = 0.7978845608028654
33 x_fp32 = x.to(tl.float32)
34 x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2))
36 d_gelu = dgrad * y
37 dx = (
38 d_gelu
39 * 0.5
40 * (
41 1.0
42 + erf(x_fp32 * RCP_SQRT_2)
43 + x_fp32 * COEFF * tl.exp(-0.5 * x_fp32 * x_fp32)
44 )
45 )
47 dy = dgrad * x_gelu
49 return dx, dy
52@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
53@triton.jit
54def gelu_tanh_and_mul_kernel(x, y):
55 x_fp32 = x.to(tl.float32)
56 x_gelu = (
57 0.5
58 * x_fp32
59 * (
60 1
61 + tanh(
62 x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32.to(tl.float32), 2.0))
63 )
64 )
65 )
66 return x_gelu * y
69@pointwise_dynamic(
70 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
71)
72@triton.jit
73def gelu_tanh_and_mul_grad_kernel(x, y, dgrad):
74 x_fp32 = x.to(tl.float32)
75 y_fp32 = y.to(tl.float32)
77 sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
78 a_cubed = x_fp32 * x_fp32 * x_fp32
79 tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed)
80 tanh_result = tanh(tanh_arg)
81 geglu_a = 0.5 * x_fp32 * (1 + tanh_result)
82 dy = geglu_a * dgrad
84 term1 = 0.5 * (1 + tanh_result)
85 tanh_sq = tanh_result * tanh_result
86 term2 = (
87 0.5
88 * x_fp32
89 * (1 - tanh_sq)
90 * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32))
91 )
92 dx = dgrad * y_fp32 * (term1 + term2)
94 return dx, dy
97class GeluAndMul(torch.autograd.Function):
98 @staticmethod
99 def forward(ctx, x, y, approximate="none"):
100 logger.debug("GEMS GELU AND MUL FORWARD")
101 ctx.save_for_backward(x, y)
102 ctx.approximate = approximate
103 if approximate == "none":
104 return gelu_none_and_mul_kernel(x, y)
105 elif approximate == "tanh":
106 return gelu_tanh_and_mul_kernel(x, y)
107 else:
108 raise ValueError(f"Invalid approximate value: {approximate}")
110 @staticmethod
111 def backward(ctx, dgrad):
112 logger.debug("GEMS GELU AND MUL BACKWARD")
113 x, y = ctx.saved_tensors
114 if ctx.approximate == "none":
115 dx, dy = gelu_none_and_mul_grad_kernel(x, y, dgrad)
116 else:
117 dx, dy = gelu_tanh_and_mul_grad_kernel(x, y, dgrad)
118 return dx, dy, None
121def gelu_and_mul(x, y, approximate="none"):
122 return GeluAndMul.apply(x, y, approximate)