Coverage for src/flag_gems/runtime/backend/_ascend/fla/chunk_delta_h.py: 0%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9# ruff: noqa: E501 

10# mypy: ignore-errors 

11import torch 

12import triton 

13import triton.language as tl 

14 

15from .utils import prepare_chunk_indices, prepare_chunk_offsets, safe_exp 

16 

17_CONDITIONS = ("seq7168",) 

18 

19 

20@triton.heuristics( 

21 { 

22 "USE_G": lambda args: args["g"] is not None, 

23 "USE_INITIAL_STATE": lambda args: args["h0"] is not None, 

24 "STORE_FINAL_STATE": lambda args: args["ht"] is not None, 

25 "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, 

26 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

27 } 

28) 

29@triton.jit(do_not_specialize=["T"]) 

30def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( 

31 k, 

32 v, 

33 w, 

34 v_new, 

35 g, 

36 h, 

37 h0, 

38 ht, 

39 cu_seqlens, 

40 chunk_offsets, 

41 T, 

42 H: tl.constexpr, 

43 Hg: tl.constexpr, 

44 K: tl.constexpr, 

45 V: tl.constexpr, 

46 BT: tl.constexpr, 

47 USE_G: tl.constexpr, 

48 USE_INITIAL_STATE: tl.constexpr, 

49 STORE_FINAL_STATE: tl.constexpr, 

50 SAVE_NEW_VALUE: tl.constexpr, 

51 IS_VARLEN: tl.constexpr, 

52): 

53 i_nh = tl.program_id(1) 

54 i_n, i_h = i_nh // H, i_nh % H 

55 T_max = 1 * T 

56 if IS_VARLEN: 

57 bos, eos = ( 

58 tl.load(cu_seqlens + i_n).to(tl.int32), 

59 tl.load(cu_seqlens + i_n + 1).to(tl.int32), 

60 ) 

61 T = eos - bos 

62 NT = tl.cdiv(T, BT) 

63 boh = tl.load(chunk_offsets + i_n).to(tl.int32) 

64 else: 

65 bos, eos = i_n * T, i_n * T + T 

66 NT = tl.cdiv(T, BT) 

67 boh = i_n * NT 

68 

69 stride_v = H * V 

70 stride_k = Hg * K 

71 stride_w = H * K 

72 

73 b_h1_bv1 = tl.zeros([128, 64], dtype=tl.float32) 

74 b_h1_bv2 = tl.zeros([128, 64], dtype=tl.float32) 

75 

76 v_start1 = 0 

77 v_start2 = 64 

78 

79 offs_k = tl.arange(0, 128)[:, None] 

80 offs_v1 = v_start1 + tl.arange(0, 64)[None, :] 

81 offs_v2 = v_start2 + tl.arange(0, 64)[None, :] 

82 mask_kv1 = (offs_k < K) & (offs_v1 < V) 

83 mask_kv2 = (offs_k < K) & (offs_v2 < V) 

84 

85 # load initial state 

86 if USE_INITIAL_STATE: 

87 h0_ptr = h0 + i_nh * K * V 

88 ptr_h0_bv1 = h0_ptr + offs_k * V + offs_v1 * 1 

89 b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1, other=0.0).to(tl.float32) 

90 

91 ptr_h0_bv2 = h0_ptr + offs_k * V + offs_v2 * 1 

92 b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2, other=0.0).to(tl.float32) 

93 

94 # main recurrence 

95 for i_t in range(NT): 

96 h_base = h + (boh + i_t) * H * K * V + i_h * K * V 

97 

98 p_h1_bv1 = tl.make_block_ptr( 

99 h_base, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0) 

100 ) 

101 tl.store( 

102 p_h1_bv1, b_h1_bv1.to(p_h1_bv1.dtype.element_ty), boundary_check=(0, 1) 

103 ) 

104 

105 p_h1_bv2 = tl.make_block_ptr( 

106 h_base, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0) 

107 ) 

108 tl.store( 

109 p_h1_bv2, b_h1_bv2.to(p_h1_bv2.dtype.element_ty), boundary_check=(0, 1) 

110 ) 

111 

112 offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None] 

113 offs_k_wv = tl.arange(0, 128)[None, :] 

114 mask_w = (offs_t_wv < T) & (offs_k_wv < K) 

115 

116 w_base = w + bos * H * K + i_h * K 

117 ptr_w = w_base + offs_t_wv * stride_w + offs_k_wv * 1 

118 b_w = tl.load(ptr_w, mask=mask_w, other=0.0) 

119 

120 k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K 

121 p_k = tl.make_block_ptr( 

122 k_base, (K, T), (1, stride_k), (0, i_t * BT), (128, BT), (0, 1) 

123 ) 

124 b_k = tl.load(p_k, boundary_check=(0, 1)) 

125 

126 v_new_base = v_new + bos * H * V + i_h * V 

127 

128 last_idx = min((i_t + 1) * BT, T) - 1 

129 b_g_last = tl.load(g + bos + i_h * T_max + last_idx) 

130 

131 offs_t = i_t * BT + tl.arange(0, BT) 

132 mask_t = offs_t < T 

133 g_ptr = g + bos + i_h * T_max 

134 b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0) 

135 

136 b_g = safe_exp(b_g_last - b_g) 

137 b_g_last = tl.exp(b_g_last) 

138 

139 offs_t_v = (i_t * BT + tl.arange(0, BT))[:, None] 

140 mask_v1 = (offs_t_v < T) & (offs_v1 < V) 

141 

142 v_base = v + bos * H * V + i_h * V 

