Coverage for src/flag_gems/fused/FLA/fused_recurrent.py: 8%

173 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6import logging 

7 

8import torch 

9import triton 

10import triton.language as tl 

11 

12from flag_gems.fused.FLA.triton_ops_helper import exp 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.heuristics( 

18 { 

19 "USE_INITIAL_STATE": lambda args: args["h0"] is not None, 

20 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

21 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, 

22 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, 

23 } 

24) 

25@triton.jit(do_not_specialize=["N", "T"]) 

26# This kernel is specialized for Qwen3-Next model. 

27# It requires modifications to the calling logic for Qwen3-Next: 

28# Refer to the rearrange_mixed_qkv logic in the benchmark, where setting contiguous=False 

29# can provide a certain performance boost by avoiding unnecessary contiguous operations. 

30def fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel( 

31 q, 

32 k, 

33 v, 

34 g, 

35 beta, 

36 o, 

37 h0, 

38 ht, 

39 cu_seqlens, 

40 ssm_state_indices, 

41 num_accepted_tokens, 

42 scale, 

43 N: tl.int64, 

44 T: tl.int64, 

45 # stride_q_b: tl.int64, 

46 stride_q_t: tl.int64, 

47 stride_q_h: tl.int64, 

48 stride_q_k: tl.int64, 

49 # stride_k_b: tl.int64, 

50 stride_k_t: tl.int64, 

51 stride_k_h: tl.int64, 

52 stride_k_k: tl.int64, 

53 # stride_v_b: tl.int64, 

54 stride_v_t: tl.int64, 

55 stride_v_hv: tl.int64, 

56 stride_v_v: tl.int64, 

57 B: tl.constexpr, 

58 H: tl.constexpr, 

59 HV: tl.constexpr, 

60 K: tl.constexpr, 

61 V: tl.constexpr, 

62 BK: tl.constexpr, 

63 BV: tl.constexpr, 

64 stride_init_state_token: tl.constexpr, 

65 stride_final_state_token: tl.constexpr, 

66 stride_indices_seq: tl.constexpr, 

67 stride_indices_tok: tl.constexpr, 

68 USE_INITIAL_STATE: tl.constexpr, 

69 INPLACE_FINAL_STATE: tl.constexpr, 

70 IS_BETA_HEADWISE: tl.constexpr, 

71 USE_QK_L2NORM_IN_KERNEL: tl.constexpr, 

72 IS_VARLEN: tl.constexpr, 

73 IS_CONTINUOUS_BATCHING: tl.constexpr, 

74 IS_SPEC_DECODING: tl.constexpr, 

75): 

76 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

77 i_n, i_hv = i_nh // HV, i_nh % HV 

78 i_h = i_hv // (HV // H) 

79 if IS_VARLEN: 

80 bos, eos = ( 

81 tl.load(cu_seqlens + i_n).to(tl.int64), 

82 tl.load(cu_seqlens + i_n + 1).to(tl.int64), 

83 ) 

84 all = T 

85 T = eos - bos 

86 else: 

87 bos, eos = i_n * T, i_n * T + T 

88 all = B * T 

89 

90 if T == 0: 

91 # no tokens to process for this sequence 

92 return 

93 

94 o_k = i_k * BK + tl.arange(0, BK) 

95 o_v = i_v * BV + tl.arange(0, BV) 

96 

97 p_q = q + bos * stride_q_t + i_h * stride_q_h + o_k * stride_q_k 

98 p_k = k + bos * stride_k_t + i_h * stride_k_h + o_k * stride_k_k 

99 p_v = v + bos * stride_v_t + i_hv * stride_v_hv + o_v * stride_v_v 

100 if IS_BETA_HEADWISE: 

101 p_beta = beta + (bos * HV + i_hv) * V + o_v 

102 else: 

103 p_beta = beta + bos * HV + i_hv 

104 

105 p_g = g + bos * HV + i_hv 

106 

107 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v 

108 

109 mask_k = o_k < K 

110 mask_v = o_v < V 

111 mask_h = mask_k[:, None] & mask_v[None, :] 

112 

113 b_h = tl.zeros([BK, BV], dtype=tl.float32) 

114 if USE_INITIAL_STATE: 

115 if IS_CONTINUOUS_BATCHING: 

116 if IS_SPEC_DECODING: 

117 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 

118 else: 

119 i_t = 0 

