Coverage for src/flag_gems/runtime/backend/_ascend/fla/fused_recurrent.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9# ruff: noqa: E501 

10# mypy: ignore-errors 

11 

12import torch 

13import triton 

14import triton.language as tl 

15 

16 

17@triton.jit 

18def div_normal(x, y): 

19 return x / y 

20 

21 

22div = div_normal 

23exp = tl.exp 

24log = tl.log 

25log2 = tl.log2 

26 

27 

28@triton.heuristics( 

29 { 

30 "USE_INITIAL_STATE": lambda args: args["h0"] is not None, 

31 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

32 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, 

33 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, 

34 } 

35) 

36@triton.jit(do_not_specialize=["N", "T"]) 

37def fused_recurrent_gated_delta_rule_fwd_kernel( 

38 q, 

39 k, 

40 v, 

41 g, 

42 beta, 

43 o, 

44 h0, 

45 ht, 

46 cu_seqlens, 

47 ssm_state_indices, 

48 num_accepted_tokens, 

49 scale, 

50 N: tl.constexpr, # num of sequences 

51 T: tl.constexpr, # num of tokens 

52 B: tl.constexpr, 

53 H: tl.constexpr, 

54 HV: tl.constexpr, 

55 K: tl.constexpr, 

56 V: tl.constexpr, 

57 BK: tl.constexpr, 

58 BV: tl.constexpr, 

59 stride_init_state_token: tl.constexpr, 

60 stride_final_state_token: tl.constexpr, 

61 stride_indices_seq: tl.constexpr, 

62 stride_indices_tok: tl.constexpr, 

63 USE_INITIAL_STATE: tl.constexpr, # whether to use initial state 

64 INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace 

65 IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, 

66 USE_QK_L2NORM_IN_KERNEL: tl.constexpr, 

67 IS_VARLEN: tl.constexpr, 

68 IS_CONTINUOUS_BATCHING: tl.constexpr, 

69 IS_SPEC_DECODING: tl.constexpr, 

70 IS_KDA: tl.constexpr, 

71): 

72 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

73 i_n, i_hv = i_nh // HV, i_nh % HV 

74 i_h = i_hv // (HV // H) 

75 if IS_VARLEN: 

76 bos, eos = ( 

77 tl.load(cu_seqlens + i_n).to(tl.int64), 

78 tl.load(cu_seqlens + i_n + 1).to(tl.int64), 

79 ) 

80 all = T 

81 T = eos - bos 

82 else: 

83 bos, eos = i_n * T, i_n * T + T 

84 all = B * T 

85 

86 if T == 0: 

87 # no tokens to process for this sequence 

88 return 

89 

90 o_k = i_k * BK + tl.arange(0, BK) 

91 o_v = i_v * BV + tl.arange(0, BV) 

92 

93 mask_k = o_k < K 

94 mask_v = o_v < V 

95 mask_h = mask_k[:, None] & mask_v[None, :] 

96 

97 b_h = tl.zeros([BK, BV], dtype=tl.float32) 

98 if USE_INITIAL_STATE: 

99 if IS_CONTINUOUS_BATCHING: 

100 if IS_SPEC_DECODING: 

101 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 

102 else: 

103 i_t = 0 

104 p_h0 = ( 

105 h0 

106 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

107 tl.int64 

108 ) 

109 * stride_init_state_token 

110 ) 

111 else: 

112 p_h0 = h0 + bos * HV * K * V 

113 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

114 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) 

115 

116 for i_t in range(0, T): 

117 p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t 

118 p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t 

119 p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t 

120 

121 if IS_BETA_HEADWISE: 

122 p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t 

123 else: 

124 p_beta = beta + bos * HV + i_hv + HV * i_t 

125 

126 if not IS_KDA: 

127 p_g = g + bos * HV + i_hv + HV * i_t 

128 else: 

129 p_gk = g + (bos * HV + i_hv + HV * i_t) * K + o_k 

130 

131 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t 

132 

133 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) 

134 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) 

135 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) 

136 b_g = tl.load(p_g).to(tl.float32) 

137 

