Coverage for src/flag_gems/fused/FLA/solve_tril.py: 12%

226 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6 

7import os 

8 

9import torch 

10import triton 

11import triton.language as tl 

12 

13from flag_gems.fused.FLA.index import prepare_chunk_indices 

14from flag_gems.fused.FLA.triton_ops_helper import make_tensor_descriptor 

15from flag_gems.fused.FLA.utils import input_guard, is_tma_supported 

16from flag_gems.utils import libentry, libtuner 

17 

18FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee") 

19ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32", "tf32x3"] 

20assert ( 

21 FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS 

22), f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}" 

23 

24 

25@libentry() 

26@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) 

27@libtuner( 

28 configs=[ 

29 triton.Config({}, num_warps=num_warps, num_stages=num_stages) 

30 for num_warps in [1, 2, 4, 8] 

31 for num_stages in [2, 3, 4, 5] 

32 ], 

33 key=["BT"], 

34) 

35@triton.jit(do_not_specialize=["T"]) 

36def solve_tril_16x16_kernel( 

37 A, 

38 Ai, 

39 cu_seqlens, 

40 chunk_indices, 

41 T, 

42 H: tl.constexpr, 

43 BT: tl.constexpr, 

44 USE_TMA: tl.constexpr, 

45 IS_VARLEN: tl.constexpr, 

46 DOT_PRECISION: tl.constexpr, 

47): 

48 i_t, i_bh = tl.program_id(0), tl.program_id(1) 

49 i_b, i_h = i_bh // H, i_bh % H 

50 if IS_VARLEN: 

51 i_n, i_t = ( 

52 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

53 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

54 ) 

55 bos, eos = ( 

56 tl.load(cu_seqlens + i_n).to(tl.int32), 

57 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

58 ) 

59 T = eos - bos 

60 else: 

61 bos, eos = i_b * T, i_b * T + T 

62 o_i = tl.arange(0, 16) 

63 m_A = o_i[:, None] > o_i[None, :] 

64 m_I = o_i[:, None] == o_i[None, :] 

65 

66 A = A + (bos * H + i_h) * BT 

67 Ai = Ai + (bos * H + i_h) * 16 

68 

69 offset = (i_t * 16) % BT 

70 if not USE_TMA: 

71 p_A = tl.make_block_ptr( 

72 A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0) 

73 ) 

74 # [16, 16] 

75 b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) 

76 else: 

77 desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) 

78 desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16]) 

79 b_A = desc.load([i_t * 16, offset]).to(tl.float32) 

80 b_A = -tl.where(m_A, b_A, 0) 

81 

82 for i in range(2, min(16, T - i_t * 16)): 

83 # [16] 

84 b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) 

85 b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) 

86 b_A = tl.where((o_i == i)[:, None], b_a, b_A) 

87 b_A += m_I 

88 if not USE_TMA: 

89 p_Ai = tl.make_block_ptr( 

90 Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0) 

91 ) 

92 tl.store( 

93 p_Ai, 

94 b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), 

95 boundary_check=(0, 1), 

96 ) 

97 else: 

98 desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) 

99 

100 

101@libentry() 

102@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) 

103@libtuner( 

104 configs=[ 

105 triton.Config({}, num_warps=num_warps, num_stages=num_stages) 

106 for num_warps in [1, 2, 4, 8] 

107 for num_stages in [2, 3, 4, 5] 

108 ], 

109 key=["H", "BT", "IS_VARLEN"], 

110) 

111@triton.jit(do_not_specialize=["T"]) 

112def merge_16x16_to_32x32_inverse_kernel( 

113 A, 

114 Ai, 

115 cu_seqlens, 

116 chunk_indices, 

117 T, 

118 H: tl.constexpr, 

119 BT: tl.constexpr, 

120 USE_TMA: tl.constexpr, 

121 IS_VARLEN: tl.constexpr, 

122 DOT_PRECISION: tl.constexpr, 

123): 

124 i_t, i_bh = tl.program_id(0), tl.program_id(1) 

125 i_b, i_h = i_bh // H, i_bh % H 

126 if IS_VARLEN: 

127 i_n, i_t = ( 

128 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

129 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

130 ) 

