Coverage for src/flag_gems/fused/FLA/solve_tril.py: 12%
226 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
7import os
9import torch
10import triton
11import triton.language as tl
13from flag_gems.fused.FLA.index import prepare_chunk_indices
14from flag_gems.fused.FLA.triton_ops_helper import make_tensor_descriptor
15from flag_gems.fused.FLA.utils import input_guard, is_tma_supported
16from flag_gems.utils import libentry, libtuner
18FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
19ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32", "tf32x3"]
20assert (
21 FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS
22), f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}"
25@libentry()
26@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
27@libtuner(
28 configs=[
29 triton.Config({}, num_warps=num_warps, num_stages=num_stages)
30 for num_warps in [1, 2, 4, 8]
31 for num_stages in [2, 3, 4, 5]
32 ],
33 key=["BT"],
34)
35@triton.jit(do_not_specialize=["T"])
36def solve_tril_16x16_kernel(
37 A,
38 Ai,
39 cu_seqlens,
40 chunk_indices,
41 T,
42 H: tl.constexpr,
43 BT: tl.constexpr,
44 USE_TMA: tl.constexpr,
45 IS_VARLEN: tl.constexpr,
46 DOT_PRECISION: tl.constexpr,
47):
48 i_t, i_bh = tl.program_id(0), tl.program_id(1)
49 i_b, i_h = i_bh // H, i_bh % H
50 if IS_VARLEN:
51 i_n, i_t = (
52 tl.load(chunk_indices + i_t * 2).to(tl.int32),
53 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
54 )
55 bos, eos = (
56 tl.load(cu_seqlens + i_n).to(tl.int32),
57 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
58 )
59 T = eos - bos
60 else:
61 bos, eos = i_b * T, i_b * T + T
62 o_i = tl.arange(0, 16)
63 m_A = o_i[:, None] > o_i[None, :]
64 m_I = o_i[:, None] == o_i[None, :]
66 A = A + (bos * H + i_h) * BT
67 Ai = Ai + (bos * H + i_h) * 16
69 offset = (i_t * 16) % BT
70 if not USE_TMA:
71 p_A = tl.make_block_ptr(
72 A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
73 )
74 # [16, 16]
75 b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
76 else:
77 desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
78 desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16])
79 b_A = desc.load([i_t * 16, offset]).to(tl.float32)
80 b_A = -tl.where(m_A, b_A, 0)
82 for i in range(2, min(16, T - i_t * 16)):
83 # [16]
84 b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
85 b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
86 b_A = tl.where((o_i == i)[:, None], b_a, b_A)
87 b_A += m_I
88 if not USE_TMA:
89 p_Ai = tl.make_block_ptr(
90 Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)
91 )
92 tl.store(
93 p_Ai,
94 b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
95 boundary_check=(0, 1),
96 )
97 else:
98 desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
101@libentry()
102@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
103@libtuner(
104 configs=[
105 triton.Config({}, num_warps=num_warps, num_stages=num_stages)
106 for num_warps in [1, 2, 4, 8]
107 for num_stages in [2, 3, 4, 5]
108 ],
109 key=["H", "BT", "IS_VARLEN"],
110)
111@triton.jit(do_not_specialize=["T"])
112def merge_16x16_to_32x32_inverse_kernel(
113 A,
114 Ai,
115 cu_seqlens,
116 chunk_indices,
117 T,
118 H: tl.constexpr,
119 BT: tl.constexpr,
120 USE_TMA: tl.constexpr,
121 IS_VARLEN: tl.constexpr,
122 DOT_PRECISION: tl.constexpr,
123):
124 i_t, i_bh = tl.program_id(0), tl.program_id(1)
125 i_b, i_h = i_bh // H, i_bh % H
126 if IS_VARLEN:
127 i_n, i_t = (
128 tl.load(chunk_indices + i_t * 2).to(tl.int32),
129 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
130 )
131 bos, eos = (
132 tl.load(cu_seqlens + i_n).to(tl.int32),
133 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
134 )
135 T = eos - bos
136 else:
137 bos, eos = i_b * T, i_b * T + T
139 o_i = tl.arange(0, 16)
140 m_A = o_i[:, None] > o_i[None, :]
141 m_I = o_i[:, None] == o_i[None, :]
142 A += (bos * H + i_h) * BT
143 Ai += (bos * H + i_h) * BT
145 if not USE_TMA:
146 p_A_11 = tl.make_block_ptr(
147 A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
148 )
149 p_A_22 = tl.make_block_ptr(
150 A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
151 )
152 b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
153 b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
154 else:
155 desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
156 desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
157 b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
158 b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
160 # [16, 16]
161 b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
162 b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
164 for i in range(2, min(16, T - i_t * BT)):
165 b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
166 b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
167 b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
168 for i in range(16 + 2, min(32, T - i_t * BT)):
169 b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
170 b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
171 b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
173 b_Ai_11 += m_I
174 b_Ai_22 += m_I
176 if not USE_TMA:
177 p_A_21 = tl.make_block_ptr(
178 A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
179 )
180 b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
181 else:
182 b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
184 b_Ai_21 = -tl.dot(
185 tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
186 b_Ai_11,
187 input_precision=DOT_PRECISION,
188 )
190 if not USE_TMA:
191 p_Ai_11 = tl.make_block_ptr(
192 Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
193 )
194 p_Ai_21 = tl.make_block_ptr(
195 Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
196 )
197 p_Ai_22 = tl.make_block_ptr(
198 Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
199 )
200 tl.store(
201 p_Ai_11,
202 b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
203 boundary_check=(0, 1),
204 )
205 tl.store(
206 p_Ai_22,
207 b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
208 boundary_check=(0, 1),
209 )
210 tl.store(
211 p_Ai_21,
212 b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
213 boundary_check=(0, 1),
214 )
215 else:
216 desc_o.store(
217 [i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
218 )
219 desc_o.store(
220 [i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
221 )
222 desc_o.store(
223 [i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
224 )
227@libentry()
228@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
229@libtuner(
230 configs=[
231 triton.Config({}, num_warps=num_warps, num_stages=num_stages)
232 for num_warps in [2, 4, 8]
233 for num_stages in [2, 3, 4, 5]
234 ],
235 key=["H", "BT", "IS_VARLEN"],
236)
237@triton.jit(do_not_specialize=["T"])
238def merge_16x16_to_64x64_inverse_kernel(
239 A,
240 Ai,
241 cu_seqlens,
242 chunk_indices,
243 T,
244 H: tl.constexpr,
245 BT: tl.constexpr,
246 USE_TMA: tl.constexpr,
247 IS_VARLEN: tl.constexpr,
248 DOT_PRECISION: tl.constexpr,
249):
250 i_t, i_bh = tl.program_id(0), tl.program_id(1)
251 i_b, i_h = i_bh // H, i_bh % H
252 if IS_VARLEN:
253 i_n, i_t = (
254 tl.load(chunk_indices + i_t * 2).to(tl.int32),
255 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
256 )
257 bos, eos = (
258 tl.load(cu_seqlens + i_n).to(tl.int32),
259 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
260 )
261 T = eos - bos
262 else:
263 bos, eos = i_b * T, i_b * T + T
265 o_i = tl.arange(0, 16)
266 m_A = o_i[:, None] > o_i[None, :]
267 m_I = o_i[:, None] == o_i[None, :]
268 A += (bos * H + i_h) * BT
269 Ai += (bos * H + i_h) * BT
271 if not USE_TMA:
272 p_A_11 = tl.make_block_ptr(
273 A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
274 )
275 p_A_22 = tl.make_block_ptr(
276 A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
277 )
278 p_A_33 = tl.make_block_ptr(
279 A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
280 )
281 p_A_44 = tl.make_block_ptr(
282 A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
283 )
284 b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
285 b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
286 b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32)
287 b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32)
288 else:
289 desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
290 desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
291 b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
292 b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
293 b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32)
294 b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32)
296 # [16, 16]
297 b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
298 b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
299 b_Ai_33 = -tl.where(m_A, b_Ai_33, 0)
300 b_Ai_44 = -tl.where(m_A, b_Ai_44, 0)
302 for i in range(2, min(16, T - i_t * BT)):
303 b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
304 b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
305 b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
306 for i in range(16 + 2, min(32, T - i_t * BT)):
307 b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
308 b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
309 b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
310 for i in range(32 + 2, min(48, T - i_t * BT)):
311 b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32)
312 b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0)
313 b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33)
314 for i in range(48 + 2, min(64, T - i_t * BT)):
315 b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48)
316 b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0)
317 b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44)
318 b_Ai_11 += m_I
319 b_Ai_22 += m_I
320 b_Ai_33 += m_I
321 b_Ai_44 += m_I
323 if not USE_TMA:
324 p_A_21 = tl.make_block_ptr(
325 A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
326 )
327 p_A_31 = tl.make_block_ptr(
328 A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
329 )
330 p_A_32 = tl.make_block_ptr(
331 A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
332 )
333 p_A_41 = tl.make_block_ptr(
334 A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
335 )
336 p_A_42 = tl.make_block_ptr(
337 A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
338 )
339 p_A_43 = tl.make_block_ptr(
340 A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
341 )
342 b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
343 b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
344 b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
345 b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
346 b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
347 b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
348 else:
349 b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
350 b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32)
351 b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32)
352 b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32)
353 b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32)
354 b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32)
356 b_Ai_21 = -tl.dot(
357 tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
358 b_Ai_11,
359 input_precision=DOT_PRECISION,
360 )
361 b_Ai_32 = -tl.dot(
362 tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION),
363 b_Ai_22,
364 input_precision=DOT_PRECISION,
365 )
366 b_Ai_43 = -tl.dot(
367 tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION),
368 b_Ai_33,
369 input_precision=DOT_PRECISION,
370 )
372 b_Ai_31 = -tl.dot(
373 b_Ai_33,
374 tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION)
375 + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION),
376 input_precision=DOT_PRECISION,
377 )
378 b_Ai_42 = -tl.dot(
379 b_Ai_44,
380 tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION)
381 + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION),
382 input_precision=DOT_PRECISION,
383 )
384 b_Ai_41 = -tl.dot(
385 b_Ai_44,
386 tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION)
387 + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION)
388 + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION),
389 input_precision=DOT_PRECISION,
390 )
392 if not USE_TMA:
393 p_Ai_11 = tl.make_block_ptr(
394 Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
395 )
396 p_Ai_22 = tl.make_block_ptr(
397 Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
398 )
399 p_Ai_33 = tl.make_block_ptr(
400 Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
401 )
402 p_Ai_44 = tl.make_block_ptr(
403 Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
404 )
405 p_Ai_21 = tl.make_block_ptr(
406 Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
407 )
408 p_Ai_31 = tl.make_block_ptr(
409 Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
410 )
411 p_Ai_32 = tl.make_block_ptr(
412 Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
413 )
414 p_Ai_41 = tl.make_block_ptr(
415 Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
416 )
417 p_Ai_42 = tl.make_block_ptr(
418 Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
419 )
420 p_Ai_43 = tl.make_block_ptr(
421 Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
422 )
423 tl.store(
424 p_Ai_11,
425 b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
426 boundary_check=(0, 1),
427 )
428 tl.store(
429 p_Ai_22,
430 b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
431 boundary_check=(0, 1),
432 )
433 tl.store(
434 p_Ai_33,
435 b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
436 boundary_check=(0, 1),
437 )
438 tl.store(
439 p_Ai_44,
440 b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
441 boundary_check=(0, 1),
442 )
443 tl.store(
444 p_Ai_21,
445 b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
446 boundary_check=(0, 1),
447 )
448 tl.store(
449 p_Ai_31,
450 b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
451 boundary_check=(0, 1),
452 )
453 tl.store(
454 p_Ai_32,
455 b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
456 boundary_check=(0, 1),
457 )
458 tl.store(
459 p_Ai_41,
460 b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
461 boundary_check=(0, 1),
462 )
463 tl.store(
464 p_Ai_42,
465 b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
466 boundary_check=(0, 1),
467 )
468 tl.store(
469 p_Ai_43,
470 b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
471 boundary_check=(0, 1),
472 )
473 else:
474 desc_o.store(
475 [i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
476 )
477 desc_o.store(
478 [i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
479 )
480 desc_o.store(
481 [i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")
482 )
483 desc_o.store(
484 [i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")
485 )
486 desc_o.store(
487 [i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
488 )
489 desc_o.store(
490 [i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")
491 )
492 desc_o.store(
493 [i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")
494 )
495 desc_o.store(
496 [i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")
497 )
498 desc_o.store(
499 [i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")
500 )
501 desc_o.store(
502 [i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")
503 )
506@input_guard
507def solve_tril(
508 A: torch.Tensor,
509 cu_seqlens: torch.Tensor | None = None,
510 output_dtype: torch.dtype = torch.float,
511) -> torch.Tensor:
512 """
513 Compute the inverse of the matrix I + A
514 A should be strictly lower triangular, i.e., A.triu() == 0.
516 Args:
517 A (torch.Tensor):
518 [B, T, H, BT], where BT should only be 16, 32, or 64.
519 cu_seqlens (torch.Tensor):
520 The cumulative sequence lengths of the input tensor. Default: `None`.
521 output_dtype (torch.dtype):
522 The dtype of the output tensor. Default: `torch.float`.
523 If `None`, the output dtype will be the same as the input dtype.
525 Returns:
526 (I + A)^-1 with the same shape as A
527 """
528 assert A.shape[-1] in [16, 32, 64]
529 output_dtype = A.dtype if output_dtype is None else output_dtype
531 B, T, H, BT = A.shape
532 chunk_indices = (
533 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
534 )
535 NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
537 Ai = torch.zeros_like(A, dtype=output_dtype)
538 if BT == 16:
539 merge_fn = solve_tril_16x16_kernel
540 elif BT == 32:
541 merge_fn = merge_16x16_to_32x32_inverse_kernel
542 elif BT == 64:
543 merge_fn = merge_16x16_to_64x64_inverse_kernel
545 merge_fn[NT, B * H](
546 A=A,
547 Ai=Ai,
548 cu_seqlens=cu_seqlens,
549 chunk_indices=chunk_indices,
550 T=T,
551 H=H,
552 BT=BT,
553 USE_TMA=is_tma_supported,
554 DOT_PRECISION=FLA_TRIL_PRECISION,
555 )
556 return Ai