Coverage for src/flag_gems/runtime/backend/_ascend/fla/solve_tril.py: 0%
180 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4#
5# This file contains code copied from the flash-linear-attention project.
6# The original source code was licensed under the MIT license and included
7# the following copyright notice:
8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9# ruff: noqa: E501
10# mypy: ignore-errors
11import torch
12import triton
13import triton.language as tl
15from .utils import prepare_chunk_indices
18@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
19@triton.jit(do_not_specialize=["T"])
20def solve_tril_16x16_kernel(
21 A,
22 Ad,
23 cu_seqlens,
24 chunk_indices,
25 T,
26 H: tl.constexpr,
27 BT: tl.constexpr,
28 IS_VARLEN: tl.constexpr,
29 LARGE_BLOCK_T: tl.constexpr,
30):
31 i_t, i_bh = tl.program_id(0), tl.program_id(1)
32 i_b, i_h = i_bh // H, i_bh % H
33 if IS_VARLEN:
34 i_n, i_t = (
35 tl.load(chunk_indices + i_t * 2).to(tl.int32),
36 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
37 )
38 bos, eos = (
39 tl.load(cu_seqlens + i_n).to(tl.int32),
40 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
41 )
42 T = eos - bos
43 else:
44 bos, eos = i_b * T, i_b * T + T
46 A = A + (bos * H + i_h) * BT
47 Ad = Ad + (bos * H + i_h) * 16
49 base_t = i_t * LARGE_BLOCK_T
51 NTASKS: tl.constexpr = 2
52 N_BLOCKS: tl.constexpr = LARGE_BLOCK_T // 16 // NTASKS
54 for taskid in range(0, NTASKS):
55 base_t += taskid * (LARGE_BLOCK_T // NTASKS)
57 # use make_block_ptr to reduce vector computation
58 b_A = tl.zeros((N_BLOCKS, 16, 16), dtype=tl.float32)
59 for blkid in range(0, N_BLOCKS):
60 row_start_o = base_t + blkid * 16
61 col_start_o = row_start_o % BT
63 # 1 Create in-block offset
64 offs_rows_in_block = tl.arange(0, 16)
65 offs_cols_in_block = tl.arange(0, 16)
67 # 2 Calculate the pointer of each element
68 ptr_A_subrec16 = (
69 A
70 + row_start_o * H * BT
71 + col_start_o
72 + offs_rows_in_block[:, None] * H * BT
73 + offs_cols_in_block[None, :]
74 )
76 # 3 Create a mask to prevent out-of-bounds access
77 global_rows = row_start_o + offs_rows_in_block[:, None]
78 global_cols = col_start_o + offs_cols_in_block[None, :]
79 load_mask = (global_rows < T) & (global_cols < BT)
81 # 4 Use mask to safely load data
82 b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(
83 tl.float32
84 )
85 b_A = tl.insert_slice(
86 ful=b_A,
87 sub=b_A_subrec16[None, :, :], # (1, 16, 16)
88 offsets=[blkid, 0, 0],
89 sizes=[1, 16, 16],
90 strides=[1, 1, 1],
91 )
93 local_ori_A = tl.trans(b_A, (1, 0, 2))
94 local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS))
96 # Convert mask into matrix multiplication to avoid for loops ub oom
97 tmp = tl.arange(0, 16).to(tl.float32)
98 rows = tmp[:, None]
99 cols = tmp[None, :]
100 is_lower = (rows > cols).to(b_A.dtype)
101 b_A = -b_A * is_lower
103 # for loop to update N_BLOCKS row vector
104 for i in range(1, 16):
105 nblks_vec16 = -tl.extract_slice(
106 local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1)
107 )
108 b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
110 dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
111 dot_product = tl.sum(dot_tmp, 0)
112 b_a = b_a + dot_product
114 b_a_new_expanded = b_a[:, None, :]
115 b_A = tl.insert_slice(
116 ful=b_A,
117 sub=b_a_new_expanded,
118 offsets=[0, i, 0],
119 sizes=[N_BLOCKS, 1, 16],
120 strides=[1, 1, 1],
121 )
123 on_diagonal = rows == cols
124 b_A = tl.where(on_diagonal, b_A + 1.0, b_A)
126 b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16))
127 p_Ai = tl.make_block_ptr(
128 Ad, (T, 16), (H * 16, 1), (base_t, 0), (N_BLOCKS * 16, 16), (1, 0)
129 )
131 # 1 Create in-block offset
132 offs_rows_to_store = tl.arange(0, N_BLOCKS * 16)
133 offs_cols_to_store = tl.arange(0, 16)
135 # 2 Calculate the pointer of each element
136 p_Ai = (
137 Ad
138 + base_t * H * 16
139 + 0
140 + offs_rows_to_store[:, None] * H * 16
141 + offs_cols_to_store[None, :]
142 )
143 # 3 Create a mask to prevent out-of-bounds access, only check rows
144 global_store_rows = base_t + offs_rows_to_store[:, None]
145 store_mask = global_store_rows < T
146 # 4 use mask to save data safely
147 tl.store(
148 p_Ai,
149 b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
150 mask=store_mask,
151 )
154@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
155@triton.jit(do_not_specialize=["T"])
156def merge_16x16_to_32x32_inverse_kernel(
157 A,
158 Ad,
159 Ai,
160 cu_seqlens,
161 chunk_indices,
162 T,
163 H: tl.constexpr,
164 BT: tl.constexpr,
165 IS_VARLEN: tl.constexpr,
166):
167 i_t, i_bh = tl.program_id(0), tl.program_id(1)
168 i_b, i_h = i_bh // H, i_bh % H
169 if IS_VARLEN:
170 i_n, i_t = (
171 tl.load(chunk_indices + i_t * 2).to(tl.int32),
172 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
173 )
174 bos, eos = (
175 tl.load(cu_seqlens + i_n).to(tl.int32),
176 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
177 )
178 T = eos - bos
179 else:
180 bos, eos = i_b * T, i_b * T + T
182 A += (bos * H + i_h) * 32
183 Ad += (bos * H + i_h) * 16
184 Ai += (bos * H + i_h) * 32
186 p_A_21 = tl.make_block_ptr(
187 A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
188 )
189 p_Ad_11 = tl.make_block_ptr(
190 Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)
191 )
192 p_Ad_22 = tl.make_block_ptr(
193 Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
194 )
195 p_Ai_11 = tl.make_block_ptr(
196 Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)
197 )
198 p_Ai_22 = tl.make_block_ptr(
199 Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)
200 )
201 p_Ai_21 = tl.make_block_ptr(
202 Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
203 )
205 A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
206 Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
207 Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
208 Ai_21 = -tl.dot(
209 tl.dot(Ai_22, A_21, input_precision="ieee"),
210 Ai_11,
211 input_precision="ieee",
212 )
213 tl.store(
214 p_Ai_11,
215 Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
216 boundary_check=(0, 1),
217 )
218 tl.store(
219 p_Ai_22,
220 Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
221 boundary_check=(0, 1),
222 )
223 tl.store(
224 p_Ai_21,
225 Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
226 boundary_check=(0, 1),
227 )
230@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
231@triton.jit(do_not_specialize=["T"])
232def merge_16x16_to_64x64_inverse_kernel(
233 A,
234 Ad,
235 Ai,
236 cu_seqlens,
237 chunk_indices,
238 T,
239 H: tl.constexpr,
240 BT: tl.constexpr,
241 IS_VARLEN: tl.constexpr,
242):
243 i_t, i_bh = tl.program_id(0), tl.program_id(1)
244 i_b, i_h = i_bh // H, i_bh % H
245 if IS_VARLEN:
246 i_n, i_t_val = (
247 tl.load(chunk_indices + i_t * 2).to(tl.int32),
248 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
249 )
250 bos, eos = (
251 tl.load(cu_seqlens + i_n).to(tl.int32),
252 tl.load(cu_seqlens + i_n + 1).to(tl.int32),
253 )
254 T = eos - bos
255 i_t = i_t_val
256 else:
257 bos, eos = i_b * T, i_b * T + T
259 # Base pointers (already offset by batch and head)
260 A += (bos * H + i_h) * 64
261 Ad += (bos * H + i_h) * 16
262 Ai += (bos * H + i_h) * 64
264 # load Ai_22 (Ad block at row i_t * 64 + 16, col 0, 16 * 16)
265 offs_m = i_t * 64 + 16 + tl.arange(0, 16)
266 offs_n = tl.arange(0, 16)
267 mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
268 ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
269 Ai_22 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
271 # load A_21 (A block at row i_t * 64 + 16, col 0, 16 * 16)
272 mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
273 ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
274 A_21 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
275 tmp = tl.dot(Ai_22, A_21, input_precision="ieee")
277 # load Ai_11 (Ad block at row i_t * 64, col 0, 16 * 16)
278 offs_m = i_t * 64 + tl.arange(0, 16)
279 offs_n = tl.arange(0, 16)
280 mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
281 ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
282 Ai_11 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
284 Ai_21 = -tl.dot(tmp, Ai_11, input_precision="ieee")
286 # load Ai_44 (Ad block at row i_t * 64 + 48, col 0, 16 * 16)
287 offs_m = i_t * 64 + 48 + tl.arange(0, 16)
288 offs_n = tl.arange(0, 16)
289 mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
290 ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
291 Ai_44 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
293 # load A_43 (Ad block at row i_t * 64 + 48, col 32, 16 * 16)
294 offs_n = 32 + tl.arange(0, 16)
295 mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
296 ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
297 A_43 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
298 tmp = tl.dot(Ai_44, A_43, input_precision="ieee")
300 # load Ai_33 (Ad block at row i_t * 64 + 32, col 0, 16 * 16)
301 offs_m = i_t * 64 + 32 + tl.arange(0, 16)
302 offs_n = tl.arange(0, 16)
303 mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
304 ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
305 Ai_33 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
307 Ai_43 = -tl.dot(tmp, Ai_33, input_precision="ieee")
309 # build Ai_22_32 (32 * 32)
310 Ai_22_32 = tl.zeros((32, 32), tl.float32)
311 Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
312 Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
313 Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
315 # load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
316 offs_m = i_t * 64 + 32 + tl.arange(0, 32)
317 offs_n = tl.arange(0, 32)
318 mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
319 ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
320 A_21_32 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
321 tmp = tl.dot(Ai_22_32, A_21_32, input_precision="ieee")
323 # build Ai_11_32 (32 * 32)
324 Ai_11_32 = tl.zeros((32, 32), tl.float32)
325 Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
326 Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
327 Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
329 Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")
331 # store Ai_11_32 to (i_t * 64, 0)
332 offs_m = i_t * 64 + tl.arange(0, 32)
333 offs_n = tl.arange(0, 32)
334 mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
335 ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
336 tl.store(
337 ptr_Ai,
338 Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
339 mask=mask_store,
340 )
342 # store Ai_22_32 to (i_t * 64 + 32, 32)
343 offs_m = i_t * 64 + 32 + tl.arange(0, 32)
344 offs_n = 32 + tl.arange(0, 32)
345 mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
346 ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
347 tl.store(
348 ptr_Ai,
349 Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
350 mask=mask_store,
351 )
353 # store Ai_21_32 to (i_t * 64 + 32, 32)
354 offs_n = tl.arange(0, 32)
355 mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
356 ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
357 tl.store(
358 ptr_Ai,
359 Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
360 mask=mask_store,
361 )
363 # zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63)
364 offs_m = i_t * 64 + tl.arange(0, 32)
365 offs_n = 32 + tl.arange(0, 32)
366 mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < BT)
367 ptr_Ai = Ai + offs_m[:, None] * (H * BT) + offs_n[None, :]
368 zero_block = tl.zeros((32, 32), dtype=ptr_Ai.dtype.element_ty)
369 tl.store(ptr_Ai, zero_block, mask=mask_store)
372def solve_tril(
373 A: torch.Tensor,
374 cu_seqlens: torch.Tensor | None = None,
375 output_dtype: torch.dtype = torch.float,
376) -> torch.Tensor:
377 """
378 Compute the inverse of the matrix I + A
379 A should be strictly lower triangular, i.e., A.triu() == 0.
381 Args:
382 A (torch.Tensor):
383 [B, T, H, BT], where BT should only be 16, 32, or 64.
384 cu_seqlens (torch.Tensor):
385 The cumulative sequence lengths of the input tensor. Default: `None`.
386 output_dtype (torch.dtype):
387 The dtype of the output tensor. Default: `torch.float`.
388 If `None`, the output dtype will be the same as the input dtype.
390 Returns:
391 (I + A)^-1 with the same shape as A
392 """
393 assert A.shape[-1] in [16, 32, 64]
395 B, T, H, BT = A.shape
396 Ad = torch.empty(
397 B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype
398 )
400 LARGE_BLOCK_T = 608 * 2
402 chunk_indices = (
403 prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
404 if cu_seqlens is not None
405 else None
406 )
407 NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T)
409 solve_tril_16x16_kernel[NT, B * H](
410 A=A,
411 Ad=Ad,
412 cu_seqlens=cu_seqlens,
413 chunk_indices=chunk_indices,
414 T=T,
415 H=H,
416 BT=BT,
417 LARGE_BLOCK_T=LARGE_BLOCK_T,
418 num_warps=1,
419 num_stages=4,
420 )
422 if BT == 16:
423 return Ad
425 Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
426 merge_fn = (
427 merge_16x16_to_32x32_inverse_kernel
428 if BT == 32
429 else merge_16x16_to_64x64_inverse_kernel
430 )
431 chunk_indices = (
432 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
433 )
434 NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
436 merge_fn[NT, B * H](
437 A=A,
438 Ad=Ad,
439 Ai=Ai,
440 cu_seqlens=cu_seqlens,
441 chunk_indices=chunk_indices,
442 T=T,
443 H=H,
444 BT=BT,
445 num_warps=4,
446 num_stages=3,
447 )
448 return Ai