143 ptr_v1 = v_base + offs_t_v * stride_v + offs_v1 * 1 

144 b_v1 = tl.load(ptr_v1, mask=mask_v1, other=0.0) 

145 b_v_new1 = b_v1.to(tl.float32) 

146 b_v_new1 -= tl.dot(b_w, b_h1_bv1.to(b_w.dtype)) 

147 

148 if SAVE_NEW_VALUE: 

149 p_v_new1 = tl.make_block_ptr( 

150 v_new_base, 

151 (T, V), 

152 (stride_v, 1), 

153 (i_t * BT, v_start1), 

154 (BT, 64), 

155 (1, 0), 

156 ) 

157 tl.store( 

158 p_v_new1, b_v_new1.to(p_v_new1.dtype.element_ty), boundary_check=(0, 1) 

159 ) 

160 

161 if USE_G: 

162 b_v_new1 = b_v_new1 * b_g[:, None] 

163 b_h1_bv1 = b_h1_bv1 * b_g_last 

164 

165 b_v_new1 = b_v_new1.to(k.dtype.element_ty) 

166 b_h1_bv1 += tl.dot(b_k, b_v_new1) 

167 

168 mask_v2 = (offs_t_v < T) & (offs_v2 < V) 

169 ptr_v2 = v_base + offs_t_v * stride_v + offs_v2 * 1 

170 b_v2 = tl.load(ptr_v2, mask=mask_v2, other=0.0) 

171 b_v_new2 = b_v2.to(tl.float32) 

172 b_v_new2 -= tl.dot(b_w, b_h1_bv2.to(b_w.dtype)) 

173 

174 if SAVE_NEW_VALUE: 

175 p_v_new2 = tl.make_block_ptr( 

176 v_new_base, 

177 (T, V), 

178 (stride_v, 1), 

179 (i_t * BT, v_start2), 

180 (BT, 64), 

181 (1, 0), 

182 ) 

183 tl.store( 

184 p_v_new2, b_v_new2.to(p_v_new2.dtype.element_ty), boundary_check=(0, 1) 

185 ) 

186 

187 if USE_G: 

188 b_v_new2 = b_v_new2 * b_g[:, None] 

189 b_h1_bv2 = b_h1_bv2 * b_g_last 

190 

191 b_v_new2 = b_v_new2.to(k.dtype.element_ty) 

192 b_h1_bv2 += tl.dot(b_k, b_v_new2) 

193 

194 # epilogue 

195 if STORE_FINAL_STATE: 

196 ht_ptr = ht + i_nh * K * V 

197 

198 p_ht1_bv1 = tl.make_block_ptr( 

199 ht_ptr, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0) 

200 ) 

201 tl.store( 

202 p_ht1_bv1, b_h1_bv1.to(p_ht1_bv1.dtype.element_ty), boundary_check=(0, 1) 

203 ) 

204 

205 p_ht1_bv2 = tl.make_block_ptr( 

206 ht_ptr, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0) 

207 ) 

208 tl.store( 

209 p_ht1_bv2, b_h1_bv2.to(p_ht1_bv2.dtype.element_ty), boundary_check=(0, 1) 

210 ) 

211 

212 

213def chunk_gated_delta_rule_fwd_h( 

214 k: torch.Tensor, 

215 w: torch.Tensor, 

216 u: torch.Tensor, 

217 g: torch.Tensor | None = None, 

218 initial_state: torch.Tensor | None = None, 

219 output_final_state: bool = False, 

220 chunk_size: int = 64, # SY: remove this argument and force chunk size 64? 

221 save_new_value: bool = True, 

222 cu_seqlens: torch.LongTensor | None = None, 

223) -> tuple[torch.Tensor, torch.Tensor]: 

224 # This kernel is slightly different from fla to support Q/K with different head numbers. 

225 # In fla, Q/K always have the same head number, so Hg is always equal to H. 

226 B, T, Hg, K, V = *k.shape, u.shape[-1] 

227 H = u.shape[-2] 

228 BT = chunk_size 

229 

230 chunk_indices = ( 

231 prepare_chunk_indices(cu_seqlens, chunk_size) 

232 if cu_seqlens is not None 

233 else None 

234 ) 

235 # N: the actual number of sequences in the batch with either equal or variable lengths 

236 if cu_seqlens is None: 

237 N, NT, chunk_offsets = B, triton.cdiv(T, BT), None 

238 else: 

239 N, NT, chunk_offsets = ( 

240 len(cu_seqlens) - 1, 

241 len(chunk_indices), 

242 prepare_chunk_offsets(cu_seqlens, BT), 

243 ) 

244 assert K <= 256, "current kernel does not support head dimension larger than 256." 

245 

246 h = k.new_empty(B, NT, H, K, V) 

247 final_state = ( 

248 k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None 

249 ) 

250 

251 v_new = torch.empty_like(u) if save_new_value else None 

252 g = g.transpose(1, 2).contiguous() 

253 

254 def grid(meta): 

255 return (1, N * H) 

256 

257 chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( 

258 k=k, 

259 v=u, 

260 w=w, 

261 v_new=v_new, 

262 g=g, 

263 h=h, 

264 h0=initial_state, 

265 ht=final_state, 

266 cu_seqlens=cu_seqlens, 

267 chunk_offsets=chunk_offsets, 

268 T=T, 

269 H=H, 

270 Hg=Hg, 

271 K=K, 

272 V=V, 

273 BT=BT, 

274 num_warps=4, 

275 num_stages=2, 

276 ) 

277 return h, v_new, final_state