Coverage for src/flag_gems/runtime/backend/_hygon/ops/gelu.py: 0%
84 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9logger = logging.getLogger(__name__)
10erf = tl_extra_shim.erf
11exp = tl_extra_shim.exp
12pow = tl_extra_shim.pow
13tanh = tl_extra_shim.tanh
16@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
17@triton.jit
18def gelu_none(x):
19 x_fp32 = x.to(tl.float32)
20 scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
21 output = 0.5 * x_fp32 * (1 + erf(x_fp32 * scale))
22 return output
25@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
26@triton.jit
27def gelu_tanh(x):
28 x_fp32 = x.to(tl.float32)
29 output = (
30 0.5
31 * x_fp32
32 * (
33 1
34 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32.to(tl.float32), 2)))
35 )
36 )
37 return output
40@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
41@triton.jit
42def gelu_backward_none(x, dy):
43 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
44 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
45 x_fp32 = x.to(tl.float32)
46 dydx = (
47 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2))
48 + 0.5 * erf(scale1 * x_fp32)
49 + 0.5
50 )
51 dx = dydx * dy
52 return dx
55@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
56@triton.jit
57def gelu_backward_tanh(x, dy):
58 x_fp32 = x.to(tl.float32)
59 # 0.79788456 = math.sqrt(2 / math.pi)
60 tanh_out = tanh(0.79788456 * x_fp32 * (1 + 0.044715 * pow(x_fp32, 2)))
61 dydx = 0.5 * x_fp32 * (
62 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2))
63 ) + 0.5 * (1 + tanh_out)
64 dx = dydx * dy
65 return dx
68class Gelu(torch.autograd.Function):
69 @staticmethod
70 def forward(ctx, A, approximate):
71 logger.debug("GEMS GELU FORWARD")
72 if approximate == "tanh":
73 out = gelu_tanh(A)
74 else:
75 out = gelu_none(A)
76 ctx.save_for_backward(A)
77 ctx.approximate = approximate
78 return out
80 @staticmethod
81 def backward(ctx, out_grad):
82 logger.debug("GEMS GELU BACKWARD")
83 (inp,) = ctx.saved_tensors
84 approximate = ctx.approximate
85 if approximate == "tanh":
86 in_grad = gelu_backward_tanh(inp, out_grad)
87 else:
88 in_grad = gelu_backward_none(inp, out_grad)
89 return in_grad, None
92def gelu(A, *, approximate="none"):
93 return Gelu.apply(A, approximate)
96class InplaceGelu(torch.autograd.Function):
97 @staticmethod
98 def forward(ctx, A, approximate):
99 logger.debug("GEMS GELU_ FORWARD")
100 ctx.save_for_backward(A.clone())
101 ctx.mark_dirty(A)
102 ctx.approximate = approximate
104 if approximate == "tanh":
105 out = gelu_tanh(A, out0=A)
106 else:
107 out = gelu_none(A, out0=A)
108 return out
110 @staticmethod
111 def backward(ctx, out_grad):
112 logger.debug("GEMS GELU_ BACKWARD")
113 (inp,) = ctx.saved_tensors
114 approximate = ctx.approximate
115 if approximate == "tanh":
116 in_grad = gelu_backward_tanh(inp, out_grad)
117 else:
118 in_grad = gelu_backward_none(inp, out_grad)
119 return in_grad, None
122def gelu_(A, *, approximate="none"):
123 InplaceGelu.apply(A, approximate)
124 return A