120 p_h0 = ( 

121 h0 

122 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

123 tl.int64 

124 ) 

125 * stride_init_state_token 

126 ) 

127 else: 

128 p_h0 = h0 + bos * HV * K * V 

129 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

130 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) 

131 

132 for i_t in range(0, T): 

133 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) 

134 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) 

135 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) 

136 

137 if USE_QK_L2NORM_IN_KERNEL: 

138 b_q *= tl.rsqrt(tl.sum(b_q * b_q) + 1e-6) 

139 b_k *= tl.rsqrt(tl.sum(b_k * b_k) + 1e-6) 

140 b_q *= scale 

141 # [BK, BV] 

142 b_g = tl.load(p_g).to(tl.float32) 

143 b_h *= exp(b_g) 

144 # [BV] 

145 b_v -= tl.sum(b_h * b_k[:, None], 0) 

146 if IS_BETA_HEADWISE: 

147 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) 

148 else: 

149 b_beta = tl.load(p_beta).to(tl.float32) 

150 b_v *= b_beta 

151 # [BK, BV] 

152 b_h += b_k[:, None] * b_v[None, :] 

153 # [BV] 

154 b_o = tl.sum(b_h * b_q[:, None], 0) 

155 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) 

156 

157 # keep the states for multi-query tokens 

158 if INPLACE_FINAL_STATE: 

159 p_ht = ( 

160 ht 

161 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

162 tl.int64 

163 ) 

164 * stride_final_state_token 

165 ) 

166 else: 

167 p_ht = ht + (bos + i_t) * stride_final_state_token 

168 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

169 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) 

170 

171 p_q += stride_q_t 

172 p_k += stride_k_t 

173 p_v += stride_v_t 

174 p_o += HV * V 

175 p_g += HV 

176 p_beta += HV * (V if IS_BETA_HEADWISE else 1) 

177 

178 

179@triton.heuristics( 

180 { 

181 "USE_INITIAL_STATE": lambda args: args["h0"] is not None, 

182 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

183 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, 

184 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, 

185 } 

186) 

187@triton.jit(do_not_specialize=["N", "T"]) 

188def fused_recurrent_gated_delta_rule_fwd_kernel( 

189 q, 

190 k, 

191 v, 

192 g, 

193 beta, 

194 o, 

195 h0, 

196 ht, 

197 cu_seqlens, 

198 ssm_state_indices, 

199 num_accepted_tokens, 

200 scale, 

201 N: tl.int64, # num of sequences 

202 T: tl.int64, # num of tokens 

203 B: tl.constexpr, 

204 H: tl.constexpr, 

205 HV: tl.constexpr, 

206 K: tl.constexpr, 

207 V: tl.constexpr, 

208 BK: tl.constexpr, 

209 BV: tl.constexpr, 

210 stride_init_state_token: tl.constexpr, 

211 stride_final_state_token: tl.constexpr, 

212 stride_indices_seq: tl.constexpr, 

213 stride_indices_tok: tl.constexpr, 

214 USE_INITIAL_STATE: tl.constexpr, # whether to use initial state 

215 INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace 

216 IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, 

217 USE_QK_L2NORM_IN_KERNEL: tl.constexpr, 

218 IS_VARLEN: tl.constexpr, 

219 IS_CONTINUOUS_BATCHING: tl.constexpr, 

220 IS_SPEC_DECODING: tl.constexpr, 

221 IS_KDA: tl.constexpr, 

222): 

223 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

224 i_n, i_hv = i_nh // HV, i_nh % HV 

225 i_h = i_hv // (HV // H) 

226 if IS_VARLEN: 

227 bos, eos = ( 

228 tl.load(cu_seqlens + i_n).to(tl.int64), 

229 tl.load(cu_seqlens + i_n + 1).to(tl.int64), 

230 ) 

231 all = T 

232 T = eos - bos 

233 else: 

234 bos, eos = i_n * T, i_n * T + T 

235 all = B * T 

236 

237 if T == 0: 

238 # no tokens to process for this sequence 

239 return 

240 

241 o_k = i_k * BK + tl.arange(0, BK) 

242 o_v = i_v * BV + tl.arange(0, BV) 

243 

244 p_q = q + (bos * H + i_h) * K + o_k 

245 p_k = k + (bos * H + i_h) * K + o_k 

246 p_v = v + (bos * HV + i_hv) * V + o_v 

