Coverage for src/flag_gems/fused/FLA/fused_recurrent.py: 8%
173 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
6import logging
8import torch
9import triton
10import triton.language as tl
12from flag_gems.fused.FLA.triton_ops_helper import exp
14logger = logging.getLogger(__name__)
17@triton.heuristics(
18 {
19 "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
20 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
21 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
22 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
23 }
24)
25@triton.jit(do_not_specialize=["N", "T"])
26# This kernel is specialized for Qwen3-Next model.
27# It requires modifications to the calling logic for Qwen3-Next:
28# Refer to the rearrange_mixed_qkv logic in the benchmark, where setting contiguous=False
29# can provide a certain performance boost by avoiding unnecessary contiguous operations.
30def fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel(
31 q,
32 k,
33 v,
34 g,
35 beta,
36 o,
37 h0,
38 ht,
39 cu_seqlens,
40 ssm_state_indices,
41 num_accepted_tokens,
42 scale,
43 N: tl.int64,
44 T: tl.int64,
45 # stride_q_b: tl.int64,
46 stride_q_t: tl.int64,
47 stride_q_h: tl.int64,
48 stride_q_k: tl.int64,
49 # stride_k_b: tl.int64,
50 stride_k_t: tl.int64,
51 stride_k_h: tl.int64,
52 stride_k_k: tl.int64,
53 # stride_v_b: tl.int64,
54 stride_v_t: tl.int64,
55 stride_v_hv: tl.int64,
56 stride_v_v: tl.int64,
57 B: tl.constexpr,
58 H: tl.constexpr,
59 HV: tl.constexpr,
60 K: tl.constexpr,
61 V: tl.constexpr,
62 BK: tl.constexpr,
63 BV: tl.constexpr,
64 stride_init_state_token: tl.constexpr,
65 stride_final_state_token: tl.constexpr,
66 stride_indices_seq: tl.constexpr,
67 stride_indices_tok: tl.constexpr,
68 USE_INITIAL_STATE: tl.constexpr,
69 INPLACE_FINAL_STATE: tl.constexpr,
70 IS_BETA_HEADWISE: tl.constexpr,
71 USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
72 IS_VARLEN: tl.constexpr,
73 IS_CONTINUOUS_BATCHING: tl.constexpr,
74 IS_SPEC_DECODING: tl.constexpr,
75):
76 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
77 i_n, i_hv = i_nh // HV, i_nh % HV
78 i_h = i_hv // (HV // H)
79 if IS_VARLEN:
80 bos, eos = (
81 tl.load(cu_seqlens + i_n).to(tl.int64),
82 tl.load(cu_seqlens + i_n + 1).to(tl.int64),
83 )
84 all = T
85 T = eos - bos
86 else:
87 bos, eos = i_n * T, i_n * T + T
88 all = B * T
90 if T == 0:
91 # no tokens to process for this sequence
92 return
94 o_k = i_k * BK + tl.arange(0, BK)
95 o_v = i_v * BV + tl.arange(0, BV)
97 p_q = q + bos * stride_q_t + i_h * stride_q_h + o_k * stride_q_k
98 p_k = k + bos * stride_k_t + i_h * stride_k_h + o_k * stride_k_k
99 p_v = v + bos * stride_v_t + i_hv * stride_v_hv + o_v * stride_v_v
100 if IS_BETA_HEADWISE:
101 p_beta = beta + (bos * HV + i_hv) * V + o_v
102 else:
103 p_beta = beta + bos * HV + i_hv
105 p_g = g + bos * HV + i_hv
107 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
109 mask_k = o_k < K
110 mask_v = o_v < V
111 mask_h = mask_k[:, None] & mask_v[None, :]
113 b_h = tl.zeros([BK, BV], dtype=tl.float32)
114 if USE_INITIAL_STATE:
115 if IS_CONTINUOUS_BATCHING:
116 if IS_SPEC_DECODING:
117 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
118 else:
119 i_t = 0
120 p_h0 = (
121 h0
122 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
123 tl.int64
124 )
125 * stride_init_state_token
126 )
127 else:
128 p_h0 = h0 + bos * HV * K * V
129 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
130 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
132 for i_t in range(0, T):
133 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
134 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
135 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
137 if USE_QK_L2NORM_IN_KERNEL:
138 b_q *= tl.rsqrt(tl.sum(b_q * b_q) + 1e-6)
139 b_k *= tl.rsqrt(tl.sum(b_k * b_k) + 1e-6)
140 b_q *= scale
141 # [BK, BV]
142 b_g = tl.load(p_g).to(tl.float32)
143 b_h *= exp(b_g)
144 # [BV]
145 b_v -= tl.sum(b_h * b_k[:, None], 0)
146 if IS_BETA_HEADWISE:
147 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
148 else:
149 b_beta = tl.load(p_beta).to(tl.float32)
150 b_v *= b_beta
151 # [BK, BV]
152 b_h += b_k[:, None] * b_v[None, :]
153 # [BV]
154 b_o = tl.sum(b_h * b_q[:, None], 0)
155 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
157 # keep the states for multi-query tokens
158 if INPLACE_FINAL_STATE:
159 p_ht = (
160 ht
161 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
162 tl.int64
163 )
164 * stride_final_state_token
165 )
166 else:
167 p_ht = ht + (bos + i_t) * stride_final_state_token
168 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
169 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
171 p_q += stride_q_t
172 p_k += stride_k_t
173 p_v += stride_v_t
174 p_o += HV * V
175 p_g += HV
176 p_beta += HV * (V if IS_BETA_HEADWISE else 1)
179@triton.heuristics(
180 {
181 "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
182 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
183 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
184 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
185 }
186)
187@triton.jit(do_not_specialize=["N", "T"])
188def fused_recurrent_gated_delta_rule_fwd_kernel(
189 q,
190 k,
191 v,
192 g,
193 beta,
194 o,
195 h0,
196 ht,
197 cu_seqlens,
198 ssm_state_indices,
199 num_accepted_tokens,
200 scale,
201 N: tl.int64, # num of sequences
202 T: tl.int64, # num of tokens
203 B: tl.constexpr,
204 H: tl.constexpr,
205 HV: tl.constexpr,
206 K: tl.constexpr,
207 V: tl.constexpr,
208 BK: tl.constexpr,
209 BV: tl.constexpr,
210 stride_init_state_token: tl.constexpr,
211 stride_final_state_token: tl.constexpr,
212 stride_indices_seq: tl.constexpr,
213 stride_indices_tok: tl.constexpr,
214 USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
215 INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
216 IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
217 USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
218 IS_VARLEN: tl.constexpr,
219 IS_CONTINUOUS_BATCHING: tl.constexpr,
220 IS_SPEC_DECODING: tl.constexpr,
221 IS_KDA: tl.constexpr,
222):
223 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
224 i_n, i_hv = i_nh // HV, i_nh % HV
225 i_h = i_hv // (HV // H)
226 if IS_VARLEN:
227 bos, eos = (
228 tl.load(cu_seqlens + i_n).to(tl.int64),
229 tl.load(cu_seqlens + i_n + 1).to(tl.int64),
230 )
231 all = T
232 T = eos - bos
233 else:
234 bos, eos = i_n * T, i_n * T + T
235 all = B * T
237 if T == 0:
238 # no tokens to process for this sequence
239 return
241 o_k = i_k * BK + tl.arange(0, BK)
242 o_v = i_v * BV + tl.arange(0, BV)
244 p_q = q + (bos * H + i_h) * K + o_k
245 p_k = k + (bos * H + i_h) * K + o_k
246 p_v = v + (bos * HV + i_hv) * V + o_v
247 if IS_BETA_HEADWISE:
248 p_beta = beta + (bos * HV + i_hv) * V + o_v
249 else:
250 p_beta = beta + bos * HV + i_hv
252 if not IS_KDA:
253 p_g = g + bos * HV + i_hv
254 else:
255 p_gk = g + (bos * HV + i_hv) * K + o_k
257 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
259 mask_k = o_k < K
260 mask_v = o_v < V
261 mask_h = mask_k[:, None] & mask_v[None, :]
263 b_h = tl.zeros([BK, BV], dtype=tl.float32)
264 if USE_INITIAL_STATE:
265 if IS_CONTINUOUS_BATCHING:
266 if IS_SPEC_DECODING:
267 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
268 else:
269 i_t = 0
270 p_h0 = (
271 h0
272 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
273 tl.int64
274 )
275 * stride_init_state_token
276 )
277 else:
278 p_h0 = h0 + bos * HV * K * V
279 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
280 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
282 for i_t in range(0, T):
283 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
284 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
285 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
287 if USE_QK_L2NORM_IN_KERNEL:
288 b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
289 b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
290 b_q = b_q * scale
291 # [BK, BV]
292 if not IS_KDA:
293 b_g = tl.load(p_g).to(tl.float32)
294 b_h *= exp(b_g)
295 else:
296 b_gk = tl.load(p_gk).to(tl.float32)
297 b_h *= exp(b_gk[:, None])
298 # [BV]
299 b_v -= tl.sum(b_h * b_k[:, None], 0)
300 if IS_BETA_HEADWISE:
301 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
302 else:
303 b_beta = tl.load(p_beta).to(tl.float32)
304 b_v *= b_beta
305 # [BK, BV]
306 b_h += b_k[:, None] * b_v[None, :]
307 # [BV]
308 b_o = tl.sum(b_h * b_q[:, None], 0)
309 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
311 # keep the states for multi-query tokens
312 if INPLACE_FINAL_STATE:
313 p_ht = (
314 ht
315 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
316 tl.int64
317 )
318 * stride_final_state_token
319 )
320 else:
321 p_ht = ht + (bos + i_t) * stride_final_state_token
322 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
323 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
325 p_q += H * K
326 p_k += H * K
327 p_o += HV * V
328 p_v += HV * V
329 if not IS_KDA:
330 p_g += HV
331 else:
332 p_gk += HV * K
333 p_beta += HV * (V if IS_BETA_HEADWISE else 1)
336def fused_recurrent_gated_delta_rule_fwd(
337 q: torch.Tensor,
338 k: torch.Tensor,
339 v: torch.Tensor,
340 g: torch.Tensor,
341 beta: torch.Tensor,
342 scale: float,
343 initial_state: torch.Tensor,
344 inplace_final_state: bool = True,
345 cu_seqlens: torch.LongTensor | None = None,
346 ssm_state_indices: torch.Tensor | None = None,
347 num_accepted_tokens: torch.Tensor | None = None,
348 use_qk_l2norm_in_kernel: bool = False,
349) -> tuple[torch.Tensor, torch.Tensor]:
350 B, T, H, K, V = *k.shape, v.shape[-1]
351 HV = v.shape[2]
352 N = B if cu_seqlens is None else len(cu_seqlens) - 1
353 BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
354 NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
355 assert NK == 1, "NK > 1 is not supported yet"
356 num_stages = 3
357 num_warps = 1
358 qkv_contiguous = (
359 (q.stride(0) == q.stride(1) + q.stride(2))
360 and (k.stride(0) == k.stride(1) + k.stride(2))
361 and (v.stride(0) == v.stride(1) + v.stride(2))
362 )
364 o = q.new_empty(NK, *v.shape)
365 if inplace_final_state:
366 final_state = initial_state
367 else:
368 final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
370 stride_init_state_token = initial_state.stride(0)
371 stride_final_state_token = final_state.stride(0)
373 if ssm_state_indices is None:
374 stride_indices_seq, stride_indices_tok = 1, 1
375 elif ssm_state_indices.ndim == 1:
376 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
377 else:
378 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
380 grid = (NK, NV, N * HV)
381 if qkv_contiguous:
382 fused_recurrent_gated_delta_rule_fwd_kernel[grid](
383 q=q,
384 k=k,
385 v=v,
386 g=g,
387 beta=beta,
388 o=o,
389 h0=initial_state,
390 ht=final_state,
391 cu_seqlens=cu_seqlens,
392 ssm_state_indices=ssm_state_indices,
393 num_accepted_tokens=num_accepted_tokens,
394 scale=scale,
395 N=N,
396 T=T,
397 B=B,
398 H=H,
399 HV=HV,
400 K=K,
401 V=V,
402 BK=BK,
403 BV=BV,
404 stride_init_state_token=stride_init_state_token,
405 stride_final_state_token=stride_final_state_token,
406 stride_indices_seq=stride_indices_seq,
407 stride_indices_tok=stride_indices_tok,
408 IS_BETA_HEADWISE=beta.ndim == v.ndim,
409 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
410 INPLACE_FINAL_STATE=inplace_final_state,
411 IS_KDA=False,
412 num_warps=num_warps,
413 num_stages=num_stages,
414 )
415 else:
416 logger.debug(
417 "GEMS fused_recurrent_gated_delta_rule_fwd, "
418 "[q.shape]: %s, [q.stride]: %s, "
419 "[k.shape]: %s, [k.stride]: %s, "
420 "[v.shape]: %s, [v.stride]: %s, "
421 "[g.shape]: %s, [beta.shape]: %s, [initial_state.shape]: %s, "
422 "[cu_seqlens.shape]: %s, N: %s, T: %s, B: %s, H: %s, HV: %s, K: %s, V: %s",
423 q.shape,
424 q.stride(),
425 k.shape,
426 k.stride(),
427 v.shape,
428 v.stride(),
429 g.shape,
430 beta.shape,
431 initial_state.shape,
432 cu_seqlens.shape,
433 N,
434 T,
435 B,
436 H,
437 HV,
438 K,
439 V,
440 )
441 fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel[grid](
442 q=q,
443 k=k,
444 v=v,
445 g=g,
446 beta=beta,
447 o=o,
448 h0=initial_state,
449 ht=final_state,
450 cu_seqlens=cu_seqlens,
451 ssm_state_indices=ssm_state_indices,
452 num_accepted_tokens=num_accepted_tokens,
453 scale=scale,
454 N=N,
455 T=T,
456 B=B,
457 H=H,
458 HV=HV,
459 K=K,
460 V=V,
461 BK=BK,
462 BV=BV,
463 stride_init_state_token=stride_init_state_token,
464 stride_final_state_token=stride_final_state_token,
465 stride_indices_seq=stride_indices_seq,
466 stride_indices_tok=stride_indices_tok,
467 # stride_q_b=q.stride(0),
468 stride_q_t=q.stride(1),
469 stride_q_h=q.stride(2),
470 stride_q_k=q.stride(3),
471 # stride_k_b=k.stride(0),
472 stride_k_t=k.stride(1),
473 stride_k_h=k.stride(2),
474 stride_k_k=k.stride(3),
475 # stride_v_b=v.stride(0),
476 stride_v_t=v.stride(1),
477 stride_v_hv=v.stride(2),
478 stride_v_v=v.stride(3),
479 IS_BETA_HEADWISE=beta.ndim == v.ndim,
480 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
481 INPLACE_FINAL_STATE=inplace_final_state,
482 IS_SPEC_DECODING=num_accepted_tokens is not None,
483 IS_CONTINUOUS_BATCHING=ssm_state_indices is not None,
484 IS_VARLEN=cu_seqlens is not None,
485 USE_INITIAL_STATE=initial_state is not None,
486 num_warps=num_warps,
487 num_stages=num_stages,
488 )
490 o = o.squeeze(0)
491 return o, final_state