Coverage for src/flag_gems/runtime/backend/_sunrise/fused/fused_recurrent.py: 0%

261 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6import logging 

7 

8import torch 

9import triton 

10import triton.language as tl 

11 

12from flag_gems.fused.FLA.triton_ops_helper import exp 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.heuristics( 

18 { 

19 "USE_INITIAL_STATE": lambda args: args["h0"] is not None, 

20 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

21 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, 

22 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, 

23 } 

24) 

25@triton.jit(do_not_specialize=["N", "T"]) 

26# This kernel is specialized for Qwen3-Next model. 

27# It requires modifications to the calling logic for Qwen3-Next: 

28# Refer to the rearrange_mixed_qkv logic in the benchmark, where setting contiguous=False 

29# can provide a certain performance boost by avoiding unnecessary contiguous operations. 

30def fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel( 

31 q, 

32 k, 

33 v, 

34 g, 

35 beta, 

36 o, 

37 h0, 

38 ht, 

39 cu_seqlens, 

40 ssm_state_indices, 

41 num_accepted_tokens, 

42 scale, 

43 N: tl.int64, 

44 T: tl.int64, 

45 # stride_q_b: tl.int64, 

46 stride_q_t: tl.int64, 

47 stride_q_h: tl.int64, 

48 stride_q_k: tl.int64, 

49 # stride_k_b: tl.int64, 

50 stride_k_t: tl.int64, 

51 stride_k_h: tl.int64, 

52 stride_k_k: tl.int64, 

53 # stride_v_b: tl.int64, 

54 stride_v_t: tl.int64, 

55 stride_v_hv: tl.int64, 

56 stride_v_v: tl.int64, 

57 B: tl.constexpr, 

58 H: tl.constexpr, 

59 HV: tl.constexpr, 

60 K: tl.constexpr, 

61 V: tl.constexpr, 

62 BK: tl.constexpr, 

63 BV: tl.constexpr, 

64 stride_init_state_token: tl.constexpr, 

65 stride_final_state_token: tl.constexpr, 

66 stride_indices_seq: tl.constexpr, 

67 stride_indices_tok: tl.constexpr, 

68 USE_INITIAL_STATE: tl.constexpr, 

69 INPLACE_FINAL_STATE: tl.constexpr, 

70 IS_BETA_HEADWISE: tl.constexpr, 

71 USE_QK_L2NORM_IN_KERNEL: tl.constexpr, 

72 IS_VARLEN: tl.constexpr, 

73 IS_CONTINUOUS_BATCHING: tl.constexpr, 

74 IS_SPEC_DECODING: tl.constexpr, 

75): 

76 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

77 i_n, i_hv = i_nh // HV, i_nh % HV 

78 i_h = i_hv // (HV // H) 

79 if IS_VARLEN: 

80 bos, eos = ( 

81 tl.load(cu_seqlens + i_n).to(tl.int64), 

82 tl.load(cu_seqlens + i_n + 1).to(tl.int64), 

83 ) 

84 all = T 

85 T = eos - bos 

86 else: 

87 bos, eos = i_n * T, i_n * T + T 

88 all = B * T 

89 

90 if T == 0: 

91 # no tokens to process for this sequence 

92 return 

93 

94 o_k = i_k * BK + tl.arange(0, BK) 

95 o_v = i_v * BV + tl.arange(0, BV) 

96 

97 p_q = q + bos * stride_q_t + i_h * stride_q_h + o_k * stride_q_k 

98 p_k = k + bos * stride_k_t + i_h * stride_k_h + o_k * stride_k_k 

99 p_v = v + bos * stride_v_t + i_hv * stride_v_hv + o_v * stride_v_v 

100 if IS_BETA_HEADWISE: 

101 p_beta = beta + (bos * HV + i_hv) * V + o_v 

102 else: 

103 p_beta = beta + bos * HV + i_hv 

104 

105 p_g = g + bos * HV + i_hv 

106 

107 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v 

108 

109 mask_k = o_k < K 

110 mask_v = o_v < V 

111 mask_h = mask_k[:, None] & mask_v[None, :] 

112 

113 b_h = tl.zeros([BK, BV], dtype=tl.float32) 

