Coverage for src/flag_gems/fused/swiglu.py: 51%
72 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2from typing import Any, Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import tl_extra_shim
10sigmoid = tl.sigmoid
11exp = tl_extra_shim.exp
12pow = tl_extra_shim.pow
14logger = logging.getLogger(__name__)
17@triton.jit
18def swiglu_kernel(
19 input_ptr,
20 output_ptr,
21 M,
22 H,
23 stride_in_m,
24 stride_in_h,
25 stride_out_m,
26 stride_out_h,
27 BLOCK_SIZE_M: tl.constexpr,
28 BLOCK_SIZE_H: tl.constexpr,
29):
30 pid_m = tl.program_id(0)
31 pid_h = tl.program_id(1)
33 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
34 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
36 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)
38 input_a_ptr = (
39 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
40 )
41 input_b_ptr = (
42 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
43 )
44 output_ptr = (
45 output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h
46 )
48 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
49 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)
51 silu_x_a = x_a * sigmoid(x_a)
52 out = silu_x_a * x_b
54 tl.store(output_ptr, out.to(x_a.dtype), mask=mask)
57@triton.jit
58def dswiglu_kernel(
59 grad_out_ptr,
60 input_ptr,
61 grad_in_ptr,
62 M,
63 H,
64 stride_grad_out_m,
65 stride_grad_out_h,
66 stride_in_m,
67 stride_in_h,
68 stride_grad_in_m,
69 stride_grad_in_h,
70 BLOCK_SIZE_M: tl.constexpr,
71 BLOCK_SIZE_H: tl.constexpr,
72):
73 pid_m = tl.program_id(0)
74 pid_h = tl.program_id(1)
76 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
77 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
79 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)
81 grad_out_ptr = (
82 grad_out_ptr
83 + offs_m[:, None] * stride_grad_out_m
84 + offs_h[None, :] * stride_grad_out_h
85 )
86 input_a_ptr = (
87 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
88 )
89 input_b_ptr = (
90 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
91 )
92 grad_a_ptr = (
93 grad_in_ptr
94 + offs_m[:, None] * stride_grad_in_m
95 + offs_h[None, :] * stride_grad_in_h
96 )
97 grad_b_ptr = (
98 grad_in_ptr
99 + offs_m[:, None] * stride_grad_in_m
100 + (offs_h[None, :] + H) * stride_grad_in_h
101 )
103 grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32)
104 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
105 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)
107 sig = sigmoid(x_a)
108 silu = x_a * sig
109 d_silu = sig + x_a * sig * (1 - sig)
111 grad_a = grad_out * x_b * d_silu
112 grad_b = grad_out * silu
114 tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask)
115 tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask)
118def swiglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor:
119 if input_tensor.shape[-1] % 2 != 0:
120 raise ValueError(
121 f"The last dimension of must be even number, got {input_tensor.shape[-1]}"
122 )
123 if not input_tensor.is_cuda:
124 raise ValueError("Only CUDA tensor is supported by SwiGLU")
126 shape = input_tensor.shape
127 H = shape[-1] // 2
128 M = input_tensor.numel() // (2 * H)
129 input_2d = input_tensor.contiguous().view(M, 2 * H)
130 output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype)
132 grid = lambda META: (
133 triton.cdiv(M, META["BLOCK_SIZE_M"]),
134 triton.cdiv(H, META["BLOCK_SIZE_H"]),
135 )
137 swiglu_kernel[grid](
138 input_2d,
139 output_2d,
140 M,
141 H,
142 input_2d.stride(0),
143 input_2d.stride(1),
144 output_2d.stride(0),
145 output_2d.stride(1),
146 BLOCK_SIZE_M=64,
147 BLOCK_SIZE_H=64,
148 )
150 return output_2d.view(*shape[:-1], H)
153def dswiglu(
154 grad_output: torch.Tensor,
155 input_tensor: torch.Tensor,
156 quantizer: Optional[Any] = None,
157) -> torch.Tensor:
158 shape = input_tensor.shape
159 assert (
160 shape[-1] % 2 == 0
161 ), f"The last dimension of input_tensor must be an even number, got {shape[-1]}"
162 H = shape[-1] // 2
163 M = input_tensor.numel() // (2 * H)
164 grad_out_2d = grad_output.contiguous().view(M, H)
165 input_2d = input_tensor.contiguous().view(M, 2 * H)
166 grad_in_2d = torch.empty_like(input_2d)
168 grid = lambda META: (
169 triton.cdiv(M, META["BLOCK_SIZE_M"]),
170 triton.cdiv(H, META["BLOCK_SIZE_H"]),
171 )
173 dswiglu_kernel[grid](
174 grad_out_2d,
175 input_2d,
176 grad_in_2d,
177 M,
178 H,
179 grad_out_2d.stride(0),
180 grad_out_2d.stride(1),
181 input_2d.stride(0),
182 input_2d.stride(1),
183 grad_in_2d.stride(0),
184 grad_in_2d.stride(1),
185 BLOCK_SIZE_M=64,
186 BLOCK_SIZE_H=64,
187 )
189 return grad_in_2d.view_as(input_tensor)
192__all__ = ["swiglu", "dswiglu"]