Coverage for src/flag_gems/fused/FLA/cumsum.py: 23%
90 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
6import warnings
8import torch
9import triton
10import triton.language as tl
12from flag_gems.fused.FLA.index import prepare_chunk_indices
13from flag_gems.fused.FLA.utils import check_shared_mem, input_guard
14from flag_gems.utils import libentry, libtuner
16BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
19@libentry()
20@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
21@libtuner(
22 configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
23 key=["B", "H", "BT", "IS_VARLEN", "REVERSE"],
24)
25@triton.jit(do_not_specialize=["T"])
26def chunk_local_cumsum_scalar_kernel(
27 s,
28 o,
29 cu_seqlens,
30 chunk_indices,
31 T,
32 B: tl.constexpr,
33 H: tl.constexpr,
34 BT: tl.constexpr,
35 REVERSE: tl.constexpr,
36 IS_VARLEN: tl.constexpr,
37 HEAD_FIRST: tl.constexpr,
38):
39 i_t, i_bh = tl.program_id(0), tl.program_id(1)
40 i_b, i_h = i_bh // H, i_bh % H
41 if IS_VARLEN:
42 i_n, i_t = (
43 tl.load(chunk_indices + i_t * 2).to(tl.int32),
44 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
45 )
46 bos, eos = (
47 tl.load(cu_seqlens + i_n).to(tl.int32),
48 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
49 )
50 T = eos - bos
51 else:
52 bos, eos = i_b * T, i_b * T + T
54 if HEAD_FIRST:
55 p_s = tl.make_block_ptr(
56 s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
57 )
58 p_o = tl.make_block_ptr(
59 o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
60 )
61 else:
62 p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
63 p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
64 # [BT]
65 b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
66 b_o = tl.cumsum(b_s, axis=0)
67 if REVERSE:
68 b_z = tl.sum(b_s, axis=0)
69 b_o = -b_o + b_z[None] + b_s
70 tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
73@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
74@triton.autotune(
75 configs=[
76 triton.Config({"BS": BS}, num_warps=num_warps)
77 for BS in BS_LIST
78 for num_warps in [2, 4, 8]
79 ],
80 key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"],
81)
82@triton.jit(do_not_specialize=["T"])
83def chunk_local_cumsum_vector_kernel(
84 s,
85 o,
86 cu_seqlens,
87 chunk_indices,
88 T,
89 B: tl.constexpr,
90 H: tl.constexpr,
91 S: tl.constexpr,
92 BT: tl.constexpr,
93 BS: tl.constexpr,
94 REVERSE: tl.constexpr,
95 IS_VARLEN: tl.constexpr,
96 HEAD_FIRST: tl.constexpr,
97):
98 i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
99 i_b, i_h = i_bh // H, i_bh % H
100 if IS_VARLEN:
101 i_n, i_t = (
102 tl.load(chunk_indices + i_t * 2).to(tl.int32),
103 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
104 )
105 bos, eos = (
106 tl.load(cu_seqlens + i_n).to(tl.int32),
107 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
108 )
109 T = eos - bos
110 else:
111 bos, eos = i_b * T, i_b * T + T
113 o_i = tl.arange(0, BT)
114 if REVERSE:
115 m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0)
116 else:
117 m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0)
119 if HEAD_FIRST:
120 p_s = tl.make_block_ptr(
121 s + (bos * H + i_h * T) * S,
122 (T, S),
123 (S, 1),
124 (i_t * BT, i_s * BS),
125 (BT, BS),
126 (1, 0),
127 )
128 p_o = tl.make_block_ptr(
129 o + (bos * H + i_h * T) * S,
130 (T, S),
131 (S, 1),
132 (i_t * BT, i_s * BS),
133 (BT, BS),
134 (1, 0),
135 )
136 else:
137 p_s = tl.make_block_ptr(
138 s + (bos * H + i_h) * S,
139 (T, S),
140 (H * S, 1),
141 (i_t * BT, i_s * BS),
142 (BT, BS),
143 (1, 0),
144 )
145 p_o = tl.make_block_ptr(
146 o + (bos * H + i_h) * S,
147 (T, S),
148 (H * S, 1),
149 (i_t * BT, i_s * BS),
150 (BT, BS),
151 (1, 0),
152 )
153 # [BT, BS]
154 b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
155 b_o = tl.dot(m_s, b_s, allow_tf32=False)
156 tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
159def chunk_local_cumsum_scalar(
160 g: torch.Tensor,
161 chunk_size: int,
162 reverse: bool = False,
163 cu_seqlens: torch.Tensor | None = None,
164 head_first: bool = False,
165 output_dtype: torch.dtype | None = torch.float,
166) -> torch.Tensor:
167 if head_first:
168 B, H, T = g.shape
169 else:
170 B, T, H = g.shape
171 assert chunk_size == 2 ** (
172 chunk_size.bit_length() - 1
173 ), "chunk_size must be a power of 2"
174 BT = chunk_size
175 chunk_indices = (
176 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
177 )
178 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
179 g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
180 grid = (NT, B * H)
181 chunk_local_cumsum_scalar_kernel[grid](
182 g_org,
183 g,
184 cu_seqlens,
185 chunk_indices,
186 T=T,
187 B=B,
188 H=H,
189 BT=BT,
190 HEAD_FIRST=head_first,
191 REVERSE=reverse,
192 )
193 return g
196def chunk_local_cumsum_vector(
197 g: torch.Tensor,
198 chunk_size: int,
199 reverse: bool = False,
200 cu_seqlens: torch.Tensor | None = None,
201 head_first: bool = False,
202 output_dtype: torch.dtype | None = torch.float,
203) -> torch.Tensor:
204 if head_first:
205 B, H, T, S = g.shape
206 else:
207 B, T, H, S = g.shape
208 BT = chunk_size
209 chunk_indices = (
210 prepare_chunk_indices(cu_seqlens, chunk_size)
211 if cu_seqlens is not None
212 else None
213 )
214 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
215 assert chunk_size == 2 ** (
216 chunk_size.bit_length() - 1
217 ), "chunk_size must be a power of 2"
219 g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
221 def grid(meta):
222 return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
224 # keep cumulative normalizer in fp32
225 # this kernel is equivalent to
226 # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
227 chunk_local_cumsum_vector_kernel[grid](
228 g_org,
229 g,
230 cu_seqlens,
231 chunk_indices,
232 T=T,
233 B=B,
234 H=H,
235 S=S,
236 BT=BT,
237 HEAD_FIRST=head_first,
238 REVERSE=reverse,
239 )
240 return g
243@input_guard
244def chunk_local_cumsum(
245 g: torch.Tensor,
246 chunk_size: int,
247 reverse: bool = False,
248 cu_seqlens: torch.Tensor | None = None,
249 head_first: bool = False,
250 output_dtype: torch.dtype | None = torch.float,
251 **kwargs,
252) -> torch.Tensor:
253 if not head_first and g.shape[1] < g.shape[2]:
254 warnings.warn(
255 "Input tensor shape suggests potential format mismatch: "
256 f" seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
257 "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
258 "when head_first=False was specified. "
259 "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
260 stacklevel=2,
261 )
262 if cu_seqlens is not None:
263 assert (
264 g.shape[0] == 1
265 ), "Only batch size 1 is supported when cu_seqlens are provided"
266 if len(g.shape) == 3:
267 return chunk_local_cumsum_scalar(
268 g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
269 )
270 elif len(g.shape) == 4:
271 return chunk_local_cumsum_vector(
272 g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
273 )
274 else:
275 raise ValueError(
276 f"Unsupported input shape {g.shape}. "
277 f"which should be (B, T, H, D) if `head_first=False` "
278 f"or (B, H, T, D) otherwise"
279 )