114 if USE_INITIAL_STATE: 

115 if IS_CONTINUOUS_BATCHING: 

116 if IS_SPEC_DECODING: 

117 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 

118 else: 

119 i_t = 0 

120 p_h0 = ( 

121 h0 

122 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

123 tl.int64 

124 ) 

125 * stride_init_state_token 

126 ) 

127 else: 

128 p_h0 = h0 + bos * HV * K * V 

129 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

130 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) 

131 

132 for i_t in range(0, T): 

133 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) 

134 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) 

135 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) 

136 

137 if USE_QK_L2NORM_IN_KERNEL: 

138 b_q *= tl.rsqrt(tl.sum(b_q * b_q) + 1e-6) 

139 b_k *= tl.rsqrt(tl.sum(b_k * b_k) + 1e-6) 

140 b_q *= scale 

141 # [BK, BV] 

142 b_g = tl.load(p_g).to(tl.float32) 

143 b_h *= exp(b_g) 

144 # [BV] 

145 b_v -= tl.sum(b_h * b_k[:, None], 0) 

146 if IS_BETA_HEADWISE: 

147 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) 

148 else: 

149 b_beta = tl.load(p_beta).to(tl.float32) 

150 b_v *= b_beta 

151 # [BK, BV] 

152 b_h += b_k[:, None] * b_v[None, :] 

153 # [BV] 

154 b_o = tl.sum(b_h * b_q[:, None], 0) 

155 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) 

156 

157 # keep the states for multi-query tokens 

158 if INPLACE_FINAL_STATE: 

159 p_ht = ( 

160 ht 

161 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

162 tl.int64 

163 ) 

164 * stride_final_state_token 

165 ) 

166 else: 

167 p_ht = ht + (bos + i_t) * stride_final_state_token 

168 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

169 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) 

170 

171 p_q += stride_q_t 

172 p_k += stride_k_t 

173 p_v += stride_v_t 

174 p_o += HV * V 

175 p_g += HV 

176 p_beta += HV * (V if IS_BETA_HEADWISE else 1) 

177 

178 

179@triton.heuristics( 

180 { 

181 "USE_INITIAL_STATE": lambda args: args["h0"] is not None, 

182 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

183 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, 

184 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, 

185 } 

186) 

187@triton.jit(do_not_specialize=["N", "T"]) 

188def fused_recurrent_gated_delta_rule_fwd_kernel( 

189 q, 

190 k, 

191 v, 

192 g, 

193 beta, 

194 o, 

195 h0, 

196 ht, 

197 cu_seqlens, 

198 ssm_state_indices, 

199 num_accepted_tokens, 

200 scale, 

201 N: tl.int64, # num of sequences 

202 T: tl.int64, # num of tokens 

203 B: tl.constexpr, 

204 H: tl.constexpr, 

205 HV: tl.constexpr, 

206 K: tl.constexpr, 

207 V: tl.constexpr, 

208 BK: tl.constexpr, 

209 BV: tl.constexpr, 

210 stride_init_state_token: tl.constexpr, 

211 stride_final_state_token: tl.constexpr, 

212 stride_indices_seq: tl.constexpr, 

213 stride_indices_tok: tl.constexpr, 

214 USE_INITIAL_STATE: tl.constexpr, # whether to use initial state 

215 INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace 

216 IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, 

217 USE_QK_L2NORM_IN_KERNEL: tl.constexpr, 

218 IS_VARLEN: tl.constexpr, 

219 IS_CONTINUOUS_BATCHING: tl.constexpr, 

220 IS_SPEC_DECODING: tl.constexpr, 

221 IS_KDA: tl.constexpr, 

222): 

223 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

224 i_n, i_hv = i_nh // HV, i_nh % HV 

225 i_h = i_hv // (HV // H) 

226 if IS_VARLEN: 

227 bos, eos = ( 

228 tl.load(cu_seqlens + i_n).to(tl.int64), 

229 tl.load(cu_seqlens + i_n + 1).to(tl.int64), 

230 ) 

231 all = T 

232 T = eos - bos 

233 else: 

234 bos, eos = i_n * T, i_n * T + T 

