Coverage for src/flag_gems/runtime/backend/_ascend/fla/fused_recurrent.py: 0%
100 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4#
5# This file contains code copied from the flash-linear-attention project.
6# The original source code was licensed under the MIT license and included
7# the following copyright notice:
8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9# ruff: noqa: E501
10# mypy: ignore-errors
12import torch
13import triton
14import triton.language as tl
17@triton.jit
18def div_normal(x, y):
19 return x / y
22div = div_normal
23exp = tl.exp
24log = tl.log
25log2 = tl.log2
28@triton.heuristics(
29 {
30 "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
31 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
32 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
33 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
34 }
35)
36@triton.jit(do_not_specialize=["N", "T"])
37def fused_recurrent_gated_delta_rule_fwd_kernel(
38 q,
39 k,
40 v,
41 g,
42 beta,
43 o,
44 h0,
45 ht,
46 cu_seqlens,
47 ssm_state_indices,
48 num_accepted_tokens,
49 scale,
50 N: tl.constexpr, # num of sequences
51 T: tl.constexpr, # num of tokens
52 B: tl.constexpr,
53 H: tl.constexpr,
54 HV: tl.constexpr,
55 K: tl.constexpr,
56 V: tl.constexpr,
57 BK: tl.constexpr,
58 BV: tl.constexpr,
59 stride_init_state_token: tl.constexpr,
60 stride_final_state_token: tl.constexpr,
61 stride_indices_seq: tl.constexpr,
62 stride_indices_tok: tl.constexpr,
63 USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
64 INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
65 IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
66 USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
67 IS_VARLEN: tl.constexpr,
68 IS_CONTINUOUS_BATCHING: tl.constexpr,
69 IS_SPEC_DECODING: tl.constexpr,
70 IS_KDA: tl.constexpr,
71):
72 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
73 i_n, i_hv = i_nh // HV, i_nh % HV
74 i_h = i_hv // (HV // H)
75 if IS_VARLEN:
76 bos, eos = (
77 tl.load(cu_seqlens + i_n).to(tl.int64),
78 tl.load(cu_seqlens + i_n + 1).to(tl.int64),
79 )
80 all = T
81 T = eos - bos
82 else:
83 bos, eos = i_n * T, i_n * T + T
84 all = B * T
86 if T == 0:
87 # no tokens to process for this sequence
88 return
90 o_k = i_k * BK + tl.arange(0, BK)
91 o_v = i_v * BV + tl.arange(0, BV)
93 mask_k = o_k < K
94 mask_v = o_v < V
95 mask_h = mask_k[:, None] & mask_v[None, :]
97 b_h = tl.zeros([BK, BV], dtype=tl.float32)
98 if USE_INITIAL_STATE:
99 if IS_CONTINUOUS_BATCHING:
100 if IS_SPEC_DECODING:
101 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
102 else:
103 i_t = 0
104 p_h0 = (
105 h0
106 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
107 tl.int64
108 )
109 * stride_init_state_token
110 )
111 else:
112 p_h0 = h0 + bos * HV * K * V
113 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
114 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
116 for i_t in range(0, T):
117 p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t
118 p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t
119 p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t
121 if IS_BETA_HEADWISE:
122 p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t
123 else:
124 p_beta = beta + bos * HV + i_hv + HV * i_t
126 if not IS_KDA:
127 p_g = g + bos * HV + i_hv + HV * i_t
128 else:
129 p_gk = g + (bos * HV + i_hv + HV * i_t) * K + o_k
131 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t
133 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
134 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
135 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
136 b_g = tl.load(p_g).to(tl.float32)
138 if USE_QK_L2NORM_IN_KERNEL:
139 b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
140 b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
141 b_q = b_q * scale
142 # [BK, BV]
143 # b_h *= tl.exp(b_g)
144 if not IS_KDA:
145 b_g = tl.load(p_g).to(tl.float32)
146 b_h *= exp(b_g)
147 else:
148 b_gk = tl.load(p_gk).to(tl.float32)
149 b_h *= exp(b_gk[:, None])
150 # [BV]
151 b_v -= tl.sum(b_h * b_k[:, None], 0)
152 if IS_BETA_HEADWISE:
153 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
154 else:
155 b_beta = tl.load(p_beta).to(tl.float32)
156 b_v *= b_beta
157 # [BK, BV]
158 b_h += b_k[:, None] * b_v[None, :]
159 # [BV]
160 b_o = tl.sum(b_h * b_q[:, None], 0)
161 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
163 # keep the states for multi-query tokens
164 if INPLACE_FINAL_STATE:
165 p_ht = (
166 ht
167 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
168 tl.int64
169 )
170 * stride_final_state_token
171 )
172 else:
173 p_ht = ht + (bos + i_t) * stride_final_state_token
174 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
175 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
178def fused_recurrent_gated_delta_rule_fwd(
179 q: torch.Tensor,
180 k: torch.Tensor,
181 v: torch.Tensor,
182 g: torch.Tensor,
183 beta: torch.Tensor,
184 scale: float,
185 initial_state: torch.Tensor,
186 inplace_final_state: bool = True,
187 cu_seqlens: torch.LongTensor | None = None,
188 ssm_state_indices: torch.Tensor | None = None,
189 num_accepted_tokens: torch.Tensor | None = None,
190 use_qk_l2norm_in_kernel: bool = False,
191) -> tuple[torch.Tensor, torch.Tensor]:
192 B, T, H, K, V = *k.shape, v.shape[-1]
193 HV = v.shape[2]
194 N = B if cu_seqlens is None else len(cu_seqlens) - 1
195 BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
196 NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
197 assert NK == 1, "NK > 1 is not supported yet"
198 num_stages = 3
199 num_warps = 1
201 o = q.new_empty(NK, *v.shape)
202 if inplace_final_state:
203 final_state = initial_state
204 else:
205 final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
207 stride_init_state_token = initial_state.stride(0)
208 stride_final_state_token = final_state.stride(0)
210 if ssm_state_indices is None:
211 stride_indices_seq, stride_indices_tok = 1, 1
212 elif ssm_state_indices.ndim == 1:
213 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
214 else:
215 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
217 grid = (NK, NV, N * HV)
218 fused_recurrent_gated_delta_rule_fwd_kernel[grid](
219 q=q,
220 k=k,
221 v=v,
222 g=g,
223 beta=beta,
224 o=o,
225 h0=initial_state,
226 ht=final_state,
227 cu_seqlens=cu_seqlens,
228 ssm_state_indices=ssm_state_indices,
229 num_accepted_tokens=num_accepted_tokens,
230 scale=scale,
231 N=N,
232 T=T,
233 B=B,
234 H=H,
235 HV=HV,
236 K=K,
237 V=V,
238 BK=BK,
239 BV=BV,
240 stride_init_state_token=stride_init_state_token,
241 stride_final_state_token=stride_final_state_token,
242 stride_indices_seq=stride_indices_seq,
243 stride_indices_tok=stride_indices_tok,
244 IS_BETA_HEADWISE=beta.ndim == v.ndim,
245 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
246 INPLACE_FINAL_STATE=inplace_final_state,
247 IS_KDA=False,
248 num_warps=num_warps,
249 num_stages=num_stages,
250 )
251 o = o.squeeze(0)
252 return o, final_state