Coverage for src/flag_gems/fused/FLA/fused_cumsum_kkt_solve_tril.py: 10%

157 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1# Copyright (c) 2025 FlagGems. All rights reserved. 

2# Fused cumsum + KKT + solve_tril for chunk_gated_delta_rule. Returns g_out, A_inv; w_u is separate. 

3# License: Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0) 

4 

5from __future__ import annotations 

6 

7import torch 

8import triton 

9import triton.language as tl 

10 

11from flag_gems.fused.FLA.index import prepare_chunk_indices 

12from flag_gems.fused.FLA.solve_tril import FLA_TRIL_PRECISION 

13from flag_gems.fused.FLA.triton_ops_helper import exp, make_tensor_descriptor 

14from flag_gems.fused.FLA.utils import is_tma_supported 

15from flag_gems.utils import libentry, libtuner 

16 

17 

18@libentry() 

19@triton.heuristics( 

20 { 

21 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

22 "USE_G": lambda args: True, 

23 } 

24) 

25@libtuner( 

26 configs=[ 

27 triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) 

28 for BK in [32, 64, 128] 

29 for num_warps in [2, 4, 8] 

30 for num_stages in [2, 3, 4] 

31 ], 

32 key=["H", "K", "BT", "IS_VARLEN"], 

33) 

34@triton.jit(do_not_specialize=["T"]) 

35def chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril_kernel( 

36 g_in, 

37 g_out, 

38 k, 

39 beta, 

40 A, 

41 A_inv, 

42 cu_seqlens, 

43 chunk_indices, 

44 T, 

45 H: tl.constexpr, 

46 Hg: tl.constexpr, 

47 K: tl.constexpr, 

48 BT: tl.constexpr, 

49 BK: tl.constexpr, 

50 IS_VARLEN: tl.constexpr, 

51 USE_G: tl.constexpr, 

52 USE_TMA: tl.constexpr, 

53 DOT_PRECISION: tl.constexpr, 

54): 

55 i_t, i_bh = tl.program_id(0), tl.program_id(1) 

56 i_b, i_h = i_bh // H, i_bh % H 

57 if IS_VARLEN: 

58 i_n, i_t = ( 

59 tl.load(chunk_indices + i_t * 2).to(tl.int32), 

60 tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), 

61 ) 

62 bos, eos = ( 

63 tl.load(cu_seqlens + i_n).to(tl.int32), 

64 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

65 ) 

66 T = eos - bos 

67 else: 

68 bos, eos = i_b * T, i_b * T + T 

69 

70 # ---------- cumsum ---------- 

71 p_g_in = tl.make_block_ptr( 

72 g_in + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) 

73 ) 

74 p_g_out = tl.make_block_ptr( 

75 g_out + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) 

76 ) 

77 b_g = tl.load(p_g_in, boundary_check=(0,)).to(tl.float32) 

78 b_g = tl.cumsum(b_g, axis=0) 

79 tl.store(p_g_out, b_g.to(p_g_out.dtype.element_ty), boundary_check=(0,)) 

80 

81 # ---------- KKT (write L to A) ---------- 

82 o_t = i_t * BT + tl.arange(0, BT) 

83 m_t = o_t < T 

84 p_beta = tl.make_block_ptr( 

85 beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) 

86 ) 

87 b_beta = tl.load(p_beta, boundary_check=(0,)) 

88 b_A = tl.zeros([BT, BT], dtype=tl.float32) 

89 for i_k in range(tl.cdiv(K, BK)): 

90 p_k = tl.make_block_ptr( 

91 k + (bos * Hg + i_h // (H // Hg)) * K, 

92 (T, K), 

93 (Hg * K, 1), 

94 (i_t * BT, i_k * BK), 

95 (BT, BK), 

96 (1, 0), 

97 ) 

98 b_k = tl.load(p_k, boundary_check=(0, 1)) 

99 b_kb = b_k * b_beta[:, None] 

100 b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) 

101 if USE_G: 

102 b_g_diff = b_g[:, None] - b_g[None, :] 

103 b_A = b_A * exp(b_g_diff) 

104 m_A_kkt = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) 

105 b_A = tl.where(m_A_kkt, b_A, 0) 

106 p_A = tl.make_block_ptr( 

107 A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) 

108 ) 

109 tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) 

110 

111 # ---------- solve_tril (read A, write A_inv) ---------- 

112 o_i = tl.arange(0, 16) 

113 m_A = o_i[:, None] > o_i[None, :] 

