Coverage for src/flag_gems/fused/flashmla_sparse.py: 31%
133 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1from typing import Optional, Tuple
3import torch
4import triton
5import triton.language as tl
7flash_mla_sparse_fwd_configs = [
8 triton.Config({"num_stages": 4, "num_warps": 8}),
9 triton.Config({"num_stages": 2, "num_warps": 4}),
10]
13@triton.autotune( # Decorate the kernel
14 configs=flash_mla_sparse_fwd_configs,
15 key=["K", "is_causal"],
16)
17@triton.jit
18def triton_flash_mla_sparse_fwd(
19 q,
20 kv,
21 indices,
22 attn_sink,
23 topk_length,
24 sm_scale: tl.constexpr,
25 output,
26 max_logits,
27 lse,
28 stride_qh,
29 stride_qm,
30 stride_qd,
31 stride_kvg,
32 stride_kvn,
33 stride_kvd,
34 stride_tg,
35 stride_tm,
36 stride_tt, # indices dim
37 stride_attn_sink_h,
38 stride_topk_length_s,
39 stride_oh,
40 stride_om,
41 stride_od,
42 stride_mh,
43 stride_mm,
44 stride_lh,
45 stride_lm,
46 SQ: tl.constexpr, # seqlen
47 SKV: tl.constexpr, # seqlen_kv
48 K: tl.constexpr, # topk
49 D: tl.constexpr, # QKV dim
50 TD: tl.constexpr, # tail dim
51 DP: tl.constexpr,
52 TDP: tl.constexpr,
53 G: tl.constexpr, # group_size
54 BK: tl.constexpr,
55 BH: tl.constexpr,
56 is_causal: tl.constexpr,
57 q_idx_i64: tl.constexpr,
58 output_idx_i64: tl.constexpr,
59 HAVE_ATTN_SINK: tl.constexpr,
60 HAVE_TOPK_LENGTH: tl.constexpr,
61):
62 i_sq, i_gbh = tl.program_id(0), tl.program_id(1)
63 i_g, i_bh = i_gbh // G, i_gbh % G
64 if not q_idx_i64:
65 q_base = q + i_sq * stride_qm + i_gbh * (BH * stride_qh)
66 else:
67 q_base = q + i_sq * tl.cast(stride_qm, tl.int64) + i_gbh * (BH * stride_qh)
68 tq_base = q_base + D * stride_qd
69 kv_base = kv + i_g * stride_kvg
70 tkv_base = kv_base + D * stride_kvd
71 t_base = indices + i_sq * stride_tm + i_g * stride_tg
72 attn_sink_ptr = (
73 attn_sink + i_gbh * (BH * stride_attn_sink_h) if HAVE_ATTN_SINK else 0
74 )
75 topk_length_ptr = (
76 topk_length + i_sq * stride_topk_length_s if HAVE_TOPK_LENGTH else 0
77 )
78 if not output_idx_i64:
79 o_base = output + i_sq * stride_om + i_gbh * (BH * stride_oh)
80 else:
81 o_base = output + i_sq * tl.cast(stride_om, tl.int64) + i_gbh * (BH * stride_oh)
82 max_log_base = max_logits + i_sq * stride_mm + i_gbh * (BH * stride_mh)
83 l_base = lse + i_sq * stride_lm + i_gbh * (BH * stride_lh)
85 offs_h = tl.arange(0, BH)
86 offs_d = tl.arange(0, DP)
87 offs_td = tl.arange(0, TDP) if TDP > 0 else None
88 offs_od = tl.arange(0, DP)
89 offs_t = tl.arange(0, BK)
90 mask_h = i_bh * BH + offs_h < G
91 mask_d = offs_d < D
92 mask_td = offs_td < TD if TDP > 0 else None
93 mask_od = mask_d
95 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
96 q_msk = mask_h[:, None] & mask_d[None, :]
97 q_blk = tl.load(q_ptr, q_msk, other=0.0).to(tl.float32)
99 tq_blk = None
100 if TDP > 0:
101 tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd
102 tq_msk = mask_h[:, None] & mask_td[None, :]
103 tq_blk = tl.load(tq_ptr, tq_msk, other=0.0).to(tl.float32)
105 max_log = tl.full([BH], float("-inf"), dtype=tl.float32)
106 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
107 acc = tl.zeros([BH, DP], dtype=tl.float32)
108 qk = tl.zeros([BH, BK], dtype=tl.float32)
110 max_col = i_sq if is_causal else SKV - 1
111 topk_len = tl.load(topk_length_ptr).to(tl.int32) if HAVE_TOPK_LENGTH else K
113 NK = tl.cdiv(K, BK)
114 for ck in range(NK):
115 # step1: load indices
116 t_ptr = (BK * ck + offs_t) * stride_tt
117 t_msk = t_ptr < topk_len
118 t_ptr += t_base
119 kv_ids = tl.load(t_ptr, t_msk, other=-1)
120 mask_ids = (kv_ids <= max_col) & (kv_ids >= 0)
121 # filter invalid index that may cause overflow in mul
122 kv_ids = tl.where(mask_ids, kv_ids, 0)
124 # if mask_ids.max(0) > 0:
125 if ck * BK <= max_col:
126 # step2: gather kv with indices
127 kv_ptr = (
128 kv_base + offs_d[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn
129 )
130 kv_msk = mask_d[:, None] & mask_ids[None, :]
131 kv_blk = tl.load(kv_ptr, kv_msk, other=0.0).to(tl.float32) # [DP, BK]
133 # step3: (q @ kv) * sm_scale
134 qk = tl.dot(q_blk, kv_blk)
135 if TDP > 0:
136 tkv_ptr = (
137 tkv_base
138 + offs_td[:, None] * stride_kvd
139 + kv_ids[None, :] * stride_kvn
140 )
141 tkv_msk = mask_td[:, None] & mask_ids[None, :]
142 tkv_blk = tl.load(tkv_ptr, tkv_msk, other=0.0).to(
143 tl.float32
144 ) # [TDP, BK]
145 qk = tl.dot(tq_blk, tkv_blk, qk) * sm_scale
146 else:
147 qk = qk * sm_scale
149 # step4: preprocess for logsumexp
150 qk = tl.where(mask_ids[None, :], qk, float("-inf")) # [BH, BK]
151 # step5: lse=log2sumexp2(qk), loop part
152 new_max = tl.maximum(max_log, tl.max(qk, axis=1)) # [BH]
153 # avoid nan generated by ((-inf) - (-inf))
154 tmp = qk - new_max[:, None]
155 tmp = tl.where(
156 (~mask_ids[None, :]) & (new_max[:, None] == float("-inf")),
157 float("-inf"),
158 tmp,
159 )
160 exp_qk = tl.math.exp(tmp) # [BH, BK]
161 sum_qk = tl.sum(exp_qk, axis=1) # [BH]
162 # avoid nan generated by ((-inf) - (-inf))
163 tmp2 = max_log - new_max
164 tmp2 = tl.where(
165 (max_log == float("-inf")) & (new_max == float("-inf")),
166 float("-inf"),
167 tmp2,
168 )
169 alpha = tl.math.exp(tmp2) # [BH]
170 sum_exp = tl.fma(sum_exp, alpha, sum_qk) # [BH]
171 acc = acc * alpha[:, None] # [BH, DP]
172 # step6: exp2(qk-lse) @ gathered_kv.trans(), loop part
173 acc = tl.dot(exp_qk, kv_blk.trans(), acc) # [BH, DP]
174 max_log = new_max
176 # step7: store max_logits
177 max_log_ptr = max_log_base + offs_h * stride_lh
178 tl.store(max_log_ptr, max_log, mask_h) # [BH], float32
180 # step8: lse=log2sumexp2(qk) final part, store lse
181 orig_lse = max_log + tl.math.log(sum_exp)
182 lse_out = tl.where(orig_lse == float("-inf"), float("inf"), orig_lse)
183 l_ptr = l_base + offs_h * stride_lh
184 l_msk = mask_h
185 tl.store(l_ptr, lse_out, l_msk) # [BH], float32
187 # step9: exp2(qk-lse) @ gathered_kv.trans(), final part
188 if HAVE_ATTN_SINK:
189 # step10: attn_sink
190 exp_max_qk = tl.math.exp(max_log) # [BH]
191 exp_orig_lse = tl.math.exp(orig_lse)
192 sink = tl.load(attn_sink_ptr + offs_h).to(tl.float32) # [BH]
193 exp_sink = tl.math.exp(sink)
194 sum_exp_new_lse = exp_orig_lse + exp_sink
195 # avoid divide 0
196 sum_exp_new_lse = tl.where(sum_exp_new_lse == 0.0, 1.0, sum_exp_new_lse)
197 factor = exp_max_qk / sum_exp_new_lse
198 out_vals = acc * factor[:, None]
199 else:
200 # avoid divide 0
201 sum_exp = tl.where(sum_exp == 0, 1.0, sum_exp)
202 out_vals = acc / sum_exp[:, None]
204 # step11: store output
205 o_ptr = (
206 o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od
207 ) # [BH, DP]
208 o_msk = mask_h[:, None] & mask_od[None, :]
209 tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk)
212def flash_mla_sparse_fwd(
213 q: torch.Tensor,
214 kv: torch.Tensor,
215 indices: torch.Tensor,
216 sm_scale: float,
217 d_v: int = 512,
218 attn_sink: Optional[torch.Tensor] = None,
219 topk_length: Optional[torch.Tensor] = None,
220) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
221 """
222 Sparse attention prefill kernel
224 Args:
225 q: [s_q, h_q, d_qk], bfloat16
226 kv: [s_kv, h_kv, d_qk], bfloat16
227 indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
228 sm_scale: float
229 d_v: The dimension of value vectors. Can only be 512
230 attn_sink: optional, [h_q], float32.
231 If attn_sink is provided, when computing output, output will be additionally multiplied by
232 exp(lse) / (exp(lse) + exp(attn_sink)). +-inf in attn_sink will be handled normally (i.e., -inf has no
233 effect, +inf will make corresponding output all zeros).
234 This argument has no effect on lse and max_logits.
235 topk_length: optional, [s_q], int32.
236 If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]],
237 ignoring later k/v tokens (even if provided in indices). In extremely rare cases (topk_length provided,
238 there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token
239 containing NaN), operator output will contain NaN, so please avoid this situation.
241 Returns:
242 (output, max_logits, lse)
243 Please refer to tests/ref.py for the precise definitions of these parameters.
244 - output: [s_q, h_q, d_v], bfloat16
245 - max_logits: [s_q, h_q], float
246 - lse: [s_q, h_q], float, log-sum-exp of attention scores
247 """
248 is_causal = False # turn off opt for causal sparse attention
249 assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
250 SQ, H, DT = q.shape
251 SKV, VG, _ = kv.shape
253 assert d_v == 512, "Unsupported d_v"
254 D = d_v
256 assert kv.shape[-1] == DT
257 TD = DT - D
258 DP = triton.next_power_of_2(D)
259 TDP = triton.next_power_of_2(TD)
260 _, _, K = indices.shape
261 assert indices.shape == (SQ, VG, K)
262 if attn_sink is not None:
263 assert attn_sink.shape == (H,), "attn_sink error shape"
264 if topk_length is not None:
265 assert topk_length.shape == (SQ,), "topk_length error shape"
267 # check from FlashMLA
268 assert VG == 1, "h_kv is expected to be 1"
269 assert H == 64 or H == 128, "Unsupported h_q"
270 assert DT == 576 or DT == 512, "Unsupported d_qk"
272 G = H // VG
273 BH = max(16, min(32, triton.next_power_of_2(G)))
274 NH = triton.cdiv(G, BH)
275 BK = 16 # used to be out of memory for 32
276 output = torch.zeros((SQ, H, D), device=q.device, dtype=q.dtype)
277 max_logits = torch.full(
278 (SQ, H), float("-inf"), device=q.device, dtype=torch.float32
279 )
280 lse = torch.full((SQ, H), float("-inf"), device=q.device, dtype=torch.float32)
281 INT32_MAX = 2147483647
282 q_idx_i64 = q.numel() > INT32_MAX
283 output_idx_i64 = output.numel() > INT32_MAX
284 grid = (SQ, VG * NH, 1)
285 triton_flash_mla_sparse_fwd[grid](
286 q,
287 kv,
288 indices,
289 attn_sink,
290 topk_length,
291 sm_scale,
292 output,
293 max_logits,
294 lse,
295 q.stride(1),
296 q.stride(0),
297 q.stride(2),
298 kv.stride(1),
299 kv.stride(0),
300 kv.stride(2),
301 indices.stride(1),
302 indices.stride(0),
303 indices.stride(2),
304 attn_sink.stride(0) if attn_sink is not None else 0,
305 topk_length.stride(0) if topk_length is not None else 0,
306 output.stride(1),
307 output.stride(0),
308 output.stride(2),
309 max_logits.stride(1),
310 max_logits.stride(0),
311 lse.stride(1),
312 lse.stride(0),
313 SQ,
314 SKV,
315 K,
316 D,
317 TD,
318 DP,
319 TDP,
320 G,
321 BK,
322 BH,
323 is_causal,
324 q_idx_i64,
325 output_idx_i64,
326 attn_sink is not None,
327 topk_length is not None,
328 )
329 return output, max_logits, lse