Coverage for src/flag_gems/fused/geglu.py: 51%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
2from typing import Any, Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import tl_extra_shim
10erf = tl_extra_shim.erf
11exp = tl_extra_shim.exp
12pow = tl_extra_shim.pow
13tanh = tl_extra_shim.tanh
15logger = logging.getLogger(__name__)
18@triton.jit
19def geglu_kernel(
20 input_ptr,
21 output_ptr,
22 M,
23 H,
24 stride_in_m,
25 stride_in_h,
26 stride_out_m,
27 stride_out_h,
28 BLOCK_SIZE_M: tl.constexpr,
29 BLOCK_SIZE_H: tl.constexpr,
30):
31 pid_m = tl.program_id(0)
32 pid_h = tl.program_id(1)
34 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
35 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
37 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)
39 input_a_ptr = (
40 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
41 )
42 input_b_ptr = (
43 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
44 )
45 output_ptr = (
46 output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h
47 )
49 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
50 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)
52 gelu_out = 0.5 * x_a * (1 + tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2))))
53 out = gelu_out * x_b
55 tl.store(output_ptr, out.to(tl.float32), mask=mask)
58@triton.jit
59def dgeglu_kernel(
60 grad_out_ptr,
61 input_ptr,
62 grad_in_ptr,
63 M,
64 H,
65 stride_grad_out_m,
66 stride_grad_out_h,
67 stride_in_m,
68 stride_in_h,
69 stride_grad_in_m,
70 stride_grad_in_h,
71 BLOCK_SIZE_M: tl.constexpr,
72 BLOCK_SIZE_H: tl.constexpr,
73):
74 pid_m = tl.program_id(0)
75 pid_h = tl.program_id(1)
77 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
78 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
80 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)
82 grad_out_ptr = (
83 grad_out_ptr
84 + offs_m[:, None] * stride_grad_out_m
85 + offs_h[None, :] * stride_grad_out_h
86 )
87 input_a_ptr = (
88 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
89 )
90 input_b_ptr = (
91 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
92 )
93 grad_a_ptr = (
94 grad_in_ptr
95 + offs_m[:, None] * stride_grad_in_m
96 + offs_h[None, :] * stride_grad_in_h
97 )
98 grad_b_ptr = (
99 grad_in_ptr
100 + offs_m[:, None] * stride_grad_in_m
101 + (offs_h[None, :] + H) * stride_grad_in_h
102 )
104 grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32)
105 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
106 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)
108 tanh_out = tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2)))
109 gelu_out = 0.5 * x_a * (1 + tanh_out)
111 # dgelu/dx
112 sech2 = 1 - pow(tanh_out, 2)
113 dgelu = 0.5 * (1 + tanh_out) + 0.5 * x_a * sech2 * 0.79788456 * (
114 1 + 3 * 0.044715 * pow(x_a, 2)
115 )
117 grad_a = grad_out * x_b * dgelu
118 grad_b = grad_out * gelu_out
120 tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask)
121 tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask)
124def geglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor:
125 logger.debug("GEMS GEGLU")
126 shape = input_tensor.shape
127 H = shape[-1] // 2
128 M = input_tensor.numel() // (2 * H)
130 input_2d = input_tensor.contiguous().view(M, 2 * H)
131 output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype)
133 grid = lambda META: (
134 triton.cdiv(M, META["BLOCK_SIZE_M"]),
135 triton.cdiv(H, META["BLOCK_SIZE_H"]),
136 )
138 geglu_kernel[grid](
139 input_2d,
140 output_2d,
141 M,
142 H,
143 input_2d.stride(0),
144 input_2d.stride(1),
145 output_2d.stride(0),
146 output_2d.stride(1),
147 BLOCK_SIZE_M=64,
148 BLOCK_SIZE_H=64,
149 )
150 # print("geglu")
151 return output_2d.view(*shape[:-1], H)
154def dgeglu(
155 grad_output: torch.Tensor,
156 input_tensor: torch.Tensor,
157 quantizer: Optional[Any] = None,
158) -> torch.Tensor:
159 logger.debug("GEMS DGEGLU")
160 shape = input_tensor.shape
161 H = shape[-1] // 2
162 M = input_tensor.numel() // (2 * H)
164 grad_out_2d = grad_output.contiguous().view(M, H)
165 input_2d = input_tensor.contiguous().view(M, 2 * H)
166 grad_in_2d = torch.empty_like(input_2d)
168 grid = lambda META: (
169 triton.cdiv(M, META["BLOCK_SIZE_M"]),
170 triton.cdiv(H, META["BLOCK_SIZE_H"]),
171 )
173 dgeglu_kernel[grid](
174 grad_out_2d,
175 input_2d,
176 grad_in_2d,
177 M,
178 H,
179 grad_out_2d.stride(0),
180 grad_out_2d.stride(1),
181 input_2d.stride(0),
182 input_2d.stride(1),
183 grad_in_2d.stride(0),
184 grad_in_2d.stride(1),
185 BLOCK_SIZE_M=64,
186 BLOCK_SIZE_H=64,
187 )
188 # print(dgeglu)
189 return grad_in_2d.view_as(input_tensor)