Coverage for src/flag_gems/runtime/backend/_mthreads/ops/gelu.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1# This custom op requires musa device capability >= 31.
2# We determine whether to enable this op by distinguish the op registration for different arch.
4import logging
6import torch
7import triton
8import triton.language as tl
10from flag_gems.utils import pointwise_dynamic, tl_extra_shim
12logger = logging.getLogger(
13 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
14)
15erf = tl_extra_shim.erf
16exp = tl_extra_shim.exp
17pow = tl_extra_shim.pow
18fast_tanh = tl_extra_shim.fast_tanh
19fast_gelu = tl_extra_shim.fast_gelu
22@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
23@triton.jit
24def gelu_none(x):
25 return fast_gelu(x.to(tl.float32))
28@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
29@triton.jit
30def gelu_tanh(x):
31 output = (
32 0.5
33 * x
34 * (1 + fast_tanh(x * 0.79788456 * (1 + 0.044715 * pow(x.to(tl.float32), 2))))
35 )
36 return output
39@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
40@triton.jit
41def gelu_backward_none(x, dy):
42 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
43 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
44 x_fp32 = x.to(tl.float32)
45 dydx = (
46 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2))
47 + 0.5 * erf(scale1 * x_fp32)
48 + 0.5
49 )
50 dx = dydx * dy
51 return dx
54@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
55@triton.jit
56def gelu_backward_tanh(x, dy):
57 x_fp32 = x.to(tl.float32)
58 # 0.79788456 = math.sqrt(2 / math.pi)
59 tanh_out = fast_tanh(0.79788456 * x * (1 + 0.044715 * pow(x_fp32, 2)))
60 dydx = 0.5 * x * (
61 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2))
62 ) + 0.5 * (1 + tanh_out)
63 dx = dydx * dy
64 return dx
67class Gelu(torch.autograd.Function):
68 @staticmethod
69 def forward(ctx, A, approximate):
70 logger.debug("GEMS_MTHREADS GELU FORWARD")
71 if approximate == "tanh":
72 out = gelu_tanh(A)
73 else:
74 out = gelu_none(A)
75 ctx.save_for_backward(A)
76 ctx.approximate = approximate
77 return out
79 @staticmethod
80 def backward(ctx, out_grad):
81 logger.debug("GEMS_MTHREADS GELU BACKWARD")
82 (inp,) = ctx.saved_tensors
83 approximate = ctx.approximate
84 if approximate == "tanh":
85 in_grad = gelu_backward_tanh(inp, out_grad)
86 else:
87 in_grad = gelu_backward_none(inp, out_grad)
88 return in_grad, None
91def gelu(A, *, approximate="none"):
92 return Gelu.apply(A, approximate)