Coverage for src/flag_gems/runtime/backend/_cambricon/ops/gelu.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import tl_extra_shim
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11fast_erf = tl_extra_shim.fast_erf
12exp = tl_extra_shim.exp
13fast_tanh = tl_extra_shim.fast_tanh
16@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
17@triton.jit
18def gelu_none(x, inplace):
19 scale: tl.constexpr = 0.7071067811
20 x_f32 = x.to(tl.float32)
21 output = 0.5 * x_f32 + 0.5 * x_f32 * fast_erf(x_f32 * scale)
22 return output
25@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
26@triton.jit
27def gelu_tanh(x, inplace):
28 x_f32 = x.to(tl.float32)
29 output = 0.5 * x_f32 + 0.5 * x_f32 * fast_tanh(
30 x_f32 * 0.79788456 + x_f32 * 0.79788456 * 0.044715 * x_f32 * x_f32
31 )
32 return output
35@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
36@triton.jit
37def gelu_backward_none(x, dy):
38 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
39 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
40 x_fp32 = x.to(tl.float32)
41 x_sqrt = scale1 * x_fp32
42 dydx = scale2 * x_fp32 * exp(-x_sqrt * x_sqrt) + 0.5 * fast_erf(x_sqrt) + 0.5
43 dx = dydx * dy
44 return dx
47@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
48@triton.jit
49def gelu_backward_tanh(x, dy):
50 x_fp32 = x.to(tl.float32)
51 c1 = 0.79788456 # math.sqrt(2 / math.pi)
52 c2 = 0.044715
53 # z = c1 * (x + c2 * x**3)
54 tanh_out = fast_tanh(c1 * x_fp32 + c1 * x_fp32 * c2 * x_fp32 * x_fp32)
55 # dz_dx = c1 * (1 + 3 * c2 * x * x)
56 # 0.1070322243 = c1 * 3 *c2
57 dydx = (
58 0.5 * ((x - x * tanh_out * tanh_out) * (c1 + 0.1070322243 * x_fp32 * x_fp32))
59 + 0.5
60 + 0.5 * tanh_out
61 )
62 dx = dydx * dy
63 return dx
66def gelu(self, *, approximate="none"):
67 logger.debug("GEMS_CAMBRICON GELU FORWARD")
68 if approximate == "tanh":
69 out = gelu_tanh(self, False)
70 else:
71 out = gelu_none(self, False)
72 return out
75def gelu_backward(grad_output, self, *, approximate="none"):
76 logger.debug("GEMS_CAMBRICON GELU BACKWARD")
77 if approximate == "tanh":
78 in_grad = gelu_backward_tanh(self, grad_output)
79 else:
80 in_grad = gelu_backward_none(self, grad_output)
81 return in_grad
84def gelu_(A, *, approximate="none"):
85 logger.debug("GEMS_CAMBRICON GELU_ FORWARD")
86 if approximate == "tanh":
87 out = gelu_tanh(A, True, out0=A)
88 else:
89 out = gelu_none(A, True, out0=A)
90 return out