Coverage for src/flag_gems/fused/flashmla_sparse.py: 31%

133 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1from typing import Optional, Tuple 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7flash_mla_sparse_fwd_configs = [ 

8 triton.Config({"num_stages": 4, "num_warps": 8}), 

9 triton.Config({"num_stages": 2, "num_warps": 4}), 

10] 

11 

12 

13@triton.autotune( # Decorate the kernel 

14 configs=flash_mla_sparse_fwd_configs, 

15 key=["K", "is_causal"], 

16) 

17@triton.jit 

18def triton_flash_mla_sparse_fwd( 

19 q, 

20 kv, 

21 indices, 

22 attn_sink, 

23 topk_length, 

24 sm_scale: tl.constexpr, 

25 output, 

26 max_logits, 

27 lse, 

28 stride_qh, 

29 stride_qm, 

30 stride_qd, 

31 stride_kvg, 

32 stride_kvn, 

33 stride_kvd, 

34 stride_tg, 

35 stride_tm, 

36 stride_tt, # indices dim 

37 stride_attn_sink_h, 

38 stride_topk_length_s, 

39 stride_oh, 

40 stride_om, 

41 stride_od, 

42 stride_mh, 

43 stride_mm, 

44 stride_lh, 

45 stride_lm, 

46 SQ: tl.constexpr, # seqlen 

47 SKV: tl.constexpr, # seqlen_kv 

48 K: tl.constexpr, # topk 

49 D: tl.constexpr, # QKV dim 

50 TD: tl.constexpr, # tail dim 

51 DP: tl.constexpr, 

52 TDP: tl.constexpr, 

53 G: tl.constexpr, # group_size 

54 BK: tl.constexpr, 

55 BH: tl.constexpr, 

56 is_causal: tl.constexpr, 

57 q_idx_i64: tl.constexpr, 

58 output_idx_i64: tl.constexpr, 

59 HAVE_ATTN_SINK: tl.constexpr, 

60 HAVE_TOPK_LENGTH: tl.constexpr, 

61): 

62 i_sq, i_gbh = tl.program_id(0), tl.program_id(1) 

63 i_g, i_bh = i_gbh // G, i_gbh % G 

64 if not q_idx_i64: 

65 q_base = q + i_sq * stride_qm + i_gbh * (BH * stride_qh) 

66 else: 

67 q_base = q + i_sq * tl.cast(stride_qm, tl.int64) + i_gbh * (BH * stride_qh) 

68 tq_base = q_base + D * stride_qd 

69 kv_base = kv + i_g * stride_kvg 

70 tkv_base = kv_base + D * stride_kvd 

71 t_base = indices + i_sq * stride_tm + i_g * stride_tg 

72 attn_sink_ptr = ( 

73 attn_sink + i_gbh * (BH * stride_attn_sink_h) if HAVE_ATTN_SINK else 0 

74 ) 

75 topk_length_ptr = ( 

76 topk_length + i_sq * stride_topk_length_s if HAVE_TOPK_LENGTH else 0 

77 ) 

78 if not output_idx_i64: 

79 o_base = output + i_sq * stride_om + i_gbh * (BH * stride_oh) 

80 else: 

81 o_base = output + i_sq * tl.cast(stride_om, tl.int64) + i_gbh * (BH * stride_oh) 

82 max_log_base = max_logits + i_sq * stride_mm + i_gbh * (BH * stride_mh) 

83 l_base = lse + i_sq * stride_lm + i_gbh * (BH * stride_lh) 

84 

85 offs_h = tl.arange(0, BH) 

86 offs_d = tl.arange(0, DP) 

87 offs_td = tl.arange(0, TDP) if TDP > 0 else None 

88 offs_od = tl.arange(0, DP) 

89 offs_t = tl.arange(0, BK) 

90 mask_h = i_bh * BH + offs_h < G 

91 mask_d = offs_d < D 

92 mask_td = offs_td < TD if TDP > 0 else None 

93 mask_od = mask_d 

94 

95 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

96 q_msk = mask_h[:, None] & mask_d[None, :] 

97 q_blk = tl.load(q_ptr, q_msk, other=0.0).to(tl.float32) 

98 

99 tq_blk = None 

100 if TDP > 0: 

101 tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd 

102 tq_msk = mask_h[:, None] & mask_td[None, :] 

103 tq_blk = tl.load(tq_ptr, tq_msk, other=0.0).to(tl.float32) 

104 

105 max_log = tl.full([BH], float("-inf"), dtype=tl.float32) 

106 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

107 acc = tl.zeros([BH, DP], dtype=tl.float32) 

108 qk = tl.zeros([BH, BK], dtype=tl.float32) 

109 

110 max_col = i_sq if is_causal else SKV - 1 

111 topk_len = tl.load(topk_length_ptr).to(tl.int32) if HAVE_TOPK_LENGTH else K 