114 m_I = o_i[:, None] == o_i[None, :] 

115 A_base = A + (bos * H + i_h) * BT 

116 A_inv_base = A_inv + (bos * H + i_h) * BT 

117 

118 if not USE_TMA: 

119 p_A_11 = tl.make_block_ptr( 

120 A_base, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) 

121 ) 

122 p_A_22 = tl.make_block_ptr( 

123 A_base, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) 

124 ) 

125 p_A_33 = tl.make_block_ptr( 

126 A_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0) 

127 ) 

128 p_A_44 = tl.make_block_ptr( 

129 A_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0) 

130 ) 

131 b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) 

132 b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) 

133 b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32) 

134 b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32) 

135 else: 

136 desc = make_tensor_descriptor(A_base, [T, BT], [H * BT, 1], [16, 16]) 

137 b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) 

138 b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) 

139 b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32) 

140 b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32) 

141 

142 b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) 

143 b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) 

144 b_Ai_33 = -tl.where(m_A, b_Ai_33, 0) 

145 b_Ai_44 = -tl.where(m_A, b_Ai_44, 0) 

146 

147 for i in range(2, min(16, T - i_t * BT)): 

148 b_a_11 = -tl.load(A_base + (i_t * BT + i) * H * BT + o_i) 

149 b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) 

150 b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) 

151 for i in range(16 + 2, min(32, T - i_t * BT)): 

152 b_a_22 = -tl.load(A_base + (i_t * BT + i) * H * BT + o_i + 16) 

153 b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) 

154 b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) 

155 for i in range(32 + 2, min(48, T - i_t * BT)): 

156 b_a_33 = -tl.load(A_base + (i_t * BT + i) * H * BT + o_i + 32) 

157 b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0) 

158 b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33) 

159 for i in range(48 + 2, min(64, T - i_t * BT)): 

160 b_a_44 = -tl.load(A_base + (i_t * BT + i) * H * BT + o_i + 48) 

161 b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0) 

162 b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44) 

163 b_Ai_11 += m_I 

164 b_Ai_22 += m_I 

165 b_Ai_33 += m_I 

166 b_Ai_44 += m_I 

167 

168 if not USE_TMA: 

169 p_A_21 = tl.make_block_ptr( 

170 A_base, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) 

171 ) 

172 p_A_31 = tl.make_block_ptr( 

173 A_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0) 

174 ) 

175 p_A_32 = tl.make_block_ptr( 

176 A_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0) 

177 ) 

178 p_A_41 = tl.make_block_ptr( 

179 A_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0) 

180 ) 

181 p_A_42 = tl.make_block_ptr( 

182 A_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0) 

183 ) 

184 p_A_43 = tl.make_block_ptr( 

185 A_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0) 

186 ) 

187 b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) 

188 b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) 

189 b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) 

190 b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) 

191 b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) 

192 b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) 

193 else: 

194 b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) 

195 b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32) 

196 b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32) 

197 b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32) 

198 b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32) 

199 b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32) 

200 

201 b_Ai_21 = -tl.dot( 

202 tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), 

203 b_Ai_11, 

204 input_precision=DOT_PRECISION, 

205 ) 

206 b_Ai_32 = -tl.dot( 

207 tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), 

208 b_Ai_22, 

209 input_precision=DOT_PRECISION, 

210 ) 

211 b_Ai_43 = -tl.dot( 

212 tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), 

213 b_Ai_33, 

214 input_precision=DOT_PRECISION, 

215 ) 

216 b_Ai_31 = -tl.dot( 

217 b_Ai_33, 

218 tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) 

219 + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION), 

220 input_precision=DOT_PRECISION, 

221 ) 

222 b_Ai_42 = -tl.dot( 

223 b_Ai_44, 

224 tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) 

225 + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION), 

226 input_precision=DOT_PRECISION, 

227 ) 

228 b_Ai_41 = -tl.dot( 

229 b_Ai_44, 

230 tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) 

231 + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) 

232 + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION), 

233 input_precision=DOT_PRECISION, 

234 ) 

235 

236 if not USE_TMA: 

237 p_Ai_11 = tl.make_block_ptr( 

238 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) 

239 ) 

240 p_Ai_22 = tl.make_block_ptr( 

241 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) 

242 ) 

243 p_Ai_33 = tl.make_block_ptr( 

244 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0) 

245 ) 

246 p_Ai_44 = tl.make_block_ptr( 

247 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0) 

248 ) 