138 if USE_QK_L2NORM_IN_KERNEL: 

139 b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) 

140 b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) 

141 b_q = b_q * scale 

142 # [BK, BV] 

143 # b_h *= tl.exp(b_g) 

144 if not IS_KDA: 

145 b_g = tl.load(p_g).to(tl.float32) 

146 b_h *= exp(b_g) 

147 else: 

148 b_gk = tl.load(p_gk).to(tl.float32) 

149 b_h *= exp(b_gk[:, None]) 

150 # [BV] 

151 b_v -= tl.sum(b_h * b_k[:, None], 0) 

152 if IS_BETA_HEADWISE: 

153 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) 

154 else: 

155 b_beta = tl.load(p_beta).to(tl.float32) 

156 b_v *= b_beta 

157 # [BK, BV] 

158 b_h += b_k[:, None] * b_v[None, :] 

159 # [BV] 

160 b_o = tl.sum(b_h * b_q[:, None], 0) 

161 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) 

162 

163 # keep the states for multi-query tokens 

164 if INPLACE_FINAL_STATE: 

165 p_ht = ( 

166 ht 

167 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( 

168 tl.int64 

169 ) 

170 * stride_final_state_token 

171 ) 

172 else: 

173 p_ht = ht + (bos + i_t) * stride_final_state_token 

174 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] 

175 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) 

176 

177 

178def fused_recurrent_gated_delta_rule_fwd( 

179 q: torch.Tensor, 

180 k: torch.Tensor, 

181 v: torch.Tensor, 

182 g: torch.Tensor, 

183 beta: torch.Tensor, 

184 scale: float, 

185 initial_state: torch.Tensor, 

186 inplace_final_state: bool = True, 

187 cu_seqlens: torch.LongTensor | None = None, 

188 ssm_state_indices: torch.Tensor | None = None, 

189 num_accepted_tokens: torch.Tensor | None = None, 

190 use_qk_l2norm_in_kernel: bool = False, 

191) -> tuple[torch.Tensor, torch.Tensor]: 

192 B, T, H, K, V = *k.shape, v.shape[-1] 

193 HV = v.shape[2] 

194 N = B if cu_seqlens is None else len(cu_seqlens) - 1 

195 BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) 

196 NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) 

197 assert NK == 1, "NK > 1 is not supported yet" 

198 num_stages = 3 

199 num_warps = 1 

200 

201 o = q.new_empty(NK, *v.shape) 

202 if inplace_final_state: 

203 final_state = initial_state 

204 else: 

205 final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) 

206 

207 stride_init_state_token = initial_state.stride(0) 

208 stride_final_state_token = final_state.stride(0) 

209 

210 if ssm_state_indices is None: 

211 stride_indices_seq, stride_indices_tok = 1, 1 

212 elif ssm_state_indices.ndim == 1: 

213 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 

214 else: 

215 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() 

216 

217 grid = (NK, NV, N * HV) 

218 fused_recurrent_gated_delta_rule_fwd_kernel[grid]( 

219 q=q, 

220 k=k, 

221 v=v, 

222 g=g, 

223 beta=beta, 

224 o=o, 

225 h0=initial_state, 

226 ht=final_state, 

227 cu_seqlens=cu_seqlens, 

228 ssm_state_indices=ssm_state_indices, 

229 num_accepted_tokens=num_accepted_tokens, 

230 scale=scale, 

231 N=N, 

232 T=T, 

233 B=B, 

234 H=H, 

235 HV=HV, 

236 K=K, 

237 V=V, 

238 BK=BK, 

239 BV=BV, 

240 stride_init_state_token=stride_init_state_token, 

241 stride_final_state_token=stride_final_state_token, 

242 stride_indices_seq=stride_indices_seq, 

243 stride_indices_tok=stride_indices_tok, 

244 IS_BETA_HEADWISE=beta.ndim == v.ndim, 

245 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, 

246 INPLACE_FINAL_STATE=inplace_final_state, 

247 IS_KDA=False, 

248 num_warps=num_warps, 

249 num_stages=num_stages, 

250 ) 

251 o = o.squeeze(0) 

252 return o, final_state