112 

113 NK = tl.cdiv(K, BK) 

114 for ck in range(NK): 

115 # step1: load indices 

116 t_ptr = (BK * ck + offs_t) * stride_tt 

117 t_msk = t_ptr < topk_len 

118 t_ptr += t_base 

119 kv_ids = tl.load(t_ptr, t_msk, other=-1) 

120 mask_ids = (kv_ids <= max_col) & (kv_ids >= 0) 

121 # filter invalid index that may cause overflow in mul 

122 kv_ids = tl.where(mask_ids, kv_ids, 0) 

123 

124 # if mask_ids.max(0) > 0: 

125 if ck * BK <= max_col: 

126 # step2: gather kv with indices 

127 kv_ptr = ( 

128 kv_base + offs_d[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn 

129 ) 

130 kv_msk = mask_d[:, None] & mask_ids[None, :] 

131 kv_blk = tl.load(kv_ptr, kv_msk, other=0.0).to(tl.float32) # [DP, BK] 

132 

133 # step3: (q @ kv) * sm_scale 

134 qk = tl.dot(q_blk, kv_blk) 

135 if TDP > 0: 

136 tkv_ptr = ( 

137 tkv_base 

138 + offs_td[:, None] * stride_kvd 

139 + kv_ids[None, :] * stride_kvn 

140 ) 

141 tkv_msk = mask_td[:, None] & mask_ids[None, :] 

142 tkv_blk = tl.load(tkv_ptr, tkv_msk, other=0.0).to( 

143 tl.float32 

144 ) # [TDP, BK] 

145 qk = tl.dot(tq_blk, tkv_blk, qk) * sm_scale 

146 else: 

147 qk = qk * sm_scale 

148 

149 # step4: preprocess for logsumexp 

150 qk = tl.where(mask_ids[None, :], qk, float("-inf")) # [BH, BK] 

151 # step5: lse=log2sumexp2(qk), loop part 

152 new_max = tl.maximum(max_log, tl.max(qk, axis=1)) # [BH] 

153 # avoid nan generated by ((-inf) - (-inf)) 

154 tmp = qk - new_max[:, None] 

155 tmp = tl.where( 

156 (~mask_ids[None, :]) & (new_max[:, None] == float("-inf")), 

157 float("-inf"), 

158 tmp, 

159 ) 

160 exp_qk = tl.math.exp(tmp) # [BH, BK] 

161 sum_qk = tl.sum(exp_qk, axis=1) # [BH] 

162 # avoid nan generated by ((-inf) - (-inf)) 

163 tmp2 = max_log - new_max 

164 tmp2 = tl.where( 

165 (max_log == float("-inf")) & (new_max == float("-inf")), 

166 float("-inf"), 

167 tmp2, 

168 ) 

169 alpha = tl.math.exp(tmp2) # [BH] 

170 sum_exp = tl.fma(sum_exp, alpha, sum_qk) # [BH] 

171 acc = acc * alpha[:, None] # [BH, DP] 

172 # step6: exp2(qk-lse) @ gathered_kv.trans(), loop part 

173 acc = tl.dot(exp_qk, kv_blk.trans(), acc) # [BH, DP] 

174 max_log = new_max 

175 

176 # step7: store max_logits 

177 max_log_ptr = max_log_base + offs_h * stride_lh 

178 tl.store(max_log_ptr, max_log, mask_h) # [BH], float32 

179 

180 # step8: lse=log2sumexp2(qk) final part, store lse 

181 orig_lse = max_log + tl.math.log(sum_exp) 

182 lse_out = tl.where(orig_lse == float("-inf"), float("inf"), orig_lse) 

183 l_ptr = l_base + offs_h * stride_lh 

184 l_msk = mask_h 

185 tl.store(l_ptr, lse_out, l_msk) # [BH], float32 

186 

187 # step9: exp2(qk-lse) @ gathered_kv.trans(), final part 

188 if HAVE_ATTN_SINK: 

189 # step10: attn_sink 

190 exp_max_qk = tl.math.exp(max_log) # [BH] 

191 exp_orig_lse = tl.math.exp(orig_lse) 

192 sink = tl.load(attn_sink_ptr + offs_h).to(tl.float32) # [BH] 

193 exp_sink = tl.math.exp(sink) 

194 sum_exp_new_lse = exp_orig_lse + exp_sink 

195 # avoid divide 0 

196 sum_exp_new_lse = tl.where(sum_exp_new_lse == 0.0, 1.0, sum_exp_new_lse) 

197 factor = exp_max_qk / sum_exp_new_lse 

198 out_vals = acc * factor[:, None] 

199 else: 

200 # avoid divide 0 

201 sum_exp = tl.where(sum_exp == 0, 1.0, sum_exp) 

202 out_vals = acc / sum_exp[:, None] 

203 

204 # step11: store output 

205 o_ptr = ( 

206 o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od 

207 ) # [BH, DP] 

208 o_msk = mask_h[:, None] & mask_od[None, :] 

209 tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk) 