131 bos, eos = ( 

132 tl.load(cu_seqlens + i_n).to(tl.int32), 

133 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

134 ) 

135 T = eos - bos 

136 else: 

137 bos, eos = i_b * T, i_b * T + T 

138 

139 o_i = tl.arange(0, 16) 

140 m_A = o_i[:, None] > o_i[None, :] 

141 m_I = o_i[:, None] == o_i[None, :] 

142 A += (bos * H + i_h) * BT 

143 Ai += (bos * H + i_h) * BT 

144 

145 if not USE_TMA: 

146 p_A_11 = tl.make_block_ptr( 

147 A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) 

148 ) 

149 p_A_22 = tl.make_block_ptr( 

150 A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) 

151 ) 

152 b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) 

153 b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) 

154 else: 

155 desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) 

156 desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) 

157 b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) 

158 b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) 

159 

160 # [16, 16] 

161 b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) 

162 b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) 

163 

164 for i in range(2, min(16, T - i_t * BT)): 

165 b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) 

166 b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) 

167 b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) 

168 for i in range(16 + 2, min(32, T - i_t * BT)): 

169 b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) 

170 b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) 

171 b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) 

172 

173 b_Ai_11 += m_I 

174 b_Ai_22 += m_I 

175 

176 if not USE_TMA: 

177 p_A_21 = tl.make_block_ptr( 

178 A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) 

179 ) 

180 b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) 

181 else: 

182 b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) 

183 

184 b_Ai_21 = -tl.dot( 

185 tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), 

186 b_Ai_11, 

187 input_precision=DOT_PRECISION, 

188 ) 

189 

190 if not USE_TMA: 

191 p_Ai_11 = tl.make_block_ptr( 

192 Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) 

193 ) 

194 p_Ai_21 = tl.make_block_ptr( 

195 Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) 

196 ) 

197 p_Ai_22 = tl.make_block_ptr( 

198 Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) 

199 ) 

200 tl.store( 

201 p_Ai_11, 

202 b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), 

203 boundary_check=(0, 1), 

204 ) 

205 tl.store( 

206 p_Ai_22, 

207 b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), 

208 boundary_check=(0, 1), 

209 ) 

210 tl.store( 

211 p_Ai_21, 

212 b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), 

213 boundary_check=(0, 1), 

214 ) 

215 else: 

216 desc_o.store( 

217 [i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne") 

218 ) 

219 desc_o.store( 

220 [i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne") 

221 ) 

222 desc_o.store( 

223 [i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne") 

224 ) 

225 

226 

227@libentry() 

228@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) 

229@libtuner( 

230 configs=[ 

231 triton.Config({}, num_warps=num_warps, num_stages=num_stages) 

232 for num_warps in [2, 4, 8] 

233 for num_stages in [2, 3, 4, 5] 

234 ], 

235 key=["H", "BT", "IS_VARLEN"], 

236) 

237@triton.jit(do_not_specialize=["T"]) 

238def merge_16x16_to_64x64_inverse_kernel( 

239 A, 

240 Ai, 

241 cu_seqlens, 

242 chunk_indices, 

243 T, 

244 H: tl.constexpr, 

245 BT: tl.constexpr, 

246 USE_TMA: tl.constexpr, 

247 IS_VARLEN: tl.constexpr, 

248 DOT_PRECISION: tl.constexpr, 

249): 

250 i_t, i_bh = tl.program_id(0), tl.program_id(1) 

251 i_b, i_h = i_bh // H, i_bh % H 

252 if IS_VARLEN: 

253 i_n, i_t = ( 

254 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

255 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

256 ) 

257 bos, eos = ( 

258 tl.load(cu_seqlens + i_n).to(tl.int32), 

259 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

260 ) 

261 T = eos - bos 

262 else: 

263 bos, eos = i_b * T, i_b * T + T 

264 

265 o_i = tl.arange(0, 16) 

266 m_A = o_i[:, None] > o_i[None, :] 

267 m_I = o_i[:, None] == o_i[None, :] 

268 A += (bos * H + i_h) * BT 

269 Ai += (bos * H + i_h) * BT 

270 

271 if not USE_TMA: 

272 p_A_11 = tl.make_block_ptr( 

273 A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) 

274 ) 

275 p_A_22 = tl.make_block_ptr( 

276 A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) 

277 ) 