235 all = B * T 

236 

237 if T == 0: 

238 # no tokens to process for this sequence 

239 return 

240 

241 o_k = i_k * BK + tl.arange(0, BK) 

242 o_v = i_v * BV + tl.arange(0, BV) 

243 

244 p_q = q + (bos * H + i_h) * K + o_k 

245 p_k = k + (bos * H + i_h) * K + o_k 

246 p_v = v + (bos * HV + i_hv) * V + o_v 

247 if IS_BETA_HEADWISE: 

248 p_beta = beta + (bos * HV + i_hv) * V + o_v 

249 else: 

250 p_beta = beta + bos * HV + i_hv 

251 

252 if not IS_KDA: 

253 p_g = g + bos * HV + i_hv 

254 else: 

255 p_gk = g + (bos * HV + i_hv) * K + o_k 

256 

257 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v 

258 

259 mask_k = o_k < K 

260 mask_v = o_v < V 

261 mask_h = mask_k[:, None] & mask_v[None, :] 

262 

263 b_h = tl.zeros([BK, BV], dtype=tl.float32) 

264 if USE_INITIAL_STATE: 

265 if IS_CONTINUOUS_BATCHING: 

266 if IS_SPEC_DECODING: 

267 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 

268 else: 

269 i_t = 0 

270 p_h0 = ( 

271 h0 

272 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

273 tl.int64 

274 ) 

275 * stride_init_state_token 

276 ) 

277 else: 

278 p_h0 = h0 + bos * HV * K * V 

279 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

280 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) 

281 

282 for i_t in range(0, T): 

283 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) 

284 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) 

285 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) 

286 

287 if USE_QK_L2NORM_IN_KERNEL: 

288 b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) 

289 b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) 

290 b_q = b_q * scale 

291 # [BK, BV] 

292 if not IS_KDA: 

293 b_g = tl.load(p_g).to(tl.float32) 

294 b_h *= exp(b_g) 

295 else: 

296 b_gk = tl.load(p_gk).to(tl.float32) 

297 b_h *= exp(b_gk[:, None]) 

298 # [BV] 

299 b_v -= tl.sum(b_h * b_k[:, None], 0) 

300 if IS_BETA_HEADWISE: 

301 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) 

302 else: 

303 b_beta = tl.load(p_beta).to(tl.float32) 

304 b_v *= b_beta 

305 # [BK, BV] 

306 b_h += b_k[:, None] * b_v[None, :] 

307 # [BV] 

308 b_o = tl.sum(b_h * b_q[:, None], 0) 

309 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) 

310 

311 # keep the states for multi-query tokens 

312 if INPLACE_FINAL_STATE: 

313 p_ht = ( 

314 ht 

315 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

316 tl.int64 

317 ) 

318 * stride_final_state_token 

319 ) 

320 else: 

321 p_ht = ht + (bos + i_t) * stride_final_state_token 

322 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

323 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) 

324 

325 p_q += H * K 

326 p_k += H * K 

327 p_o += HV * V 

328 p_v += HV * V 

329 if not IS_KDA: 

330 p_g += HV 

331 else: 

332 p_gk += HV * K 

333 p_beta += HV * (V if IS_BETA_HEADWISE else 1) 

334 

335 

336@triton.heuristics( 

337 { 

338 "USE_INITIAL_STATE": lambda args: args["h0"] is not None, 

339 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

340 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, 

341 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, 

342 } 

343) 

344@triton.jit(do_not_specialize=["N", "T"]) 

