Coverage for src/flag_gems/fused/FLA/fused_cumsum_kkt_solve_tril.py: 10%
157 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1# Copyright (c) 2025 FlagGems. All rights reserved.
2# Fused cumsum + KKT + solve_tril for chunk_gated_delta_rule. Returns g_out, A_inv; w_u is separate.
3# License: Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0)
5from __future__ import annotations
7import torch
8import triton
9import triton.language as tl
11from flag_gems.fused.FLA.index import prepare_chunk_indices
12from flag_gems.fused.FLA.solve_tril import FLA_TRIL_PRECISION
13from flag_gems.fused.FLA.triton_ops_helper import exp, make_tensor_descriptor
14from flag_gems.fused.FLA.utils import is_tma_supported
15from flag_gems.utils import libentry, libtuner
18@libentry()
19@triton.heuristics(
20 {
21 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
22 "USE_G": lambda args: True,
23 }
24)
25@libtuner(
26 configs=[
27 triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
28 for BK in [32, 64, 128]
29 for num_warps in [2, 4, 8]
30 for num_stages in [2, 3, 4]
31 ],
32 key=["H", "K", "BT", "IS_VARLEN"],
33)
34@triton.jit(do_not_specialize=["T"])
35def chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril_kernel(
36 g_in,
37 g_out,
38 k,
39 beta,
40 A,
41 A_inv,
42 cu_seqlens,
43 chunk_indices,
44 T,
45 H: tl.constexpr,
46 Hg: tl.constexpr,
47 K: tl.constexpr,
48 BT: tl.constexpr,
49 BK: tl.constexpr,
50 IS_VARLEN: tl.constexpr,
51 USE_G: tl.constexpr,
52 USE_TMA: tl.constexpr,
53 DOT_PRECISION: tl.constexpr,
54):
55 i_t, i_bh = tl.program_id(0), tl.program_id(1)
56 i_b, i_h = i_bh // H, i_bh % H
57 if IS_VARLEN:
58 i_n, i_t = (
59 tl.load(chunk_indices + i_t * 2).to(tl.int32),
60 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
61 )
62 bos, eos = (
63 tl.load(cu_seqlens + i_n).to(tl.int32),
64 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
65 )
66 T = eos - bos
67 else:
68 bos, eos = i_b * T, i_b * T + T
70 # ---------- cumsum ----------
71 p_g_in = tl.make_block_ptr(
72 g_in + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
73 )
74 p_g_out = tl.make_block_ptr(
75 g_out + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
76 )
77 b_g = tl.load(p_g_in, boundary_check=(0,)).to(tl.float32)
78 b_g = tl.cumsum(b_g, axis=0)
79 tl.store(p_g_out, b_g.to(p_g_out.dtype.element_ty), boundary_check=(0,))
81 # ---------- KKT (write L to A) ----------
82 o_t = i_t * BT + tl.arange(0, BT)
83 m_t = o_t < T
84 p_beta = tl.make_block_ptr(
85 beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
86 )
87 b_beta = tl.load(p_beta, boundary_check=(0,))
88 b_A = tl.zeros([BT, BT], dtype=tl.float32)
89 for i_k in range(tl.cdiv(K, BK)):
90 p_k = tl.make_block_ptr(
91 k + (bos * Hg + i_h // (H // Hg)) * K,
92 (T, K),
93 (Hg * K, 1),
94 (i_t * BT, i_k * BK),
95 (BT, BK),
96 (1, 0),
97 )
98 b_k = tl.load(p_k, boundary_check=(0, 1))
99 b_kb = b_k * b_beta[:, None]
100 b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
101 if USE_G:
102 b_g_diff = b_g[:, None] - b_g[None, :]
103 b_A = b_A * exp(b_g_diff)
104 m_A_kkt = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
105 b_A = tl.where(m_A_kkt, b_A, 0)
106 p_A = tl.make_block_ptr(
107 A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
108 )
109 tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
111 # ---------- solve_tril (read A, write A_inv) ----------
112 o_i = tl.arange(0, 16)
113 m_A = o_i[:, None] > o_i[None, :]
114 m_I = o_i[:, None] == o_i[None, :]
115 A_base = A + (bos * H + i_h) * BT
116 A_inv_base = A_inv + (bos * H + i_h) * BT
118 if not USE_TMA:
119 p_A_11 = tl.make_block_ptr(
120 A_base, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
121 )
122 p_A_22 = tl.make_block_ptr(
123 A_base, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
124 )
125 p_A_33 = tl.make_block_ptr(
126 A_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
127 )
128 p_A_44 = tl.make_block_ptr(
129 A_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
130 )
131 b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
132 b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
133 b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32)
134 b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32)
135 else:
136 desc = make_tensor_descriptor(A_base, [T, BT], [H * BT, 1], [16, 16])
137 b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
138 b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
139 b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32)
140 b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32)
142 b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
143 b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
144 b_Ai_33 = -tl.where(m_A, b_Ai_33, 0)
145 b_Ai_44 = -tl.where(m_A, b_Ai_44, 0)
147 for i in range(2, min(16, T - i_t * BT)):
148 b_a_11 = -tl.load(A_base + (i_t * BT + i) * H * BT + o_i)
149 b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
150 b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
151 for i in range(16 + 2, min(32, T - i_t * BT)):
152 b_a_22 = -tl.load(A_base + (i_t * BT + i) * H * BT + o_i + 16)
153 b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
154 b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
155 for i in range(32 + 2, min(48, T - i_t * BT)):
156 b_a_33 = -tl.load(A_base + (i_t * BT + i) * H * BT + o_i + 32)
157 b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0)
158 b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33)
159 for i in range(48 + 2, min(64, T - i_t * BT)):
160 b_a_44 = -tl.load(A_base + (i_t * BT + i) * H * BT + o_i + 48)
161 b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0)
162 b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44)
163 b_Ai_11 += m_I
164 b_Ai_22 += m_I
165 b_Ai_33 += m_I
166 b_Ai_44 += m_I
168 if not USE_TMA:
169 p_A_21 = tl.make_block_ptr(
170 A_base, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
171 )
172 p_A_31 = tl.make_block_ptr(
173 A_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
174 )
175 p_A_32 = tl.make_block_ptr(
176 A_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
177 )
178 p_A_41 = tl.make_block_ptr(
179 A_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
180 )
181 p_A_42 = tl.make_block_ptr(
182 A_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
183 )
184 p_A_43 = tl.make_block_ptr(
185 A_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
186 )
187 b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
188 b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
189 b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
190 b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
191 b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
192 b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
193 else:
194 b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
195 b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32)
196 b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32)
197 b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32)
198 b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32)
199 b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32)
201 b_Ai_21 = -tl.dot(
202 tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
203 b_Ai_11,
204 input_precision=DOT_PRECISION,
205 )
206 b_Ai_32 = -tl.dot(
207 tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION),
208 b_Ai_22,
209 input_precision=DOT_PRECISION,
210 )
211 b_Ai_43 = -tl.dot(
212 tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION),
213 b_Ai_33,
214 input_precision=DOT_PRECISION,
215 )
216 b_Ai_31 = -tl.dot(
217 b_Ai_33,
218 tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION)
219 + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION),
220 input_precision=DOT_PRECISION,
221 )
222 b_Ai_42 = -tl.dot(
223 b_Ai_44,
224 tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION)
225 + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION),
226 input_precision=DOT_PRECISION,
227 )
228 b_Ai_41 = -tl.dot(
229 b_Ai_44,
230 tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION)
231 + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION)
232 + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION),
233 input_precision=DOT_PRECISION,
234 )
236 if not USE_TMA:
237 p_Ai_11 = tl.make_block_ptr(
238 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
239 )
240 p_Ai_22 = tl.make_block_ptr(
241 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
242 )
243 p_Ai_33 = tl.make_block_ptr(
244 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
245 )
246 p_Ai_44 = tl.make_block_ptr(
247 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
248 )
249 p_Ai_21 = tl.make_block_ptr(
250 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
251 )
252 p_Ai_31 = tl.make_block_ptr(
253 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
254 )
255 p_Ai_32 = tl.make_block_ptr(
256 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
257 )
258 p_Ai_41 = tl.make_block_ptr(
259 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
260 )
261 p_Ai_42 = tl.make_block_ptr(
262 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
263 )
264 p_Ai_43 = tl.make_block_ptr(
265 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
266 )
267 tl.store(
268 p_Ai_11,
269 b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
270 boundary_check=(0, 1),
271 )
272 tl.store(
273 p_Ai_22,
274 b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
275 boundary_check=(0, 1),
276 )
277 tl.store(
278 p_Ai_33,
279 b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
280 boundary_check=(0, 1),
281 )
282 tl.store(
283 p_Ai_44,
284 b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
285 boundary_check=(0, 1),
286 )
287 tl.store(
288 p_Ai_21,
289 b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
290 boundary_check=(0, 1),
291 )
292 tl.store(
293 p_Ai_31,
294 b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
295 boundary_check=(0, 1),
296 )
297 tl.store(
298 p_Ai_32,
299 b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
300 boundary_check=(0, 1),
301 )
302 tl.store(
303 p_Ai_41,
304 b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
305 boundary_check=(0, 1),
306 )
307 tl.store(
308 p_Ai_42,
309 b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
310 boundary_check=(0, 1),
311 )
312 tl.store(
313 p_Ai_43,
314 b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
315 boundary_check=(0, 1),
316 )
317 else:
318 desc_o = make_tensor_descriptor(A_inv_base, [T, BT], [H * BT, 1], [16, 16])
319 desc_o.store(
320 [i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
321 )
322 desc_o.store(
323 [i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
324 )
325 desc_o.store(
326 [i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")
327 )
328 desc_o.store(
329 [i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")
330 )
331 desc_o.store(
332 [i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
333 )
334 desc_o.store(
335 [i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")
336 )
337 desc_o.store(
338 [i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")
339 )
340 desc_o.store(
341 [i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")
342 )
343 desc_o.store(
344 [i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")
345 )
346 desc_o.store(
347 [i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")
348 )
351def chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril(
352 g: torch.Tensor,
353 k: torch.Tensor,
354 beta: torch.Tensor,
355 cu_seqlens: torch.LongTensor | None = None,
356 chunk_size: int = 64,
357 output_dtype: torch.dtype | None = None,
358) -> tuple[torch.Tensor, torch.Tensor]:
359 """Fused kernel: cumsum(g) + KKT(L) + solve_tril(L -> inv). Returns (g_out, A_inv).
360 w_u stays a separate kernel (e.g. recompute_w_u_fwd) for HGMMA."""
361 B, T, Hg, K = k.shape
362 H = beta.shape[-1]
363 BT = chunk_size
364 output_dtype = output_dtype or k.dtype
365 chunk_indices = (
366 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
367 )
368 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
370 g_out = torch.empty_like(g)
371 A = torch.empty(B, T, H, BT, device=g.device, dtype=torch.float32)
372 A_inv = torch.zeros(B, T, H, BT, device=g.device, dtype=output_dtype)
374 def grid(meta):
375 return (NT, B * H)
377 chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril_kernel[grid](
378 g_in=g,
379 g_out=g_out,
380 k=k,
381 beta=beta,
382 A=A,
383 A_inv=A_inv,
384 cu_seqlens=cu_seqlens,
385 chunk_indices=chunk_indices,
386 T=T,
387 H=H,
388 Hg=Hg,
389 K=K,
390 BT=BT,
391 IS_VARLEN=cu_seqlens is not None,
392 USE_TMA=is_tma_supported,
393 DOT_PRECISION=FLA_TRIL_PRECISION,
394 )
395 return g_out, A_inv