278 p_A_33 = tl.make_block_ptr( 

279 A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0) 

280 ) 

281 p_A_44 = tl.make_block_ptr( 

282 A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0) 

283 ) 

284 b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) 

285 b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) 

286 b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32) 

287 b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32) 

288 else: 

289 desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) 

290 desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) 

291 b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) 

292 b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) 

293 b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32) 

294 b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32) 

295 

296 # [16, 16] 

297 b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) 

298 b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) 

299 b_Ai_33 = -tl.where(m_A, b_Ai_33, 0) 

300 b_Ai_44 = -tl.where(m_A, b_Ai_44, 0) 

301 

302 for i in range(2, min(16, T - i_t * BT)): 

303 b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) 

304 b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) 

305 b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) 

306 for i in range(16 + 2, min(32, T - i_t * BT)): 

307 b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) 

308 b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) 

309 b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) 

310 for i in range(32 + 2, min(48, T - i_t * BT)): 

311 b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32) 

312 b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0) 

313 b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33) 

314 for i in range(48 + 2, min(64, T - i_t * BT)): 

315 b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48) 

316 b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0) 

317 b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44) 

318 b_Ai_11 += m_I 

319 b_Ai_22 += m_I 

320 b_Ai_33 += m_I 

321 b_Ai_44 += m_I 

322 

323 if not USE_TMA: 

324 p_A_21 = tl.make_block_ptr( 

325 A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) 

326 ) 

327 p_A_31 = tl.make_block_ptr( 

328 A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0) 

329 ) 

330 p_A_32 = tl.make_block_ptr( 

331 A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0) 

332 ) 

333 p_A_41 = tl.make_block_ptr( 

334 A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0) 

335 ) 

336 p_A_42 = tl.make_block_ptr( 

337 A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0) 

338 ) 

339 p_A_43 = tl.make_block_ptr( 

340 A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0) 

341 ) 

342 b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) 

343 b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) 

344 b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) 

345 b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) 

346 b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) 

347 b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) 

348 else: 

349 b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) 

350 b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32) 

351 b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32) 

352 b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32) 

353 b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32) 

354 b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32) 

355 

356 b_Ai_21 = -tl.dot( 

357 tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), 

358 b_Ai_11, 

359 input_precision=DOT_PRECISION, 

360 ) 

361 b_Ai_32 = -tl.dot( 

362 tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), 

363 b_Ai_22, 

364 input_precision=DOT_PRECISION, 

365 ) 

366 b_Ai_43 = -tl.dot( 

367 tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), 

368 b_Ai_33, 

369 input_precision=DOT_PRECISION, 

370 ) 

371 

372 b_Ai_31 = -tl.dot( 

373 b_Ai_33, 

374 tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) 

375 + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION), 

376 input_precision=DOT_PRECISION, 

377 ) 

378 b_Ai_42 = -tl.dot( 

379 b_Ai_44, 

380 tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) 

381 + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION), 

382 input_precision=DOT_PRECISION, 

383 ) 

384 b_Ai_41 = -tl.dot( 

385 b_Ai_44, 

386 tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) 

387 + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) 

388 + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION), 

389 input_precision=DOT_PRECISION, 

390 ) 

391 

392 if not USE_TMA: 

393 p_Ai_11 = tl.make_block_ptr( 

394 Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) 

395 ) 

396 p_Ai_22 = tl.make_block_ptr( 

397 Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) 

398 ) 

399 p_Ai_33 = tl.make_block_ptr( 

400 Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0) 

401 ) 

402 p_Ai_44 = tl.make_block_ptr( 

403 Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0) 

404 ) 

405 p_Ai_21 = tl.make_block_ptr( 

406 Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) 

407 ) 

408 p_Ai_31 = tl.make_block_ptr( 

409 Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0) 

410 ) 

411 p_Ai_32 = tl.make_block_ptr( 

412 Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0) 

413 ) 

414 p_Ai_41 = tl.make_block_ptr( 

415 Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0) 

