Coverage for src/flag_gems/runtime/backend/_cambricon/fused/gelu_and_mul.py: 0%
69 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger(__name__)
12fast_erf = tl_extra_shim.fast_erf
13fast_tanh = tl_extra_shim.fast_tanh
16@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
17@triton.jit
18def gelu_none_and_mul_kernel(x, y):
19 x_fp32 = x.to(tl.float32)
20 x_gelu = 0.5 * x_fp32 * (1 + fast_erf(x_fp32 * 0.7071067811))
21 return x_gelu * y
24@pointwise_dynamic(
25 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
26)
27@triton.jit
28def gelu_none_and_mul_grad_kernel(x, y, dgrad):
29 RCP_SQRT_2: tl.constexpr = 0.7071067811
30 COEFF: tl.constexpr = 0.7978845608028654
32 x_fp32 = x.to(tl.float32)
33 x_gelu = 0.5 * x_fp32 * (1 + fast_erf(x_fp32 * RCP_SQRT_2))
35 d_gelu = dgrad * y
36 dx = (
37 d_gelu
38 * 0.5
39 * (
40 1.0
41 + fast_erf(x_fp32 * RCP_SQRT_2)
42 + x_fp32 * COEFF * tl.exp(-0.5 * x_fp32 * x_fp32)
43 )
44 )
46 dy = dgrad * x_gelu
48 return dx, dy
51@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
52@triton.jit
53def gelu_tanh_and_mul_kernel(x, y):
54 x_fp32 = x.to(tl.float32)
55 x_gelu = (
56 0.5
57 * x_fp32
58 * (1 + fast_tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * x_fp32 * x_fp32)))
59 )
60 return x_gelu * y
63@pointwise_dynamic(
64 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
65)
66@triton.jit
67def gelu_tanh_and_mul_grad_kernel(x, y, dgrad):
68 x_fp32 = x.to(tl.float32)
69 y_fp32 = y.to(tl.float32)
71 sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
72 a_cubed = x_fp32 * x_fp32 * x_fp32
73 tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed)
74 tanh_result = fast_tanh(tanh_arg)
75 geglu_a = 0.5 * x_fp32 * (1 + tanh_result)
76 dy = geglu_a * dgrad
78 term1 = 0.5 * (1 + tanh_result)
79 tanh_sq = tanh_result * tanh_result
80 term2 = (
81 0.5
82 * x_fp32
83 * (1 - tanh_sq)
84 * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32))
85 )
86 dx = dgrad * y_fp32 * (term1 + term2)
88 return dx, dy
91class GeluAndMul(torch.autograd.Function):
92 @staticmethod
93 def forward(ctx, x, y, approximate="none"):
94 logger.debug("GEMS_CAMBRICON GELU AND MUL FORWARD")
95 ctx.save_for_backward(x, y)
96 ctx.approximate = approximate
97 if approximate == "none":
98 return gelu_none_and_mul_kernel(x, y)
99 elif approximate == "tanh":
100 return gelu_tanh_and_mul_kernel(x, y)
101 else:
102 raise ValueError(f"Invalid approximate value: {approximate}")
104 @staticmethod
105 def backward(ctx, dgrad):
106 logger.debug("GEMS_CAMBRICON GELU AND MUL BACKWARD")
107 x, y = ctx.saved_tensors
108 if ctx.approximate == "none":
109 dx, dy = gelu_none_and_mul_grad_kernel(x, y, dgrad)
110 else:
111 dx, dy = gelu_tanh_and_mul_grad_kernel(x, y, dgrad)
112 return dx, dy, None
115def gelu_and_mul(A, B, approximate="none"):
116 return GeluAndMul.apply(A, B, approximate)