345def fused_recurrent_gated_delta_rule_large_t_fwd_kernel( 

346 q, 

347 k, 

348 v, 

349 g, 

350 beta, 

351 o, 

352 h0, 

353 ht, 

354 cu_seqlens, 

355 ssm_state_indices, 

356 num_accepted_tokens, 

357 scale, 

358 N: tl.int64, # num of sequences 

359 T: tl.int64, # num of tokens 

360 B: tl.constexpr, 

361 H: tl.constexpr, 

362 HV: tl.constexpr, 

363 K: tl.constexpr, 

364 V: tl.constexpr, 

365 BK: tl.constexpr, 

366 BV: tl.constexpr, 

367 stride_init_state_token: tl.constexpr, 

368 stride_final_state_token: tl.constexpr, 

369 stride_indices_seq: tl.constexpr, 

370 stride_indices_tok: tl.constexpr, 

371 USE_INITIAL_STATE: tl.constexpr, # whether to use initial state 

372 INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace 

373 IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, 

374 USE_QK_L2NORM_IN_KERNEL: tl.constexpr, 

375 IS_VARLEN: tl.constexpr, 

376 IS_CONTINUOUS_BATCHING: tl.constexpr, 

377 IS_SPEC_DECODING: tl.constexpr, 

378 IS_KDA: tl.constexpr, 

379): 

380 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

381 i_n, i_hv = i_nh // HV, i_nh % HV 

382 i_h = i_hv // (HV // H) 

383 if IS_VARLEN: 

384 bos, eos = ( 

385 tl.load(cu_seqlens + i_n).to(tl.int64), 

386 tl.load(cu_seqlens + i_n + 1).to(tl.int64), 

387 ) 

388 all = T 

389 T = eos - bos 

390 else: 

391 bos, eos = i_n * T, i_n * T + T 

392 all = B * T 

393 

394 if T == 0: 

395 # no tokens to process for this sequence 

396 return 

397 

398 o_k = i_k * BK + tl.arange(0, BK) 

399 o_v = i_v * BV + tl.arange(0, BV) 

400 

401 p_q = q + (bos * H + i_h) * K + o_k 

402 p_k = k + (bos * H + i_h) * K + o_k 

403 p_v = v + (bos * HV + i_hv) * V + o_v 

404 if IS_BETA_HEADWISE: 

405 p_beta = beta + (bos * HV + i_hv) * V + o_v 

406 else: 

407 p_beta = beta + bos * HV + i_hv 

408 

409 if not IS_KDA: 

410 p_g = g + bos * HV + i_hv 

411 else: 

412 p_gk = g + (bos * HV + i_hv) * K + o_k 

413 

414 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v 

415 

416 mask_k = o_k < K 

417 mask_v = o_v < V 

418 mask_h = mask_v[:, None] & mask_k[None, :] 

419 

420 b_h = tl.zeros([BV, BK], dtype=tl.float32) 

421 if USE_INITIAL_STATE: 

422 if IS_CONTINUOUS_BATCHING: 

423 if IS_SPEC_DECODING: 

424 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 

425 else: 

426 i_t = 0 

427 # Load state index and check for PAD_SLOT_ID (-1) 

428 state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

429 tl.int64 

430 ) 

431 # Skip if state index is invalid (PAD_SLOT_ID = -1) 

432 if state_idx < 0: 

433 return 

434 p_h0 = h0 + state_idx * stride_init_state_token 

435 else: 

436 p_h0 = h0 + bos * HV * V * K 

437 p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] 

438 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) 

439 

440 for i_t in range(0, T): 

441 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) 

442 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) 

443 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) 

444 

445 if USE_QK_L2NORM_IN_KERNEL: 

446 b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) 

447 b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) 

448 b_q = b_q * scale 

449 # [BV, BK] 

450 if not IS_KDA: 

451 b_g = tl.load(p_g).to(tl.float32) 

452 b_h *= exp(b_g) 

453 else: 

454 b_gk = tl.load(p_gk).to(tl.float32) 

455 b_h *= exp(b_gk[None, :]) 

456 # [BV] 

457 b_v -= tl.sum(b_h * b_k[None, :], 1) 

458 if IS_BETA_HEADWISE: 

459 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) 

460 else: 

461 b_beta = tl.load(p_beta).to(tl.float32) 

462 b_v *= b_beta 

463 # [BV, BK] 

464 b_h += b_v[:, None] * b_k[None, :] 

465 # [BV] 

466 b_o = tl.sum(b_h * b_q[None, :], 1) 

467 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) 

468 

469 # keep the states for multi-query tokens 

470 if INPLACE_FINAL_STATE: 

471 # Load state index and check for PAD_SLOT_ID (-1) 

472 final_state_idx = tl.load( 

473 ssm_state_indices + i_n * stride_indices_seq + i_t 

474 ).to(tl.int64) 

