Coverage for src/flag_gems/runtime/backend/_ascend/fla/solve_tril.py: 0%

180 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9# ruff: noqa: E501 

10# mypy: ignore-errors 

11import torch 

12import triton 

13import triton.language as tl 

14 

15from .utils import prepare_chunk_indices 

16 

17 

18@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) 

19@triton.jit(do_not_specialize=["T"]) 

20def solve_tril_16x16_kernel( 

21 A, 

22 Ad, 

23 cu_seqlens, 

24 chunk_indices, 

25 T, 

26 H: tl.constexpr, 

27 BT: tl.constexpr, 

28 IS_VARLEN: tl.constexpr, 

29 LARGE_BLOCK_T: tl.constexpr, 

30): 

31 i_t, i_bh = tl.program_id(0), tl.program_id(1) 

32 i_b, i_h = i_bh // H, i_bh % H 

33 if IS_VARLEN: 

34 i_n, i_t = ( 

35 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

36 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

37 ) 

38 bos, eos = ( 

39 tl.load(cu_seqlens + i_n).to(tl.int32), 

40 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

41 ) 

42 T = eos - bos 

43 else: 

44 bos, eos = i_b * T, i_b * T + T 

45 

46 A = A + (bos * H + i_h) * BT 

47 Ad = Ad + (bos * H + i_h) * 16 

48 

49 base_t = i_t * LARGE_BLOCK_T 

50 

51 NTASKS: tl.constexpr = 2 

52 N_BLOCKS: tl.constexpr = LARGE_BLOCK_T // 16 // NTASKS 

53 

54 for taskid in range(0, NTASKS): 

55 base_t += taskid * (LARGE_BLOCK_T // NTASKS) 

56 

57 # use make_block_ptr to reduce vector computation 

58 b_A = tl.zeros((N_BLOCKS, 16, 16), dtype=tl.float32) 

59 for blkid in range(0, N_BLOCKS): 

60 row_start_o = base_t + blkid * 16 

61 col_start_o = row_start_o % BT 

62 

63 # 1 Create in-block offset 

64 offs_rows_in_block = tl.arange(0, 16) 

65 offs_cols_in_block = tl.arange(0, 16) 

66 

67 # 2 Calculate the pointer of each element 

68 ptr_A_subrec16 = ( 

69 A 

70 + row_start_o * H * BT 

71 + col_start_o 

72 + offs_rows_in_block[:, None] * H * BT 

73 + offs_cols_in_block[None, :] 

74 ) 

75 

76 # 3 Create a mask to prevent out-of-bounds access 

77 global_rows = row_start_o + offs_rows_in_block[:, None] 

78 global_cols = col_start_o + offs_cols_in_block[None, :] 

79 load_mask = (global_rows < T) & (global_cols < BT) 

80 

81 # 4 Use mask to safely load data 

82 b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to( 

83 tl.float32 

84 ) 

85 b_A = tl.insert_slice( 

86 ful=b_A, 

87 sub=b_A_subrec16[None, :, :], # (1, 16, 16) 

88 offsets=[blkid, 0, 0], 

89 sizes=[1, 16, 16], 

90 strides=[1, 1, 1], 

91 ) 

92 

93 local_ori_A = tl.trans(b_A, (1, 0, 2)) 

94 local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS)) 

95 

96 # Convert mask into matrix multiplication to avoid for loops ub oom 

97 tmp = tl.arange(0, 16).to(tl.float32) 

98 rows = tmp[:, None] 

99 cols = tmp[None, :] 

100 is_lower = (rows > cols).to(b_A.dtype) 

101 b_A = -b_A * is_lower 

102 

103 # for loop to update N_BLOCKS row vector 

104 for i in range(1, 16): 

105 nblks_vec16 = -tl.extract_slice( 

106 local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1) 

107 ) 

108 b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16)) 

109 

110 dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2)) 

