Coverage for src/flag_gems/runtime/backend/_arm/ops/gelu.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9logger = logging.getLogger(__name__)
10erf = tl_extra_shim.erf
11exp = tl_extra_shim.exp
12pow = tl_extra_shim.pow
13tanh = tl_extra_shim.tanh
16@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
17@triton.jit
18def gelu_none(x):
19 scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
20 output = 0.5 * x * (1 + erf(x * scale))
21 return output
24@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
25@triton.jit
26def gelu_tanh(x):
27 output = (
28 0.5 * x * (1 + tanh(x * 0.79788456 * (1 + 0.044715 * pow(x.to(tl.float32), 2))))
29 )
30 return output
33@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
34@triton.jit
35def gelu_backward_none(x, dy):
36 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
37 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
38 x_fp32 = x.to(tl.float32)
39 dydx = (
40 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2))
41 + 0.5 * erf(scale1 * x_fp32)
42 + 0.5
43 )
44 dx = dydx * dy
45 return dx
48@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
49@triton.jit
50def gelu_backward_tanh(x, dy):
51 x_fp32 = x.to(tl.float32)
52 # 0.79788456 = math.sqrt(2 / math.pi)
53 tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * pow(x_fp32, 2)))
54 dydx = 0.5 * x * (
55 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2))
56 ) + 0.5 * (1 + tanh_out)
57 dx = dydx * dy
58 return dx
61class Gelu(torch.autograd.Function):
62 @staticmethod
63 def forward(ctx, A, approximate):
64 logger.debug("GEMS GELU FORWARD")
65 if approximate == "tanh":
66 out = gelu_tanh(A)
67 else:
68 out = gelu_none(A)
69 ctx.save_for_backward(A)
70 ctx.approximate = approximate
71 return out
73 @staticmethod
74 def backward(ctx, out_grad):
75 logger.debug("GEMS GELU BACKWARD")
76 (inp,) = ctx.saved_tensors
77 approximate = ctx.approximate
78 if approximate == "tanh":
79 in_grad = gelu_backward_tanh(inp, out_grad)
80 else:
81 in_grad = gelu_backward_none(inp, out_grad)
82 return in_grad, None
85def gelu(A, *, approximate="none"):
86 print("\n.......test for mutibackend specific gelu........\n")
87 return Gelu.apply(A, approximate)