475 # Only store if state index is valid (not PAD_SLOT_ID) 

476 if final_state_idx >= 0: 

477 p_ht = ht + final_state_idx * stride_final_state_token 

478 p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] 

479 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) 

480 else: 

481 p_ht = ht + (bos + i_t) * stride_final_state_token 

482 p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] 

483 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) 

484 

485 p_q += H * K 

486 p_k += H * K 

487 p_o += HV * V 

488 p_v += HV * V 

489 if not IS_KDA: 

490 p_g += HV 

491 else: 

492 p_gk += HV * K 

493 p_beta += HV * (V if IS_BETA_HEADWISE else 1) 

494 

495 

496def fused_recurrent_gated_delta_rule_fwd( 

497 q: torch.Tensor, 

498 k: torch.Tensor, 

499 v: torch.Tensor, 

500 g: torch.Tensor, 

501 beta: torch.Tensor, 

502 scale: float, 

503 initial_state: torch.Tensor, 

504 inplace_final_state: bool = True, 

505 cu_seqlens: torch.LongTensor | None = None, 

506 ssm_state_indices: torch.Tensor | None = None, 

507 num_accepted_tokens: torch.Tensor | None = None, 

508 use_qk_l2norm_in_kernel: bool = False, 

509) -> tuple[torch.Tensor, torch.Tensor]: 

510 logger.debug("GEMS FUSED RECURRENT GATED DELTA RULE FWD") 

511 if not use_qk_l2norm_in_kernel: 

512 q = q.contiguous() 

513 k = k.contiguous() 

514 v = v.contiguous() 

515 

516 B, T, H, K, V = *k.shape, v.shape[-1] 

517 HV = v.shape[2] 

518 N = B if cu_seqlens is None else len(cu_seqlens) - 1 

519 # PTPU shows large bf16 forward drift on the gated-delta recurrence when 

520 # the value block is 32 wide and q/k are not L2-normalized. 

521 max_bv = 8 if not use_qk_l2norm_in_kernel else 32 

522 BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), max_bv) 

523 NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) 

524 assert NK == 1, "NK > 1 is not supported yet" 

525 num_stages = 3 

526 num_warps = 1 

527 qkv_contiguous = q.is_contiguous() and k.is_contiguous() and v.is_contiguous() 

528 

529 o = q.new_empty(NK, *v.shape) 

530 if inplace_final_state: 

531 final_state = initial_state 

532 else: 

533 final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) 

534 

535 stride_init_state_token = initial_state.stride(0) 

536 stride_final_state_token = final_state.stride(0) 

537 

538 if ssm_state_indices is None: 

539 stride_indices_seq, stride_indices_tok = 1, 1 

540 elif ssm_state_indices.ndim == 1: 

541 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 

542 else: 

543 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() 

544 

545 grid = (NK, NV, N * HV) 

546 if qkv_contiguous: 

547 fused_recurrent_gated_delta_rule_fwd_kernel[grid]( 

548 q=q, 

549 k=k, 

550 v=v, 

551 g=g, 

552 beta=beta, 

553 o=o, 

554 h0=initial_state, 

555 ht=final_state, 

556 cu_seqlens=cu_seqlens, 

557 ssm_state_indices=ssm_state_indices, 

558 num_accepted_tokens=num_accepted_tokens, 

559 scale=scale, 

560 N=N, 

561 T=T, 

562 B=B, 

563 H=H, 

564 HV=HV, 

565 K=K, 

566 V=V, 

567 BK=BK, 

568 BV=BV, 

569 stride_init_state_token=stride_init_state_token, 

570 stride_final_state_token=stride_final_state_token, 

571 stride_indices_seq=stride_indices_seq, 

572 stride_indices_tok=stride_indices_tok, 

573 IS_BETA_HEADWISE=beta.ndim == v.ndim, 

574 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, 

575 INPLACE_FINAL_STATE=inplace_final_state, 

576 IS_KDA=False, 

577 num_warps=num_warps, 

578 num_stages=num_stages, 

579 ) 

580 else: 