416 ) 

417 p_Ai_42 = tl.make_block_ptr( 

418 Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0) 

419 ) 

420 p_Ai_43 = tl.make_block_ptr( 

421 Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0) 

422 ) 

423 tl.store( 

424 p_Ai_11, 

425 b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), 

426 boundary_check=(0, 1), 

427 ) 

428 tl.store( 

429 p_Ai_22, 

430 b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), 

431 boundary_check=(0, 1), 

432 ) 

433 tl.store( 

434 p_Ai_33, 

435 b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), 

436 boundary_check=(0, 1), 

437 ) 

438 tl.store( 

439 p_Ai_44, 

440 b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), 

441 boundary_check=(0, 1), 

442 ) 

443 tl.store( 

444 p_Ai_21, 

445 b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), 

446 boundary_check=(0, 1), 

447 ) 

448 tl.store( 

449 p_Ai_31, 

450 b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), 

451 boundary_check=(0, 1), 

452 ) 

453 tl.store( 

454 p_Ai_32, 

455 b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), 

456 boundary_check=(0, 1), 

457 ) 

458 tl.store( 

459 p_Ai_41, 

460 b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), 

461 boundary_check=(0, 1), 

462 ) 

463 tl.store( 

464 p_Ai_42, 

465 b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), 

466 boundary_check=(0, 1), 

467 ) 

468 tl.store( 

469 p_Ai_43, 

470 b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), 

471 boundary_check=(0, 1), 

472 ) 

473 else: 

474 desc_o.store( 

475 [i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne") 

476 ) 

477 desc_o.store( 

478 [i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne") 

479 ) 

480 desc_o.store( 

481 [i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne") 

482 ) 

483 desc_o.store( 

484 [i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne") 

485 ) 

486 desc_o.store( 

487 [i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne") 

488 ) 

489 desc_o.store( 

490 [i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne") 

491 ) 

492 desc_o.store( 

493 [i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne") 

494 ) 

495 desc_o.store( 

496 [i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne") 

497 ) 

498 desc_o.store( 

499 [i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne") 

500 ) 

501 desc_o.store( 

502 [i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne") 

503 ) 

504 

505 

506@input_guard 

507def solve_tril( 

508 A: torch.Tensor, 

509 cu_seqlens: torch.Tensor | None = None, 

510 output_dtype: torch.dtype = torch.float, 

511) -> torch.Tensor: 

512 """ 

513 Compute the inverse of the matrix I + A 

514 A should be strictly lower triangular, i.e., A.triu() == 0. 

515 

516 Args: 

517 A (torch.Tensor): 

518 [B, T, H, BT], where BT should only be 16, 32, or 64. 

519 cu_seqlens (torch.Tensor): 

520 The cumulative sequence lengths of the input tensor. Default: `None`. 

521 output_dtype (torch.dtype): 

522 The dtype of the output tensor. Default: `torch.float`. 

523 If `None`, the output dtype will be the same as the input dtype. 

524 

525 Returns: 

526 (I + A)^-1 with the same shape as A 

527 """ 

528 assert A.shape[-1] in [16, 32, 64] 

529 output_dtype = A.dtype if output_dtype is None else output_dtype 

530 

531 B, T, H, BT = A.shape 

532 chunk_indices = ( 

533 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None 

534 ) 

535 NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) 

536 

537 Ai = torch.zeros_like(A, dtype=output_dtype) 

538 if BT == 16: 

539 merge_fn = solve_tril_16x16_kernel 

540 elif BT == 32: 

541 merge_fn = merge_16x16_to_32x32_inverse_kernel 

542 elif BT == 64: 

543 merge_fn = merge_16x16_to_64x64_inverse_kernel 

544 

545 merge_fn[NT, B * H]( 

546 A=A, 

547 Ai=Ai, 

548 cu_seqlens=cu_seqlens, 

549 chunk_indices=chunk_indices, 

550 T=T, 

551 H=H, 

552 BT=BT, 

553 USE_TMA=is_tma_supported, 

554 DOT_PRECISION=FLA_TRIL_PRECISION, 

555 ) 

556 return Ai