Coverage for src/flag_gems/ops/gelu.py: 70%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic, tl_extra_shim
8erf = tl_extra_shim.erf
9exp = tl_extra_shim.exp
10pow = tl_extra_shim.pow
11tanh = tl_extra_shim.tanh
14logger = logging.getLogger(__name__)
17@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
18@triton.jit
19def gelu_none(x):
20 scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
21 output = 0.5 * x * (1 + erf(x.to(tl.float32) * scale))
22 return output
25@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
26@triton.jit
27def gelu_tanh(x):
28 x_fp32 = x.to(tl.float32)
29 output = 0.5 * x * (1 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32, 2))))
30 return output
33@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
34@triton.jit
35def gelu_backward_none(x, dy):
36 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
37 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
38 x_fp32 = x.to(tl.float32)
39 dydx = (
40 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2))
41 + 0.5 * erf(scale1 * x_fp32)
42 + 0.5
43 )
44 dx = dydx * dy
45 return dx
48@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
49@triton.jit
50def gelu_backward_tanh(x, dy):
51 x_fp32 = x.to(tl.float32)
52 # 0.79788456 = math.sqrt(2 / math.pi)
53 tanh_out = tanh(0.79788456 * x_fp32 * (1 + 0.044715 * pow(x_fp32, 2)))
54 dydx = 0.5 * x_fp32 * (
55 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2))
56 ) + 0.5 * (1 + tanh_out)
57 dx = dydx * dy
58 return dx
61def gelu(self, *, approximate="none"):
62 logger.debug("GEMS GELU FORWARD")
63 if approximate == "tanh":
64 out = gelu_tanh(self)
65 else:
66 out = gelu_none(self)
67 return out
70def gelu_backward(grad_output, self, *, approximate="none"):
71 logger.debug("GEMS GELU BACKWARD")
72 if approximate == "tanh":
73 in_grad = gelu_backward_tanh(self, grad_output)
74 else:
75 in_grad = gelu_backward_none(self, grad_output)
76 return in_grad
79def gelu_(A, *, approximate="none"):
80 logger.debug("GEMS GELU_ FORWARD")
81 if approximate == "tanh":
82 out = gelu_tanh(A, out0=A)
83 else:
84 out = gelu_none(A, out0=A)
85 return out