Coverage for src/flag_gems/runtime/backend/_ascend/fla/sigmoid_gating.py: 0%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang 

4# 

5# This file contains code copied from the flash-linear-attention project. 

6# The original source code was licensed under the MIT license and included 

7# the following copyright notice: 

8# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

9# ruff: noqa: E501 

10# mypy: ignore-errors 

11import torch 

12import triton 

13import triton.language as tl 

14 

15 

16@triton.jit 

17def div_normal(x, y): 

18 return x / y 

19 

20 

21div = div_normal 

22exp = tl.exp 

23log = tl.log 

24log2 = tl.log2 

25 

26 

27@triton.heuristics( 

28 { 

29 "USE_INITIAL_STATE": lambda args: args["h0_source"] is not None, 

30 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, 

31 } 

32) 

33@triton.jit(do_not_specialize=["T"]) 

34def fused_sigmoid_gating_delta_rule_update_kernel( 

35 A_log, 

36 a, 

37 dt_bias, 

38 softplus_beta, 

39 softplus_threshold, 

40 q, 

41 k, 

42 v, 

43 b, 

44 o, 

45 h0_source, 

46 h0_indices, 

47 cu_seqlens, 

48 scale, 

49 T, 

50 B: tl.constexpr, 

51 H: tl.constexpr, 

52 HV: tl.constexpr, 

53 K: tl.constexpr, 

54 V: tl.constexpr, 

55 BK: tl.constexpr, 

56 BV: tl.constexpr, 

57 USE_INITIAL_STATE: tl.constexpr, 

58 USE_QK_L2NORM_IN_KERNEL: tl.constexpr, 

59 IS_VARLEN: tl.constexpr, 

60): 

61 """ 

62 Fused kernel that combines sigmoid gating computation with recurrent delta rule update. 

63 """ 

64 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

65 i_n, i_hv = i_nh // HV, i_nh % HV 

66 i_h = i_hv // (HV // H) 

67 

68 if IS_VARLEN: 

69 bos, eos = ( 

70 tl.load(cu_seqlens + i_n).to(tl.int64), 

71 tl.load(cu_seqlens + i_n + 1).to(tl.int64), 

72 ) 

73 all = T 

74 T = eos - bos 

75 else: 

76 bos, eos = i_n * T, i_n * T + T 

77 all = B * T 

78 

79 o_k = i_k * BK + tl.arange(0, BK) 

80 o_v = i_v * BV + tl.arange(0, BV) 

81 

82 p_q = q + (bos * H + i_h) * K + o_k 

83 p_k = k + (bos * H + i_h) * K + o_k 

84 p_v = v + (bos * HV + i_hv) * V + o_v 

85 p_b = b + bos * HV + i_hv 

86 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v 

87 

88 # Gating computation pointers 

89 p_A_log = A_log + i_hv 

90 p_a = a + bos * HV + i_hv 

91 p_dt_bias = dt_bias + i_hv 

92 

93 mask_k = o_k < K 

94 mask_v = o_v < V 

95 mask_h = mask_k[:, None] & mask_v[None, :] 

96 

97 b_h = tl.zeros([BK, BV], dtype=tl.float32) 

98 if USE_INITIAL_STATE: 

99 idx = tl.load(h0_indices + i_n) 

100 # if idx >= 0: 

101 tmp0 = tl.where(idx < 0, 0, idx) 

102 p_h0 = ( 

103 h0_source 

104 + tmp0 * HV * K * V 

105 + i_hv * K * V 

106 + o_k[:, None] * V 

107 + o_v[None, :] 

108 ) 

109 temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) 

110 temp2 = tl.zeros_like(temp1) 

111 value0 = tl.where(idx < 0, temp2, temp1) 

112 b_h += value0 # tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) 

113 

114 for i in range(0, T): 

115 # Load inputs 

116 b_q = tl.load(p_q + i * H * K, mask=mask_k, other=0).to(tl.float32) 

117 b_k = tl.load(p_k + i * H * K, mask=mask_k, other=0).to(tl.float32) 

118 b_v = tl.load(p_v + i * HV * V, mask=mask_v, other=0).to(tl.float32) 

119 b_b = tl.load(p_b + i * HV).to(tl.float32) 

120 

121 # Compute sigmoid gating 

122 # Load gating parameters 

123 b_A_log = tl.load(p_A_log).to(tl.float32) 

124 b_a = tl.load(p_a + i * HV).to(tl.float32) 

125 b_dt_bias = tl.load(p_dt_bias).to(tl.float32) 

126 

127 # Compute g = -exp(A_log) * softplus(a + dt_bias) 

128 x = b_a + b_dt_bias 

129 beta_x = softplus_beta * x 