581 logger.debug( 

582 "GEMS fused_recurrent_gated_delta_rule_fwd, " 

583 "[q.shape]: %s, [q.stride]: %s, " 

584 "[k.shape]: %s, [k.stride]: %s, " 

585 "[v.shape]: %s, [v.stride]: %s, " 

586 "[g.shape]: %s, [beta.shape]: %s, [initial_state.shape]: %s, " 

587 "[cu_seqlens.shape]: %s, N: %s, T: %s, B: %s, H: %s, HV: %s, K: %s, V: %s", 

588 q.shape, 

589 q.stride(), 

590 k.shape, 

591 k.stride(), 

592 v.shape, 

593 v.stride(), 

594 g.shape, 

595 beta.shape, 

596 initial_state.shape, 

597 cu_seqlens.shape, 

598 N, 

599 T, 

600 B, 

601 H, 

602 HV, 

603 K, 

604 V, 

605 ) 

606 if T <= 64: 

607 fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel[grid]( 

608 q=q, 

609 k=k, 

610 v=v, 

611 g=g, 

612 beta=beta, 

613 o=o, 

614 h0=initial_state, 

615 ht=final_state, 

616 cu_seqlens=cu_seqlens, 

617 ssm_state_indices=ssm_state_indices, 

618 num_accepted_tokens=num_accepted_tokens, 

619 scale=scale, 

620 N=N, 

621 T=T, 

622 B=B, 

623 H=H, 

624 HV=HV, 

625 K=K, 

626 V=V, 

627 BK=BK, 

628 BV=BV, 

629 stride_init_state_token=stride_init_state_token, 

630 stride_final_state_token=stride_final_state_token, 

631 stride_indices_seq=stride_indices_seq, 

632 stride_indices_tok=stride_indices_tok, 

633 # stride_q_b=q.stride(0), 

634 stride_q_t=q.stride(1), 

635 stride_q_h=q.stride(2), 

636 stride_q_k=q.stride(3), 

637 # stride_k_b=k.stride(0), 

638 stride_k_t=k.stride(1), 

639 stride_k_h=k.stride(2), 

640 stride_k_k=k.stride(3), 

641 # stride_v_b=v.stride(0), 

642 stride_v_t=v.stride(1), 

643 stride_v_hv=v.stride(2), 

644 stride_v_v=v.stride(3), 

645 IS_BETA_HEADWISE=beta.ndim == v.ndim, 

646 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, 

647 INPLACE_FINAL_STATE=inplace_final_state, 

648 IS_SPEC_DECODING=num_accepted_tokens is not None, 

649 IS_CONTINUOUS_BATCHING=ssm_state_indices is not None, 

650 IS_VARLEN=cu_seqlens is not None, 

651 USE_INITIAL_STATE=initial_state is not None, 

652 num_warps=num_warps, 

653 num_stages=num_stages, 

654 ) 

655 else: 

656 fused_recurrent_gated_delta_rule_large_t_fwd_kernel[grid]( 

657 q=q.contiguous(), 

658 k=k.contiguous(), 

659 v=v.contiguous(), 

660 g=g.contiguous(), 

661 beta=beta.contiguous(), 

662 o=o, 

663 h0=initial_state, 

664 ht=final_state, 

665 cu_seqlens=cu_seqlens, 

666 ssm_state_indices=ssm_state_indices, 

667 num_accepted_tokens=num_accepted_tokens, 

668 scale=scale, 

669 N=N, 

670 T=T, 

671 B=B, 

672 H=H, 

673 HV=HV, 

674 K=K, 

675 V=V, 

676 BK=BK, 

677 BV=BV, 

678 stride_init_state_token=stride_init_state_token, 

679 stride_final_state_token=stride_final_state_token, 

680 stride_indices_seq=stride_indices_seq, 

681 stride_indices_tok=stride_indices_tok, 

682 IS_BETA_HEADWISE=beta.ndim == v.ndim, 

683 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, 

684 INPLACE_FINAL_STATE=inplace_final_state, 

685 IS_KDA=False, 

686 num_warps=num_warps, 

687 num_stages=num_stages, 

688 ) 

689 o = o.squeeze(0) 

690 return o, final_state