Coverage for src/flag_gems/fused/geglu.py: 50%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
2from typing import Any, Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import tl_extra_shim
10erf = tl_extra_shim.erf
11exp = tl_extra_shim.exp
12pow = tl_extra_shim.pow
13tanh = tl_extra_shim.tanh
15logger = logging.getLogger(__name__)
18@triton.jit
19def geglu_kernel(
20 input_ptr,
21 output_ptr,
22 M,
23 H,
24 stride_in_m,
25 stride_in_h,
26 stride_out_m,
27 stride_out_h,
28 BLOCK_SIZE_M: tl.constexpr,
29 BLOCK_SIZE_H: tl.constexpr,
30):
31 pid_m = tl.program_id(0)
32 pid_h = tl.program_id(1)
34 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
35 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
37 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)
39 input_a_ptr = (
40 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
41 )
42 input_b_ptr = (
43 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
44 )
45 output_ptr = (
46 output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h
47 )
49 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
50 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)
52 gelu_out = 0.5 * x_a * (1 + tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2))))
53 out = gelu_out * x_b
55 tl.store(output_ptr, out.to(tl.float32), mask=mask)
58@triton.jit
59def dgeglu_kernel(
60 grad_out_ptr,
61 input_ptr,
62 grad_in_ptr,
63 M,
64 H,
65 stride_grad_out_m,
66 stride_grad_out_h,
67 stride_in_m,
68 stride_in_h,
69 stride_grad_in_m,
70 stride_grad_in_h,
71 BLOCK_SIZE_M: tl.constexpr,
72 BLOCK_SIZE_H: tl.constexpr,
73):
74 pid_m = tl.program_id(0)
75 pid_h = tl.program_id(1)
77 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
78 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
80 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)
82 grad_out_ptr = (
83 grad_out_ptr
84 + offs_m[:, None] * stride_grad_out_m
85 + offs_h[None, :] * stride_grad_out_h
86 )
87 input_a_ptr = (
88 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
89 )
90 input_b_ptr = (
91 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
92 )
93 grad_a_ptr = (
94 grad_in_ptr
95 + offs_m[:, None] * stride_grad_in_m
96 + offs_h[None, :] * stride_grad_in_h
97 )
98 grad_b_ptr = (
99 grad_in_ptr
100 + offs_m[:, None] * stride_grad_in_m
101 + (offs_h[None, :] + H) * stride_grad_in_h
102 )
104 grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32)
105 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
106 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)
108 tanh_out = tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2)))
109 gelu_out = 0.5 * x_a * (1 + tanh_out)
111 # dgelu/dx
112 sech2 = 1 - pow(tanh_out, 2)
113 dgelu = 0.5 * (1 + tanh_out) + 0.5 * x_a * sech2 * 0.79788456 * (
114 1 + 3 * 0.044715 * pow(x_a, 2)
115 )
117 grad_a = grad_out * x_b * dgelu
118 grad_b = grad_out * gelu_out
120 tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask)
121 tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask)
124def geglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor:
125 shape = input_tensor.shape
126 H = shape[-1] // 2
127 M = input_tensor.numel() // (2 * H)
129 input_2d = input_tensor.contiguous().view(M, 2 * H)
130 output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype)
132 grid = lambda META: (
133 triton.cdiv(M, META["BLOCK_SIZE_M"]),
134 triton.cdiv(H, META["BLOCK_SIZE_H"]),
135 )
137 geglu_kernel[grid](
138 input_2d,
139 output_2d,
140 M,
141 H,
142 input_2d.stride(0),
143 input_2d.stride(1),
144 output_2d.stride(0),
145 output_2d.stride(1),
146 BLOCK_SIZE_M=64,
147 BLOCK_SIZE_H=64,
148 )
149 # print("geglu")
150 return output_2d.view(*shape[:-1], H)
153def dgeglu(
154 grad_output: torch.Tensor,
155 input_tensor: torch.Tensor,
156 quantizer: Optional[Any] = None,
157) -> torch.Tensor:
158 shape = input_tensor.shape
159 H = shape[-1] // 2
160 M = input_tensor.numel() // (2 * H)
162 grad_out_2d = grad_output.contiguous().view(M, H)
163 input_2d = input_tensor.contiguous().view(M, 2 * H)
164 grad_in_2d = torch.empty_like(input_2d)
166 grid = lambda META: (
167 triton.cdiv(M, META["BLOCK_SIZE_M"]),
168 triton.cdiv(H, META["BLOCK_SIZE_H"]),
169 )
171 dgeglu_kernel[grid](
172 grad_out_2d,
173 input_2d,
174 grad_in_2d,
175 M,
176 H,
177 grad_out_2d.stride(0),
178 grad_out_2d.stride(1),
179 input_2d.stride(0),
180 input_2d.stride(1),
181 grad_in_2d.stride(0),
182 grad_in_2d.stride(1),
183 BLOCK_SIZE_M=64,
184 BLOCK_SIZE_H=64,
185 )
186 # print(dgeglu)
187 return grad_in_2d.view_as(input_tensor)