249 p_Ai_21 = tl.make_block_ptr( 

250 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) 

251 ) 

252 p_Ai_31 = tl.make_block_ptr( 

253 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0) 

254 ) 

255 p_Ai_32 = tl.make_block_ptr( 

256 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0) 

257 ) 

258 p_Ai_41 = tl.make_block_ptr( 

259 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0) 

260 ) 

261 p_Ai_42 = tl.make_block_ptr( 

262 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0) 

263 ) 

264 p_Ai_43 = tl.make_block_ptr( 

265 A_inv_base, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0) 

266 ) 

267 tl.store( 

268 p_Ai_11, 

269 b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), 

270 boundary_check=(0, 1), 

271 ) 

272 tl.store( 

273 p_Ai_22, 

274 b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), 

275 boundary_check=(0, 1), 

276 ) 

277 tl.store( 

278 p_Ai_33, 

279 b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), 

280 boundary_check=(0, 1), 

281 ) 

282 tl.store( 

283 p_Ai_44, 

284 b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), 

285 boundary_check=(0, 1), 

286 ) 

287 tl.store( 

288 p_Ai_21, 

289 b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), 

290 boundary_check=(0, 1), 

291 ) 

292 tl.store( 

293 p_Ai_31, 

294 b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), 

295 boundary_check=(0, 1), 

296 ) 

297 tl.store( 

298 p_Ai_32, 

299 b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), 

300 boundary_check=(0, 1), 

301 ) 

302 tl.store( 

303 p_Ai_41, 

304 b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), 

305 boundary_check=(0, 1), 

306 ) 

307 tl.store( 

308 p_Ai_42, 

309 b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), 

310 boundary_check=(0, 1), 

311 ) 

312 tl.store( 

313 p_Ai_43, 

314 b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), 

315 boundary_check=(0, 1), 

316 ) 

317 else: 

318 desc_o = make_tensor_descriptor(A_inv_base, [T, BT], [H * BT, 1], [16, 16]) 

319 desc_o.store( 

320 [i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne") 

321 ) 

322 desc_o.store( 

323 [i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne") 

324 ) 

325 desc_o.store( 

326 [i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne") 

327 ) 

328 desc_o.store( 

329 [i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne") 

330 ) 

331 desc_o.store( 

332 [i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne") 

333 ) 

334 desc_o.store( 

335 [i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne") 

336 ) 

337 desc_o.store( 

338 [i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne") 

339 ) 

340 desc_o.store( 

341 [i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne") 

342 ) 

343 desc_o.store( 

344 [i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne") 

345 ) 

346 desc_o.store( 

347 [i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne") 

348 ) 

349 

350 

351def chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril( 

352 g: torch.Tensor, 

353 k: torch.Tensor, 

354 beta: torch.Tensor, 

355 cu_seqlens: torch.LongTensor | None = None, 

356 chunk_size: int = 64, 

357 output_dtype: torch.dtype | None = None, 

358) -> tuple[torch.Tensor, torch.Tensor]: 

359 """Fused kernel: cumsum(g) + KKT(L) + solve_tril(L -> inv). Returns (g_out, A_inv). 

360 w_u stays a separate kernel (e.g. recompute_w_u_fwd) for HGMMA.""" 

361 B, T, Hg, K = k.shape 

362 H = beta.shape[-1] 

363 BT = chunk_size 

364 output_dtype = output_dtype or k.dtype 

365 chunk_indices = ( 

366 prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None 

367 ) 

368 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) 

369 

370 g_out = torch.empty_like(g) 

371 A = torch.empty(B, T, H, BT, device=g.device, dtype=torch.float32) 

372 A_inv = torch.zeros(B, T, H, BT, device=g.device, dtype=output_dtype) 

373 

374 def grid(meta): 

375 return (NT, B * H) 

376 

377 chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril_kernel[grid]( 

378 g_in=g, 

379 g_out=g_out, 

380 k=k, 

381 beta=beta, 

382 A=A, 

383 A_inv=A_inv, 

384 cu_seqlens=cu_seqlens, 

385 chunk_indices=chunk_indices, 

386 T=T, 

387 H=H, 

388 Hg=Hg, 

389 K=K, 

390 BT=BT, 

391 IS_VARLEN=cu_seqlens is not None, 

392 USE_TMA=is_tma_supported, 

393 DOT_PRECISION=FLA_TRIL_PRECISION, 

394 ) 

395 return g_out, A_inv