130 # Apply softplus with numerical stability 

131 softplus_x = tl.where( 

132 beta_x <= softplus_threshold, 

133 (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), 

134 x, 

135 ) 

136 b_g = -tl.exp(b_A_log) * softplus_x 

137 

138 # Compute beta = sigmoid(b) 

139 b_beta = 1.0 / (1.0 + tl.exp(-b_b)) 

140 

141 # Apply L2 normalization if enabled 

142 if USE_QK_L2NORM_IN_KERNEL: 

143 b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) 

144 b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) 

145 

146 b_q = b_q * scale 

147 

148 # Apply gating to hidden state: h *= exp(g) 

149 b_h *= tl.exp(b_g) 

150 

151 # Delta rule: v -= sum(h * k, dim=0) 

152 b_v -= tl.sum(b_h * b_k[:, None], 0) 

153 

154 # Apply beta gating: v *= beta 

155 b_v *= b_beta 

156 

157 # Update hidden state: h += k[:, None] * v[None, :] 

158 b_h += b_k[:, None] * b_v[None, :] 

159 

160 # Compute output: o = sum(h * q, dim=0) 

161 b_o = tl.sum(b_h * b_q[:, None], 0) 

162 tl.store(p_o + i * HV * V, b_o.to(p_o.dtype.element_ty), mask=mask_v) 

163 

164 # # Update pointers for next timestep 

165 # p_q += H * K 

166 # p_k += H * K 

167 # p_o += HV * V 

168 # p_v += HV * V 

169 # p_b += HV 

170 # p_a += HV 

171 

172 # Store final state back to h0_source with bounds checking 

173 if USE_INITIAL_STATE: 

174 idx = tl.load(h0_indices + i_n) 

175 if idx >= 0: 

176 p_h0 = ( 

177 h0_source 

178 + idx * HV * K * V 

179 + i_hv * K * V 

180 + o_k[:, None] * V 

181 + o_v[None, :] 

182 ) 

183 tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) 

184 

185 

186def fused_sigmoid_gating_delta_rule_update( 

187 A_log: torch.Tensor, 

188 a: torch.Tensor, 

189 dt_bias: torch.Tensor, 

190 softplus_beta: float, 

191 softplus_threshold: float, 

192 q: torch.Tensor, 

193 k: torch.Tensor, 

194 v: torch.Tensor, 

195 b: torch.Tensor, 

196 initial_state_source: torch.Tensor, 

197 initial_state_indices: torch.Tensor, 

198 scale: float = None, 

199 use_qk_l2norm_in_kernel: bool = False, 

200 cu_seqlens: torch.Tensor = None, 

201): 

202 """ 

203 Fused triton implementation of sigmoid gating delta rule update. 

204 This function uses a single fused kernel that combines both sigmoid gating computation 

205 and the recurrent delta rule update for better performance. 

206 """ 

207 B, T, H, K, V = *k.shape, v.shape[-1] 

208 HV = v.shape[2] 

209 N = B if cu_seqlens is None else len(cu_seqlens) - 1 

210 BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64) 

211 NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) 

212 assert NK == 1, "NK > 1 is not supported yet" 

213 num_stages = 3 

214 num_warps = 1 

215 

216 if scale is None: 

217 scale = k.shape[-1] ** -0.5 

218 else: 

219 assert scale > 0, "scale must be positive" 

220 

221 o = q.new_empty(NK, *v.shape) 

222 grid = (NK, NV, N * HV) 

223 

224 if not initial_state_indices.is_contiguous(): 

225 initial_state_indices = initial_state_indices.contiguous() 

226 if not initial_state_source.is_contiguous(): 

227 initial_state_source = initial_state_source.contiguous() 

228 if not cu_seqlens.is_contiguous(): 

229 cu_seqlens = cu_seqlens.contiguous() 

230 

231 fused_sigmoid_gating_delta_rule_update_kernel[grid]( 

232 A_log=A_log, 

233 a=a, 

234 dt_bias=dt_bias, 

235 softplus_beta=softplus_beta, 

236 softplus_threshold=softplus_threshold, 

237 q=q, 

238 k=k, 

239 v=v, 

240 b=b, 

241 o=o, 

242 h0_source=initial_state_source, 

243 h0_indices=initial_state_indices, 

244 cu_seqlens=cu_seqlens, 

245 scale=scale, 

246 T=T, 

247 B=B, 

248 H=H, 

249 HV=HV, 

250 K=K, 

251 V=V, 

252 BK=BK, 

253 BV=BV, 

254 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, 

255 num_warps=num_warps, 

256 num_stages=num_stages, 

257 ) 

258 o = o.squeeze(0) 

259 return o 

260 o = o.squeeze(0) 

261 return o