Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/gelu.py: 0%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import tl_extra_shim
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11erf = tl_extra_shim.erf
12exp = tl_extra_shim.exp
13pow = tl_extra_shim.pow
14tanh = tl_extra_shim.tanh
17@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
18@triton.jit
19def gelu_none(x):
20 scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
21 output = 0.5 * x * (1 + erf(x * scale))
22 return output
25@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
26@triton.jit
27def gelu_tanh(x):
28 output = (
29 0.5
30 * x
31 * (1 + tanh(x * 0.79788456 * (1 + 0.044715 * pow(x.to(tl.float32), 2.0))))
32 )
33 return output
36@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
37@triton.jit
38def gelu_backward_none(x, dy):
39 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
40 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
41 x_fp32 = x.to(tl.float32)
42 dydx = (
43 scale2 * x_fp32 * tl.exp(-pow(scale1 * x_fp32, 2.0))
44 + 0.5 * erf(scale1 * x_fp32)
45 + 0.5
46 )
47 dx = dydx * dy
48 return dx
51@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
52@triton.jit
53def gelu_backward_tanh(x, dy):
54 x_fp32 = x.to(tl.float32)
55 # 0.79788456 = math.sqrt(2 / math.pi)
56 tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * pow(x_fp32, 2.0)))
57 dydx = 0.5 * x * (
58 (1 - pow(tanh_out, 2.0)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2.0))
59 ) + 0.5 * (1 + tanh_out)
60 dx = dydx * dy
61 return dx
64def gelu(self, *, approximate="none"):
65 logger.debug("GEMS GELU FORWARD")
66 if approximate == "tanh":
67 out = gelu_tanh(self)
68 else:
69 out = gelu_none(self)
70 return out
73def gelu_backward(grad_output, self, *, approximate="none"):
74 logger.debug("GEMS GELU BACKWARD")
75 if approximate == "tanh":
76 in_grad = gelu_backward_tanh(self, grad_output)
77 else:
78 in_grad = gelu_backward_none(self, grad_output)
79 return in_grad
82def gelu_(A, *, approximate="none"):
83 logger.debug("GEMS GELU_ FORWARD")
84 if approximate == "tanh":
85 out = gelu_tanh(A, out0=A)
86 else:
87 out = gelu_none(A, out0=A)
88 return out