111 dot_product = tl.sum(dot_tmp, 0) 

112 b_a = b_a + dot_product 

113 

114 b_a_new_expanded = b_a[:, None, :] 

115 b_A = tl.insert_slice( 

116 ful=b_A, 

117 sub=b_a_new_expanded, 

118 offsets=[0, i, 0], 

119 sizes=[N_BLOCKS, 1, 16], 

120 strides=[1, 1, 1], 

121 ) 

122 

123 on_diagonal = rows == cols 

124 b_A = tl.where(on_diagonal, b_A + 1.0, b_A) 

125 

126 b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16)) 

127 p_Ai = tl.make_block_ptr( 

128 Ad, (T, 16), (H * 16, 1), (base_t, 0), (N_BLOCKS * 16, 16), (1, 0) 

129 ) 

130 

131 # 1 Create in-block offset 

132 offs_rows_to_store = tl.arange(0, N_BLOCKS * 16) 

133 offs_cols_to_store = tl.arange(0, 16) 

134 

135 # 2 Calculate the pointer of each element 

136 p_Ai = ( 

137 Ad 

138 + base_t * H * 16 

139 + 0 

140 + offs_rows_to_store[:, None] * H * 16 

141 + offs_cols_to_store[None, :] 

142 ) 

143 # 3 Create a mask to prevent out-of-bounds access, only check rows 

144 global_store_rows = base_t + offs_rows_to_store[:, None] 

145 store_mask = global_store_rows < T 

146 # 4 use mask to save data safely 

147 tl.store( 

148 p_Ai, 

149 b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), 

150 mask=store_mask, 

151 ) 

152 

153 

154@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) 

155@triton.jit(do_not_specialize=["T"]) 

156def merge_16x16_to_32x32_inverse_kernel( 

157 A, 

158 Ad, 

159 Ai, 

160 cu_seqlens, 

161 chunk_indices, 

162 T, 

163 H: tl.constexpr, 

164 BT: tl.constexpr, 

165 IS_VARLEN: tl.constexpr, 

166): 

167 i_t, i_bh = tl.program_id(0), tl.program_id(1) 

168 i_b, i_h = i_bh // H, i_bh % H 

169 if IS_VARLEN: 

170 i_n, i_t = ( 

171 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

172 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

173 ) 

174 bos, eos = ( 

175 tl.load(cu_seqlens + i_n).to(tl.int32), 

176 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

177 ) 

178 T = eos - bos 

179 else: 

180 bos, eos = i_b * T, i_b * T + T 

181 

182 A += (bos * H + i_h) * 32 

183 Ad += (bos * H + i_h) * 16 

184 Ai += (bos * H + i_h) * 32 

185 

186 p_A_21 = tl.make_block_ptr( 

187 A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) 

188 ) 

189 p_Ad_11 = tl.make_block_ptr( 

190 Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0) 

191 ) 

192 p_Ad_22 = tl.make_block_ptr( 

193 Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) 

194 ) 

195 p_Ai_11 = tl.make_block_ptr( 

196 Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0) 

197 ) 

198 p_Ai_22 = tl.make_block_ptr( 

199 Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0) 

200 ) 

201 p_Ai_21 = tl.make_block_ptr( 

202 Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) 

203 ) 

204 

205 A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) 

206 Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) 

207 Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) 

208 Ai_21 = -tl.dot( 

209 tl.dot(Ai_22, A_21, input_precision="ieee"), 

210 Ai_11, 

211 input_precision="ieee", 

212 ) 

213 tl.store( 

214 p_Ai_11, 

215 Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), 

216 boundary_check=(0, 1), 

217 ) 

218 tl.store( 

219 p_Ai_22, 

220 Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), 

221 boundary_check=(0, 1), 

222 ) 

223 tl.store( 

224 p_Ai_21, 

225 Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), 

226 boundary_check=(0, 1), 

227 ) 

228 

229 

230@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) 

