Coverage for src/flag_gems/experimental_ops/gelu_.py: 0%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import torch # noqa: F401
2import triton
3import triton.language as tl
6@triton.jit
7def gelu_(
8 x_ptr, # *Pointer* to the input/output tensor (in-place).
9 n_elements, # Number of elements.
10 USE_TANH: tl.constexpr, # Whether to use tanh approximation.
11 BLOCK_SIZE: tl.constexpr, # Elements per program.
12):
13 pid = tl.program_id(axis=0)
14 block_start = pid * BLOCK_SIZE
15 offsets = block_start + tl.arange(0, BLOCK_SIZE)
16 mask = offsets < n_elements
18 x = tl.load(x_ptr + offsets, mask=mask, other=0)
19 x_f32 = x.to(tl.float32)
21 # Compute GELU either exact (via erf approximation) or tanh approximation
22 if USE_TANH:
23 # tanh approximation:
24 # gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 x^3)))
25 c0 = 0.7978845608028654 # sqrt(2/pi)
26 c1 = 0.044715
27 x3 = x_f32 * x_f32 * x_f32
28 z = c0 * (x_f32 + c1 * x3)
29 # tanh(z) = (1 - e^{-2z}) / (1 + e^{-2z})
30 t = tl.exp(-2.0 * z)
31 tanh_z = (1.0 - t) / (1.0 + t)
32 y = 0.5 * x_f32 * (1.0 + tanh_z)
33 else:
34 # exact (erf-based) GELU:
35 # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
36 inv_sqrt2 = 0.7071067811865476
37 z = x_f32 * inv_sqrt2
39 # Abramowitz and Stegun formula 7.1.26 for erf approximation
40 # erf(x) ≈ sign(x) * (1 - (((((a5*t + a4)*t + a3)*t + a2)*t + a1)*t) * e^{-x^2})
41 # where t = 1 / (1 + p*|x|)
42 p = 0.3275911
43 a1 = 0.254829592
44 a2 = -0.284496736
45 a3 = 1.421413741
46 a4 = -1.453152027
47 a5 = 1.061405429
49 az = tl.abs(z)
50 t = 1.0 / (1.0 + p * az)
51 poly = a5
52 poly = poly * t + a4
53 poly = poly * t + a3
54 poly = poly * t + a2
55 poly = poly * t + a1
56 poly = poly * t
57 erf_abs = 1.0 - poly * tl.exp(-az * az)
58 erf_z = tl.where(z >= 0, erf_abs, -erf_abs)
60 y = 0.5 * x_f32 * (1.0 + erf_z)
62 y_cast = y.to(x.dtype)
63 tl.store(x_ptr + offsets, y_cast, mask=mask)
66# Preserve a handle to the kernel before defining the Python wrapper of the same name
67gelu__kernel = gelu_
70def gelu_(*args, **kwargs):
71 # Resolve input tensor
72 x = None
73 if len(args) >= 1:
74 x = args[0]
75 else:
76 # Try common names
77 x = kwargs.get("input", None)
78 if x is None:
79 x = kwargs.get("self", None)
80 if x is None:
81 x = kwargs.get("x", None)
82 if x is None:
83 raise ValueError("gelu_ expects a tensor as the first argument.")
85 # Determine approximation mode
86 approx = kwargs.get("approximate", "none")
87 if isinstance(approx, bool):
88 use_tanh = bool(approx)
89 else:
90 approx_str = str(approx).lower()
91 if approx_str in ("tanh", "true"):
92 use_tanh = True
93 elif approx_str in ("none", "false"):
94 use_tanh = False
95 else:
96 raise ValueError(
97 f"Unsupported approximate mode: {approx}. Use 'none' or 'tanh'."
98 )
100 if not x.is_cuda:
101 raise AssertionError("Input tensor must be on CUDA device for Triton kernel.")
102 if not x.is_contiguous():
103 raise AssertionError("Input tensor must be contiguous.")
104 if not x.is_floating_point():
105 raise AssertionError("gelu_ expects a floating point tensor.")
107 n_elements = x.numel()
108 if n_elements == 0:
109 return x
111 BLOCK_SIZE = 1024
112 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
114 gelu__kernel[grid](x, n_elements, USE_TANH=use_tanh, BLOCK_SIZE=BLOCK_SIZE)
115 return x