Coverage for src/flag_gems/fused/FLA/fused_recurrent.py: 7%
174 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
6import logging
8import torch
9import triton
10import triton.language as tl
12from flag_gems.fused.FLA.triton_ops_helper import exp
14logger = logging.getLogger(__name__)
17@triton.heuristics(
18 {
19 "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
20 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
21 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
22 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
23 }
24)
25@triton.jit(do_not_specialize=["N", "T"])
26# This kernel is specialized for Qwen3-Next model.
27# It requires modifications to the calling logic for Qwen3-Next:
28# Refer to the rearrange_mixed_qkv logic in the benchmark, where setting contiguous=False
29# can provide a certain performance boost by avoiding unnecessary contiguous operations.
30def fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel(
31 q,
32 k,
33 v,
34 g,
35 beta,
36 o,
37 h0,
38 ht,
39 cu_seqlens,
40 ssm_state_indices,
41 num_accepted_tokens,
42 scale,
43 N: tl.int64,
44 T: tl.int64,
45 # stride_q_b: tl.int64,
46 stride_q_t: tl.int64,
47 stride_q_h: tl.int64,
48 stride_q_k: tl.int64,
49 # stride_k_b: tl.int64,
50 stride_k_t: tl.int64,
51 stride_k_h: tl.int64,
52 stride_k_k: tl.int64,
53 # stride_v_b: tl.int64,
54 stride_v_t: tl.int64,
55 stride_v_hv: tl.int64,
56 stride_v_v: tl.int64,
57 B: tl.constexpr,
58 H: tl.constexpr,
59 HV: tl.constexpr,
60 K: tl.constexpr,
61 V: tl.constexpr,
62 BK: tl.constexpr,
63 BV: tl.constexpr,
64 stride_init_state_token: tl.constexpr,
65 stride_final_state_token: tl.constexpr,
66 stride_indices_seq: tl.constexpr,
67 stride_indices_tok: tl.constexpr,
68 USE_INITIAL_STATE: tl.constexpr,
69 INPLACE_FINAL_STATE: tl.constexpr,
70 IS_BETA_HEADWISE: tl.constexpr,
71 USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
72 IS_VARLEN: tl.constexpr,
73 IS_CONTINUOUS_BATCHING: tl.constexpr,
74 IS_SPEC_DECODING: tl.constexpr,
75):
76 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
77 i_n, i_hv = i_nh // HV, i_nh % HV
78 i_h = i_hv // (HV // H)
79 if IS_VARLEN:
80 bos, eos = (
81 tl.load(cu_seqlens + i_n).to(tl.int64),
82 tl.load(cu_seqlens + i_n + 1).to(tl.int64),
83 )
84 all = T
85 T = eos - bos
86 else:
87 bos, eos = i_n * T, i_n * T + T
88 all = B * T
90 if T == 0:
91 # no tokens to process for this sequence
92 return
94 o_k = i_k * BK + tl.arange(0, BK)
95 o_v = i_v * BV + tl.arange(0, BV)
97 p_q = q + bos * stride_q_t + i_h * stride_q_h + o_k * stride_q_k
98 p_k = k + bos * stride_k_t + i_h * stride_k_h + o_k * stride_k_k
99 p_v = v + bos * stride_v_t + i_hv * stride_v_hv + o_v * stride_v_v
100 if IS_BETA_HEADWISE:
101 p_beta = beta + (bos * HV + i_hv) * V + o_v
102 else:
103 p_beta = beta + bos * HV + i_hv
105 p_g = g + bos * HV + i_hv
107 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
109 mask_k = o_k < K
110 mask_v = o_v < V
111 mask_h = mask_k[:, None] & mask_v[None, :]
113 b_h = tl.zeros([BK, BV], dtype=tl.float32)
114 if USE_INITIAL_STATE:
115 if IS_CONTINUOUS_BATCHING:
116 if IS_SPEC_DECODING:
117 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
118 else:
119 i_t = 0
120 p_h0 = (
121 h0
122 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
123 tl.int64
124 )
125 * stride_init_state_token
126 )
127 else:
128 p_h0 = h0 + bos * HV * K * V
129 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
130 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
132 for i_t in range(0, T):
133 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
134 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
135 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
137 if USE_QK_L2NORM_IN_KERNEL:
138 b_q *= tl.rsqrt(tl.sum(b_q * b_q) + 1e-6)
139 b_k *= tl.rsqrt(tl.sum(b_k * b_k) + 1e-6)
140 b_q *= scale
141 # [BK, BV]
142 b_g = tl.load(p_g).to(tl.float32)
143 b_h *= exp(b_g)
144 # [BV]
145 b_v -= tl.sum(b_h * b_k[:, None], 0)
146 if IS_BETA_HEADWISE:
147 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
148 else:
149 b_beta = tl.load(p_beta).to(tl.float32)
150 b_v *= b_beta
151 # [BK, BV]
152 b_h += b_k[:, None] * b_v[None, :]
153 # [BV]
154 b_o = tl.sum(b_h * b_q[:, None], 0)
155 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
157 # keep the states for multi-query tokens
158 if INPLACE_FINAL_STATE:
159 p_ht = (
160 ht
161 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
162 tl.int64
163 )
164 * stride_final_state_token
165 )
166 else:
167 p_ht = ht + (bos + i_t) * stride_final_state_token
168 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
169 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
171 p_q += stride_q_t
172 p_k += stride_k_t
173 p_v += stride_v_t
174 p_o += HV * V
175 p_g += HV
176 p_beta += HV * (V if IS_BETA_HEADWISE else 1)
179@triton.heuristics(
180 {
181 "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
182 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
183 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
184 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
185 }
186)
187@triton.jit(do_not_specialize=["N", "T"])
188def fused_recurrent_gated_delta_rule_fwd_kernel(
189 q,
190 k,
191 v,
192 g,
193 beta,
194 o,
195 h0,
196 ht,
197 cu_seqlens,
198 ssm_state_indices,
199 num_accepted_tokens,
200 scale,
201 N: tl.int64, # num of sequences
202 T: tl.int64, # num of tokens
203 B: tl.constexpr,
204 H: tl.constexpr,
205 HV: tl.constexpr,
206 K: tl.constexpr,
207 V: tl.constexpr,
208 BK: tl.constexpr,
209 BV: tl.constexpr,
210 stride_init_state_token: tl.constexpr,
211 stride_final_state_token: tl.constexpr,
212 stride_indices_seq: tl.constexpr,
213 stride_indices_tok: tl.constexpr,
214 USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
215 INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
216 IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
217 USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
218 IS_VARLEN: tl.constexpr,
219 IS_CONTINUOUS_BATCHING: tl.constexpr,
220 IS_SPEC_DECODING: tl.constexpr,
221 IS_KDA: tl.constexpr,
222):
223 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
224 i_n, i_hv = i_nh // HV, i_nh % HV
225 i_h = i_hv // (HV // H)
226 if IS_VARLEN:
227 bos, eos = (
228 tl.load(cu_seqlens + i_n).to(tl.int64),
229 tl.load(cu_seqlens + i_n + 1).to(tl.int64),
230 )
231 all = T
232 T = eos - bos
233 else:
234 bos, eos = i_n * T, i_n * T + T
235 all = B * T
237 if T == 0:
238 # no tokens to process for this sequence
239 return
241 o_k = i_k * BK + tl.arange(0, BK)
242 o_v = i_v * BV + tl.arange(0, BV)
244 p_q = q + (bos * H + i_h) * K + o_k
245 p_k = k + (bos * H + i_h) * K + o_k
246 p_v = v + (bos * HV + i_hv) * V + o_v
247 if IS_BETA_HEADWISE:
248 p_beta = beta + (bos * HV + i_hv) * V + o_v
249 else:
250 p_beta = beta + bos * HV + i_hv
252 if not IS_KDA:
253 p_g = g + bos * HV + i_hv
254 else:
255 p_gk = g + (bos * HV + i_hv) * K + o_k
257 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
259 mask_k = o_k < K
260 mask_v = o_v < V
261 mask_h = mask_k[:, None] & mask_v[None, :]
263 b_h = tl.zeros([BK, BV], dtype=tl.float32)
264 if USE_INITIAL_STATE:
265 if IS_CONTINUOUS_BATCHING:
266 if IS_SPEC_DECODING:
267 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
268 else:
269 i_t = 0
270 p_h0 = (
271 h0
272 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
273 tl.int64
274 )
275 * stride_init_state_token
276 )
277 else:
278 p_h0 = h0 + bos * HV * K * V
279 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
280 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
282 for i_t in range(0, T):
283 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
284 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
285 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
287 if USE_QK_L2NORM_IN_KERNEL:
288 b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
289 b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
290 b_q = b_q * scale
291 # [BK, BV]
292 if not IS_KDA:
293 b_g = tl.load(p_g).to(tl.float32)
294 b_h *= exp(b_g)
295 else:
296 b_gk = tl.load(p_gk).to(tl.float32)
297 b_h *= exp(b_gk[:, None])
298 # [BV]
299 b_v -= tl.sum(b_h * b_k[:, None], 0)
300 if IS_BETA_HEADWISE:
301 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
302 else:
303 b_beta = tl.load(p_beta).to(tl.float32)
304 b_v *= b_beta
305 # [BK, BV]
306 b_h += b_k[:, None] * b_v[None, :]
307 # [BV]
308 b_o = tl.sum(b_h * b_q[:, None], 0)
309 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
311 # keep the states for multi-query tokens
312 if INPLACE_FINAL_STATE:
313 p_ht = (
314 ht
315 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
316 tl.int64
317 )
318 * stride_final_state_token
319 )
320 else:
321 p_ht = ht + (bos + i_t) * stride_final_state_token
322 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
323 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
325 p_q += H * K
326 p_k += H * K
327 p_o += HV * V
328 p_v += HV * V
329 if not IS_KDA:
330 p_g += HV
331 else:
332 p_gk += HV * K
333 p_beta += HV * (V if IS_BETA_HEADWISE else 1)
336def fused_recurrent_gated_delta_rule_fwd(
337 q: torch.Tensor,
338 k: torch.Tensor,
339 v: torch.Tensor,
340 g: torch.Tensor,
341 beta: torch.Tensor,
342 scale: float,
343 initial_state: torch.Tensor,
344 inplace_final_state: bool = True,
345 cu_seqlens: torch.LongTensor | None = None,
346 ssm_state_indices: torch.Tensor | None = None,
347 num_accepted_tokens: torch.Tensor | None = None,
348 use_qk_l2norm_in_kernel: bool = False,
349) -> tuple[torch.Tensor, torch.Tensor]:
350 logger.debug("GEMS FUSED RECURRENT GATED DELTA RULE FWD")
351 B, T, H, K, V = *k.shape, v.shape[-1]
352 HV = v.shape[2]
353 N = B if cu_seqlens is None else len(cu_seqlens) - 1
354 BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
355 NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
356 assert NK == 1, "NK > 1 is not supported yet"
357 num_stages = 3
358 num_warps = 1
359 qkv_contiguous = (
360 (q.stride(0) == q.stride(1) + q.stride(2))
361 and (k.stride(0) == k.stride(1) + k.stride(2))
362 and (v.stride(0) == v.stride(1) + v.stride(2))
363 )
365 o = q.new_empty(NK, *v.shape)
366 if inplace_final_state:
367 final_state = initial_state
368 else:
369 final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
371 stride_init_state_token = initial_state.stride(0)
372 stride_final_state_token = final_state.stride(0)
374 if ssm_state_indices is None:
375 stride_indices_seq, stride_indices_tok = 1, 1
376 elif ssm_state_indices.ndim == 1:
377 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
378 else:
379 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
381 grid = (NK, NV, N * HV)
382 if qkv_contiguous:
383 fused_recurrent_gated_delta_rule_fwd_kernel[grid](
384 q=q,
385 k=k,
386 v=v,
387 g=g,
388 beta=beta,
389 o=o,
390 h0=initial_state,
391 ht=final_state,
392 cu_seqlens=cu_seqlens,
393 ssm_state_indices=ssm_state_indices,
394 num_accepted_tokens=num_accepted_tokens,
395 scale=scale,
396 N=N,
397 T=T,
398 B=B,
399 H=H,
400 HV=HV,
401 K=K,
402 V=V,
403 BK=BK,
404 BV=BV,
405 stride_init_state_token=stride_init_state_token,
406 stride_final_state_token=stride_final_state_token,
407 stride_indices_seq=stride_indices_seq,
408 stride_indices_tok=stride_indices_tok,
409 IS_BETA_HEADWISE=beta.ndim == v.ndim,
410 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
411 INPLACE_FINAL_STATE=inplace_final_state,
412 IS_KDA=False,
413 num_warps=num_warps,
414 num_stages=num_stages,
415 )
416 else:
417 logger.debug(
418 "GEMS fused_recurrent_gated_delta_rule_fwd, "
419 "[q.shape]: %s, [q.stride]: %s, "
420 "[k.shape]: %s, [k.stride]: %s, "
421 "[v.shape]: %s, [v.stride]: %s, "
422 "[g.shape]: %s, [beta.shape]: %s, [initial_state.shape]: %s, "
423 "[cu_seqlens.shape]: %s, N: %s, T: %s, B: %s, H: %s, HV: %s, K: %s, V: %s",
424 q.shape,
425 q.stride(),
426 k.shape,
427 k.stride(),
428 v.shape,
429 v.stride(),
430 g.shape,
431 beta.shape,
432 initial_state.shape,
433 cu_seqlens.shape,
434 N,
435 T,
436 B,
437 H,
438 HV,
439 K,
440 V,
441 )
442 fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel[grid](
443 q=q,
444 k=k,
445 v=v,
446 g=g,
447 beta=beta,
448 o=o,
449 h0=initial_state,
450 ht=final_state,
451 cu_seqlens=cu_seqlens,
452 ssm_state_indices=ssm_state_indices,
453 num_accepted_tokens=num_accepted_tokens,
454 scale=scale,
455 N=N,
456 T=T,
457 B=B,
458 H=H,
459 HV=HV,
460 K=K,
461 V=V,
462 BK=BK,
463 BV=BV,
464 stride_init_state_token=stride_init_state_token,
465 stride_final_state_token=stride_final_state_token,
466 stride_indices_seq=stride_indices_seq,
467 stride_indices_tok=stride_indices_tok,
468 # stride_q_b=q.stride(0),
469 stride_q_t=q.stride(1),
470 stride_q_h=q.stride(2),
471 stride_q_k=q.stride(3),
472 # stride_k_b=k.stride(0),
473 stride_k_t=k.stride(1),
474 stride_k_h=k.stride(2),
475 stride_k_k=k.stride(3),
476 # stride_v_b=v.stride(0),
477 stride_v_t=v.stride(1),
478 stride_v_hv=v.stride(2),
479 stride_v_v=v.stride(3),
480 IS_BETA_HEADWISE=beta.ndim == v.ndim,
481 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
482 INPLACE_FINAL_STATE=inplace_final_state,
483 IS_SPEC_DECODING=num_accepted_tokens is not None,
484 IS_CONTINUOUS_BATCHING=ssm_state_indices is not None,
485 IS_VARLEN=cu_seqlens is not None,
486 USE_INITIAL_STATE=initial_state is not None,
487 num_warps=num_warps,
488 num_stages=num_stages,
489 )
491 o = o.squeeze(0)
492 return o, final_state