231@triton.jit(do_not_specialize=["T"]) 

232def merge_16x16_to_64x64_inverse_kernel( 

233 A, 

234 Ad, 

235 Ai, 

236 cu_seqlens, 

237 chunk_indices, 

238 T, 

239 H: tl.constexpr, 

240 BT: tl.constexpr, 

241 IS_VARLEN: tl.constexpr, 

242): 

243 i_t, i_bh = tl.program_id(0), tl.program_id(1) 

244 i_b, i_h = i_bh // H, i_bh % H 

245 if IS_VARLEN: 

246 i_n, i_t_val = ( 

247 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

248 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

249 ) 

250 bos, eos = ( 

251 tl.load(cu_seqlens + i_n).to(tl.int32), 

252 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

253 ) 

254 T = eos - bos 

255 i_t = i_t_val 

256 else: 

257 bos, eos = i_b * T, i_b * T + T 

258 

259 # Base pointers (already offset by batch and head) 

260 A += (bos * H + i_h) * 64 

261 Ad += (bos * H + i_h) * 16 

262 Ai += (bos * H + i_h) * 64 

263 

264 # load Ai_22 (Ad block at row i_t * 64 + 16, col 0, 16 * 16) 

265 offs_m = i_t * 64 + 16 + tl.arange(0, 16) 

266 offs_n = tl.arange(0, 16) 

267 mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) 

268 ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] 

269 Ai_22 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) 

270 

271 # load A_21 (A block at row i_t * 64 + 16, col 0, 16 * 16) 

272 mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64) 

273 ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :] 

274 A_21 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) 

275 tmp = tl.dot(Ai_22, A_21, input_precision="ieee") 

276 

277 # load Ai_11 (Ad block at row i_t * 64, col 0, 16 * 16) 

278 offs_m = i_t * 64 + tl.arange(0, 16) 

279 offs_n = tl.arange(0, 16) 

280 mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) 

281 ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] 

282 Ai_11 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) 

283 

284 Ai_21 = -tl.dot(tmp, Ai_11, input_precision="ieee") 

285 

286 # load Ai_44 (Ad block at row i_t * 64 + 48, col 0, 16 * 16) 

287 offs_m = i_t * 64 + 48 + tl.arange(0, 16) 

288 offs_n = tl.arange(0, 16) 

289 mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) 

290 ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] 

291 Ai_44 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) 

292 

293 # load A_43 (Ad block at row i_t * 64 + 48, col 32, 16 * 16) 

294 offs_n = 32 + tl.arange(0, 16) 

295 mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64) 

296 ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :] 

297 A_43 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) 

298 tmp = tl.dot(Ai_44, A_43, input_precision="ieee") 

299 

300 # load Ai_33 (Ad block at row i_t * 64 + 32, col 0, 16 * 16) 

301 offs_m = i_t * 64 + 32 + tl.arange(0, 16) 

302 offs_n = tl.arange(0, 16) 

303 mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) 

304 ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] 

305 Ai_33 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) 

306 

307 Ai_43 = -tl.dot(tmp, Ai_33, input_precision="ieee") 

308 

309 # build Ai_22_32 (32 * 32) 

310 Ai_22_32 = tl.zeros((32, 32), tl.float32) 

311 Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1)) 

312 Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1)) 

313 Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1)) 

314 

315 # load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32) 

316 offs_m = i_t * 64 + 32 + tl.arange(0, 32) 

317 offs_n = tl.arange(0, 32) 

318 mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64) 

319 ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :] 

320 A_21_32 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) 

321 tmp = tl.dot(Ai_22_32, A_21_32, input_precision="ieee") 

322 

323 # build Ai_11_32 (32 * 32) 

324 Ai_11_32 = tl.zeros((32, 32), tl.float32) 

325 Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1)) 

326 Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1)) 

327 Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1)) 

328 

329 Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee") 

330 

331 # store Ai_11_32 to (i_t * 64, 0) 