247 if IS_BETA_HEADWISE: 

248 p_beta = beta + (bos * HV + i_hv) * V + o_v 

249 else: 

250 p_beta = beta + bos * HV + i_hv 

251 

252 if not IS_KDA: 

253 p_g = g + bos * HV + i_hv 

254 else: 

255 p_gk = g + (bos * HV + i_hv) * K + o_k 

256 

257 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v 

258 

259 mask_k = o_k < K 

260 mask_v = o_v < V 

261 mask_h = mask_k[:, None] & mask_v[None, :] 

262 

263 b_h = tl.zeros([BK, BV], dtype=tl.float32) 

264 if USE_INITIAL_STATE: 

265 if IS_CONTINUOUS_BATCHING: 

266 if IS_SPEC_DECODING: 

267 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 

268 else: 

269 i_t = 0 

270 p_h0 = ( 

271 h0 

272 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

273 tl.int64 

274 ) 

275 * stride_init_state_token 

276 ) 

277 else: 

278 p_h0 = h0 + bos * HV * K * V 

279 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

280 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) 

281 

282 for i_t in range(0, T): 

283 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) 

284 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) 

285 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) 

286 

287 if USE_QK_L2NORM_IN_KERNEL: 

288 b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) 

289 b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) 

290 b_q = b_q * scale 

291 # [BK, BV] 

292 if not IS_KDA: 

293 b_g = tl.load(p_g).to(tl.float32) 

294 b_h *= exp(b_g) 

295 else: 

296 b_gk = tl.load(p_gk).to(tl.float32) 

297 b_h *= exp(b_gk[:, None]) 

298 # [BV] 

299 b_v -= tl.sum(b_h * b_k[:, None], 0) 

300 if IS_BETA_HEADWISE: 

301 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) 

302 else: 

303 b_beta = tl.load(p_beta).to(tl.float32) 

304 b_v *= b_beta 

305 # [BK, BV] 

306 b_h += b_k[:, None] * b_v[None, :] 

307 # [BV] 

308 b_o = tl.sum(b_h * b_q[:, None], 0) 

309 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) 

310 

311 # keep the states for multi-query tokens 

312 if INPLACE_FINAL_STATE: 

313 p_ht = ( 

314 ht 

315 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

316 tl.int64 

317 ) 

318 * stride_final_state_token 

319 ) 

320 else: 

321 p_ht = ht + (bos + i_t) * stride_final_state_token 

322 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

323 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) 

324 

325 p_q += H * K 

326 p_k += H * K 

327 p_o += HV * V 

328 p_v += HV * V 

329 if not IS_KDA: 

330 p_g += HV 

331 else: 

332 p_gk += HV * K 

333 p_beta += HV * (V if IS_BETA_HEADWISE else 1) 

334 

335 

336def fused_recurrent_gated_delta_rule_fwd( 

337 q: torch.Tensor, 

338 k: torch.Tensor, 

339 v: torch.Tensor, 

340 g: torch.Tensor, 

341 beta: torch.Tensor, 

342 scale: float, 

343 initial_state: torch.Tensor, 

344 inplace_final_state: bool = True, 

345 cu_seqlens: torch.LongTensor | None = None, 

346 ssm_state_indices: torch.Tensor | None = None, 

347 num_accepted_tokens: torch.Tensor | None = None, 

348 use_qk_l2norm_in_kernel: bool = False, 

349) -> tuple[torch.Tensor, torch.Tensor]: 

350 B, T, H, K, V = *k.shape, v.shape[-1] 

351 HV = v.shape[2] 

352 N = B if cu_seqlens is None else len(cu_seqlens) - 1 

353 BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) 

354 NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) 

355 assert NK == 1, "NK > 1 is not supported yet" 

356 num_stages = 3 

357 num_warps = 1 

358 qkv_contiguous = ( 

359 (q.stride(0) == q.stride(1) + q.stride(2)) 

360 and (k.stride(0) == k.stride(1) + k.stride(2)) 

361 and (v.stride(0) == v.stride(1) + v.stride(2)) 

362 ) 

363 

364 o = q.new_empty(NK, *v.shape) 

365 if inplace_final_state: 

366 final_state = initial_state 

367 else: 

368 final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) 

369 

370 stride_init_state_token = initial_state.stride(0) 

371 stride_final_state_token = final_state.stride(0) 

372 