210 

211 

212def flash_mla_sparse_fwd( 

213 q: torch.Tensor, 

214 kv: torch.Tensor, 

215 indices: torch.Tensor, 

216 sm_scale: float, 

217 d_v: int = 512, 

218 attn_sink: Optional[torch.Tensor] = None, 

219 topk_length: Optional[torch.Tensor] = None, 

220) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

221 """ 

222 Sparse attention prefill kernel 

223 

224 Args: 

225 q: [s_q, h_q, d_qk], bfloat16 

226 kv: [s_kv, h_kv, d_qk], bfloat16 

227 indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv 

228 sm_scale: float 

229 d_v: The dimension of value vectors. Can only be 512 

230 attn_sink: optional, [h_q], float32. 

231 If attn_sink is provided, when computing output, output will be additionally multiplied by 

232 exp(lse) / (exp(lse) + exp(attn_sink)). +-inf in attn_sink will be handled normally (i.e., -inf has no 

233 effect, +inf will make corresponding output all zeros). 

234 This argument has no effect on lse and max_logits. 

235 topk_length: optional, [s_q], int32. 

236 If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], 

237 ignoring later k/v tokens (even if provided in indices). In extremely rare cases (topk_length provided, 

238 there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token 

239 containing NaN), operator output will contain NaN, so please avoid this situation. 

240 

241 Returns: 

242 (output, max_logits, lse) 

243 Please refer to tests/ref.py for the precise definitions of these parameters. 

244 - output: [s_q, h_q, d_v], bfloat16 

245 - max_logits: [s_q, h_q], float 

246 - lse: [s_q, h_q], float, log-sum-exp of attention scores 

247 """ 

248 is_causal = False # turn off opt for causal sparse attention 

249 assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() 

250 SQ, H, DT = q.shape 

251 SKV, VG, _ = kv.shape 

252 

253 assert d_v == 512, "Unsupported d_v" 

254 D = d_v 

255 

256 assert kv.shape[-1] == DT 

257 TD = DT - D 

258 DP = triton.next_power_of_2(D) 

259 TDP = triton.next_power_of_2(TD) 

260 _, _, K = indices.shape 

261 assert indices.shape == (SQ, VG, K) 

262 if attn_sink is not None: 

263 assert attn_sink.shape == (H,), "attn_sink error shape" 

264 if topk_length is not None: 

265 assert topk_length.shape == (SQ,), "topk_length error shape" 

266 

267 # check from FlashMLA 

268 assert VG == 1, "h_kv is expected to be 1" 

269 assert H == 64 or H == 128, "Unsupported h_q" 

270 assert DT == 576 or DT == 512, "Unsupported d_qk" 

271 

272 G = H // VG 

273 BH = max(16, min(32, triton.next_power_of_2(G))) 

274 NH = triton.cdiv(G, BH) 

275 BK = 16 # used to be out of memory for 32 

276 output = torch.zeros((SQ, H, D), device=q.device, dtype=q.dtype) 

277 max_logits = torch.full( 

278 (SQ, H), float("-inf"), device=q.device, dtype=torch.float32 

279 ) 

280 lse = torch.full((SQ, H), float("-inf"), device=q.device, dtype=torch.float32) 

281 INT32_MAX = 2147483647 

282 q_idx_i64 = q.numel() > INT32_MAX 

283 output_idx_i64 = output.numel() > INT32_MAX 

284 grid = (SQ, VG * NH, 1) 

285 triton_flash_mla_sparse_fwd[grid]( 

286 q, 

287 kv, 

288 indices, 

289 attn_sink, 

290 topk_length, 

291 sm_scale, 

292 output, 

293 max_logits, 

294 lse, 

295 q.stride(1), 

296 q.stride(0), 

297 q.stride(2), 

298 kv.stride(1), 

299 kv.stride(0), 

300 kv.stride(2), 

301 indices.stride(1), 

302 indices.stride(0), 

303 indices.stride(2), 

304 attn_sink.stride(0) if attn_sink is not None else 0, 

305 topk_length.stride(0) if topk_length is not None else 0, 

306 output.stride(1), 

307 output.stride(0), 

308 output.stride(2), 

309 max_logits.stride(1), 

310 max_logits.stride(0), 

311 lse.stride(1), 

312 lse.stride(0), 

313 SQ, 

314 SKV, 

315 K, 

316 D, 

317 TD, 

318 DP, 

319 TDP, 

320 G, 

321 BK, 

322 BH, 

323 is_causal, 

324 q_idx_i64, 

325 output_idx_i64, 

326 attn_sink is not None, 

327 topk_length is not None, 

328 ) 

329 return output, max_logits, lse