332 offs_m = i_t * 64 + tl.arange(0, 32) 

333 offs_n = tl.arange(0, 32) 

334 mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) 

335 ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] 

336 tl.store( 

337 ptr_Ai, 

338 Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), 

339 mask=mask_store, 

340 ) 

341 

342 # store Ai_22_32 to (i_t * 64 + 32, 32) 

343 offs_m = i_t * 64 + 32 + tl.arange(0, 32) 

344 offs_n = 32 + tl.arange(0, 32) 

345 mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) 

346 ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] 

347 tl.store( 

348 ptr_Ai, 

349 Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), 

350 mask=mask_store, 

351 ) 

352 

353 # store Ai_21_32 to (i_t * 64 + 32, 32) 

354 offs_n = tl.arange(0, 32) 

355 mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) 

356 ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] 

357 tl.store( 

358 ptr_Ai, 

359 Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), 

360 mask=mask_store, 

361 ) 

362 

363 # zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63) 

364 offs_m = i_t * 64 + tl.arange(0, 32) 

365 offs_n = 32 + tl.arange(0, 32) 

366 mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < BT) 

367 ptr_Ai = Ai + offs_m[:, None] * (H * BT) + offs_n[None, :] 

368 zero_block = tl.zeros((32, 32), dtype=ptr_Ai.dtype.element_ty) 

369 tl.store(ptr_Ai, zero_block, mask=mask_store) 

370 

371 

372def solve_tril( 

373 A: torch.Tensor, 

374 cu_seqlens: torch.Tensor | None = None, 

375 output_dtype: torch.dtype = torch.float, 

376) -> torch.Tensor: 

377 """ 

378 Compute the inverse of the matrix I + A 

379 A should be strictly lower triangular, i.e., A.triu() == 0. 

380 

381 Args: 

382 A (torch.Tensor): 

383 [B, T, H, BT], where BT should only be 16, 32, or 64. 

384 cu_seqlens (torch.Tensor): 

385 The cumulative sequence lengths of the input tensor. Default: `None`. 

386 output_dtype (torch.dtype): 

387 The dtype of the output tensor. Default: `torch.float`. 

388 If `None`, the output dtype will be the same as the input dtype. 

389 

390 Returns: 

391 (I + A)^-1 with the same shape as A 

392 """ 

393 assert A.shape[-1] in [16, 32, 64] 

394 

395 B, T, H, BT = A.shape 

396 Ad = torch.empty( 

397 B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype 

398 ) 

399 

400 LARGE_BLOCK_T = 608 * 2 

401 

402 chunk_indices = ( 

403 prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) 

404 if cu_seqlens is not None 

405 else None 

406 ) 

407 NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T) 

408 

409 solve_tril_16x16_kernel[NT, B * H]( 

410 A=A, 

411 Ad=Ad, 

412 cu_seqlens=cu_seqlens, 

413 chunk_indices=chunk_indices, 

414 T=T, 

415 H=H, 

416 BT=BT, 

417 LARGE_BLOCK_T=LARGE_BLOCK_T, 

418 num_warps=1, 

419 num_stages=4, 

420 ) 

421 

422 if BT == 16: 

423 return Ad 

424 

425 Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) 

426 merge_fn = ( 

427 merge_16x16_to_32x32_inverse_kernel 

428 if BT == 32 

429 else merge_16x16_to_64x64_inverse_kernel 

430 ) 

431 chunk_indices = ( 

432 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None 

433 ) 

434 NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) 

435 

436 merge_fn[NT, B * H]( 

437 A=A, 

438 Ad=Ad, 

439 Ai=Ai, 

440 cu_seqlens=cu_seqlens, 

441 chunk_indices=chunk_indices, 

442 T=T, 

443 H=H, 

444 BT=BT, 

445 num_warps=4, 

446 num_stages=3, 

447 ) 

448 return Ai