373 if ssm_state_indices is None: 

374 stride_indices_seq, stride_indices_tok = 1, 1 

375 elif ssm_state_indices.ndim == 1: 

376 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 

377 else: 

378 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() 

379 

380 grid = (NK, NV, N * HV) 

381 if qkv_contiguous: 

382 fused_recurrent_gated_delta_rule_fwd_kernel[grid]( 

383 q=q, 

384 k=k, 

385 v=v, 

386 g=g, 

387 beta=beta, 

388 o=o, 

389 h0=initial_state, 

390 ht=final_state, 

391 cu_seqlens=cu_seqlens, 

392 ssm_state_indices=ssm_state_indices, 

393 num_accepted_tokens=num_accepted_tokens, 

394 scale=scale, 

395 N=N, 

396 T=T, 

397 B=B, 

398 H=H, 

399 HV=HV, 

400 K=K, 

401 V=V, 

402 BK=BK, 

403 BV=BV, 

404 stride_init_state_token=stride_init_state_token, 

405 stride_final_state_token=stride_final_state_token, 

406 stride_indices_seq=stride_indices_seq, 

407 stride_indices_tok=stride_indices_tok, 

408 IS_BETA_HEADWISE=beta.ndim == v.ndim, 

409 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, 

410 INPLACE_FINAL_STATE=inplace_final_state, 

411 IS_KDA=False, 

412 num_warps=num_warps, 

413 num_stages=num_stages, 

414 ) 

415 else: 

416 logger.debug( 

417 "GEMS fused_recurrent_gated_delta_rule_fwd, " 

418 "[q.shape]: %s, [q.stride]: %s, " 

419 "[k.shape]: %s, [k.stride]: %s, " 

420 "[v.shape]: %s, [v.stride]: %s, " 

421 "[g.shape]: %s, [beta.shape]: %s, [initial_state.shape]: %s, " 

422 "[cu_seqlens.shape]: %s, N: %s, T: %s, B: %s, H: %s, HV: %s, K: %s, V: %s", 

423 q.shape, 

424 q.stride(), 

425 k.shape, 

426 k.stride(), 

427 v.shape, 

428 v.stride(), 

429 g.shape, 

430 beta.shape, 

431 initial_state.shape, 

432 cu_seqlens.shape, 

433 N, 

434 T, 

435 B, 

436 H, 

437 HV, 

438 K, 

439 V, 

440 ) 

441 fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel[grid]( 

442 q=q, 

443 k=k, 

444 v=v, 

445 g=g, 

446 beta=beta, 

447 o=o, 

448 h0=initial_state, 

449 ht=final_state, 

450 cu_seqlens=cu_seqlens, 

451 ssm_state_indices=ssm_state_indices, 

452 num_accepted_tokens=num_accepted_tokens, 

453 scale=scale, 

454 N=N, 

455 T=T, 

456 B=B, 

457 H=H, 

458 HV=HV, 

459 K=K, 

460 V=V, 

461 BK=BK, 

462 BV=BV, 

463 stride_init_state_token=stride_init_state_token, 

464 stride_final_state_token=stride_final_state_token, 

465 stride_indices_seq=stride_indices_seq, 

466 stride_indices_tok=stride_indices_tok, 

467 # stride_q_b=q.stride(0), 

468 stride_q_t=q.stride(1), 

469 stride_q_h=q.stride(2), 

470 stride_q_k=q.stride(3), 

471 # stride_k_b=k.stride(0), 

472 stride_k_t=k.stride(1), 

473 stride_k_h=k.stride(2), 

474 stride_k_k=k.stride(3), 

475 # stride_v_b=v.stride(0), 

476 stride_v_t=v.stride(1), 

477 stride_v_hv=v.stride(2), 

478 stride_v_v=v.stride(3), 

479 IS_BETA_HEADWISE=beta.ndim == v.ndim, 

480 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, 

481 INPLACE_FINAL_STATE=inplace_final_state, 

482 IS_SPEC_DECODING=num_accepted_tokens is not None, 

483 IS_CONTINUOUS_BATCHING=ssm_state_indices is not None, 

484 IS_VARLEN=cu_seqlens is not None, 

485 USE_INITIAL_STATE=initial_state is not None, 

486 num_warps=num_warps, 

487 num_stages=num_stages, 

488 ) 

489 

490 o = o.squeeze(0) 

491 return o, final_state