Coverage for src/flag_gems/fused/swiglu.py: 53%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
2from typing import Any, Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import tl_extra_shim
10sigmoid = tl.sigmoid
11exp = tl_extra_shim.exp
12pow = tl_extra_shim.pow
14logger = logging.getLogger(__name__)
17@triton.jit
18def swiglu_kernel(
19 input_ptr,
20 output_ptr,
21 M,
22 H,
23 stride_in_m,
24 stride_in_h,
25 stride_out_m,
26 stride_out_h,
27 BLOCK_SIZE_M: tl.constexpr,
28 BLOCK_SIZE_H: tl.constexpr,
29):
30 pid_m = tl.program_id(0)
31 pid_h = tl.program_id(1)
33 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
34 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
36 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)
38 input_a_ptr = (
39 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
40 )
41 input_b_ptr = (
42 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
43 )
44 output_ptr = (
45 output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h
46 )
48 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
49 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)
51 silu_x_a = x_a * sigmoid(x_a)
52 out = silu_x_a * x_b
54 tl.store(output_ptr, out.to(x_a.dtype), mask=mask)
57@triton.jit
58def dswiglu_kernel(
59 grad_out_ptr,
60 input_ptr,
61 grad_in_ptr,
62 M,
63 H,
64 stride_grad_out_m,
65 stride_grad_out_h,
66 stride_in_m,
67 stride_in_h,
68 stride_grad_in_m,
69 stride_grad_in_h,
70 BLOCK_SIZE_M: tl.constexpr,
71 BLOCK_SIZE_H: tl.constexpr,
72):
73 pid_m = tl.program_id(0)
74 pid_h = tl.program_id(1)
76 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
77 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
79 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)
81 grad_out_ptr = (
82 grad_out_ptr
83 + offs_m[:, None] * stride_grad_out_m
84 + offs_h[None, :] * stride_grad_out_h
85 )
86 input_a_ptr = (
87 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
88 )
89 input_b_ptr = (
90 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
91 )
92 grad_a_ptr = (
93 grad_in_ptr
94 + offs_m[:, None] * stride_grad_in_m
95 + offs_h[None, :] * stride_grad_in_h
96 )
97 grad_b_ptr = (
98 grad_in_ptr
99 + offs_m[:, None] * stride_grad_in_m
100 + (offs_h[None, :] + H) * stride_grad_in_h
101 )
103 grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32)
104 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
105 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)
107 sig = sigmoid(x_a)
108 silu = x_a * sig
109 d_silu = sig + x_a * sig * (1 - sig)
111 grad_a = grad_out * x_b * d_silu
112 grad_b = grad_out * silu
114 tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask)
115 tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask)
118def swiglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor:
119 logger.debug("GEMS SWIGLU")
120 if input_tensor.shape[-1] % 2 != 0:
121 raise ValueError(
122 f"The last dimension of must be even number, got {input_tensor.shape[-1]}"
123 )
124 if not input_tensor.is_cuda:
125 raise ValueError("Only CUDA tensor is supported by SwiGLU")
127 shape = input_tensor.shape
128 H = shape[-1] // 2
129 M = input_tensor.numel() // (2 * H)
130 input_2d = input_tensor.contiguous().view(M, 2 * H)
131 output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype)
133 grid = lambda META: (
134 triton.cdiv(M, META["BLOCK_SIZE_M"]),
135 triton.cdiv(H, META["BLOCK_SIZE_H"]),
136 )
138 swiglu_kernel[grid](
139 input_2d,
140 output_2d,
141 M,
142 H,
143 input_2d.stride(0),
144 input_2d.stride(1),
145 output_2d.stride(0),
146 output_2d.stride(1),
147 BLOCK_SIZE_M=64,
148 BLOCK_SIZE_H=64,
149 )
151 return output_2d.view(*shape[:-1], H)
154def dswiglu(
155 grad_output: torch.Tensor,
156 input_tensor: torch.Tensor,
157 quantizer: Optional[Any] = None,
158) -> torch.Tensor:
159 logger.debug("GEMS DSWIGLU")
160 shape = input_tensor.shape
161 assert (
162 shape[-1] % 2 == 0
163 ), f"The last dimension of input_tensor must be an even number, got {shape[-1]}"
164 H = shape[-1] // 2
165 M = input_tensor.numel() // (2 * H)
166 grad_out_2d = grad_output.contiguous().view(M, H)
167 input_2d = input_tensor.contiguous().view(M, 2 * H)
168 grad_in_2d = torch.empty_like(input_2d)
170 grid = lambda META: (
171 triton.cdiv(M, META["BLOCK_SIZE_M"]),
172 triton.cdiv(H, META["BLOCK_SIZE_H"]),
173 )
175 dswiglu_kernel[grid](
176 grad_out_2d,
177 input_2d,
178 grad_in_2d,
179 M,
180 H,
181 grad_out_2d.stride(0),
182 grad_out_2d.stride(1),
183 input_2d.stride(0),
184 input_2d.stride(1),
185 grad_in_2d.stride(0),
186 grad_in_2d.stride(1),
187 BLOCK_SIZE_M=64,
188 BLOCK_SIZE_H=64,
189 )
191 return grad_in_2d.view_as(input_tensor)
194__all__ = ["swiglu", "dswiglu"]