Coverage for src/flag_gems/fused/DSA/sparse_mla.py: 10%
99 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
9spar_mla_fwd_configs = [
10 triton.Config({"num_stages": 4}, num_warps=8),
11 triton.Config({"num_stages": 2}, num_warps=4),
12]
15@triton.autotune( # Decorate the kernel
16 configs=spar_mla_fwd_configs,
17 key=["K", "is_causal"],
18)
19@triton.jit
20def triton_sparse_mla_fwd(
21 q,
22 kv,
23 indices,
24 sm_scale: tl.constexpr,
25 output,
26 lse,
27 stride_qb,
28 stride_qh,
29 stride_qm,
30 stride_qd,
31 stride_kvb,
32 stride_kvg,
33 stride_kvn,
34 stride_kvd,
35 stride_tb,
36 stride_tg,
37 stride_tm,
38 stride_tt, # indices dim
39 stride_ob,
40 stride_oh,
41 stride_om,
42 stride_od,
43 stride_lb,
44 stride_lh,
45 stride_lm,
46 SQ: tl.constexpr, # seqlen
47 K: tl.constexpr, # topk
48 D: tl.constexpr, # QKV dim
49 TD: tl.constexpr, # tail dim
50 DP: tl.constexpr,
51 TDP: tl.constexpr,
52 G: tl.constexpr, # group_size
53 BK: tl.constexpr,
54 BH: tl.constexpr,
55 is_causal: tl.constexpr,
56):
57 i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58 NH = tl.cdiv(G, BH)
59 i_g, i_bh = i_gbh // NH, i_gbh % NH
60 q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh)
61 tq_base = q_base + D * stride_qd
62 kv_base = kv + i_b * stride_kvb + i_g * stride_kvg
63 tkv_base = kv_base + D * stride_kvd
64 t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg
65 o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh)
66 l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh)
68 offs_h = tl.arange(0, BH)
69 offs_d = tl.arange(0, DP)
70 offs_td = tl.arange(0, TDP)
71 offs_od = tl.arange(0, DP)
72 offs_t = tl.arange(0, BK)
73 mask_h = i_bh * BH + offs_h < G
74 mask_d = offs_d < D
75 mask_td = offs_td < TD
76 mask_od = mask_d
78 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
79 q_msk = mask_h[:, None] & mask_d[None, :]
80 q_blk = tl.load(q_ptr, q_msk, other=0.0).to(tl.float16)
82 tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd
83 tq_msk = mask_h[:, None] & mask_td[None, :]
84 tq_blk = tl.load(tq_ptr, tq_msk, other=0.0).to(tl.float16)
86 max_log = tl.full([BH], float("-inf"), dtype=tl.float16)
87 sum_exp = tl.full([BH], 1.0, dtype=tl.float16)
88 acc = tl.zeros([BH, DP], dtype=tl.float16)
89 qk = tl.zeros([BH, BK], dtype=tl.float16)
91 log_scale: tl.constexpr = sm_scale * 1.44269504
93 # max_col = max(0, i_sq + SKV - SQ) if is_causal else SKV-1
94 max_col = i_sq if is_causal else SQ - 1
96 NK = tl.cdiv(K, BK)
97 for ck in range(NK):
98 t_ptr = (BK * ck + offs_t) * stride_tt
99 t_msk = t_ptr < K
100 t_ptr += t_base
101 kv_ids = tl.load(t_ptr, t_msk, other=-1)
102 mask_ids = (kv_ids <= max_col) & (kv_ids >= 0)
104 if tl.max(mask_ids, axis=0) > 0:
105 kv_ptr = (
106 kv_base + offs_d[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn
107 )
108 kv_msk = mask_d[:, None] & mask_ids[None, :]
109 kv_blk = tl.load(kv_ptr, kv_msk, other=0.0).to(tl.float16) # [DP, BK]
111 tkv_ptr = (
112 tkv_base + offs_td[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn
113 )
114 tkv_msk = mask_td[:, None] & mask_ids[None, :]
115 tkv_blk = tl.load(tkv_ptr, tkv_msk, other=0.0).to(tl.float16) # [TDP, BK]
117 qk = tl.dot(q_blk, kv_blk, out_dtype=tl.float16)
118 qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float16) * log_scale
119 # qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float16) * sm_scale
121 qk = tl.where(mask_ids[None, :], qk, float("-inf")) # [BH, BK]
123 new_max = tl.maximum(max_log, tl.max(qk, axis=1))
124 exp_qk = tl.math.exp2(qk - new_max[:, None]).to(tl.float16)
125 # exp_qk = tl.math.exp(qk - new_max[:, None]).to(tl.float16)
126 sum_qk = tl.sum(exp_qk, axis=1)
127 alpha = tl.math.exp2(max_log - new_max).to(tl.float16)
128 # alpha = tl.math.exp(max_log - new_max).to(tl.float16)
129 sum_exp = sum_exp * alpha + sum_qk
130 acc = acc * alpha[:, None]
131 acc = tl.dot(
132 exp_qk, kv_blk.trans(), acc, out_dtype=tl.float16
133 ) # [BH, BK] @ [BK, DP] = [BH, DP]
135 max_log = new_max.to(tl.float16)
137 out_vals = acc / sum_exp[:, None]
138 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od
139 o_msk = mask_h[:, None] & mask_od[None, :]
140 # o_msk &= tl.zeros_like(o_msk)
141 tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk)
143 fin_log = max_log + tl.math.log2(sum_exp.to(tl.float32)) # return lse / ln2
144 # fin_log *= 0.69314718
145 # fin_log = max_log + tl.math.log(sum_exp.to(tl.float32))
146 # fin_log *= 1.44269504 # return lse / ln2
147 l_ptr = l_base + offs_h * stride_lh
148 l_msk = mask_h
149 tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk)
152def triton_sparse_mla_fwd_interface(
153 q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512
154):
155 logger.debug("GEMS SPARSE_MLA_FWD_INTERFACE")
156 is_causal = True
157 assert return_p_sum is False, "This kernel file is for fwd only"
158 assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
159 B, SQ, H, DT = q.shape
160 _, _, VG, _ = kv.shape
162 # assert DT == 576, "you should assign dim otherwise"
163 D = d_v
165 assert kv.shape[-1] == DT
166 TD = DT - D
167 DP = triton.next_power_of_2(D)
168 TDP = triton.next_power_of_2(TD)
169 assert kv.shape[0] == B
170 _, _, _, K = indices.shape
171 assert indices.shape == (B, SQ, VG, K)
172 G = H // VG
173 if sm_scale is None:
174 sm_scale = DT**-0.5
175 BH = max(16, min(64, triton.next_power_of_2(G)))
176 NH = triton.cdiv(G, BH)
177 BK = 32
178 output = torch.zeros((B, SQ, H, D), device=q.device, dtype=q.dtype)
179 lse = torch.full((B, SQ, H), float("-inf"), device=q.device, dtype=q.dtype)
180 grid = (B, SQ, VG * NH) # (SQ//BQ, B*H)
181 triton_sparse_mla_fwd[grid](
182 q,
183 kv,
184 indices,
185 sm_scale,
186 output,
187 lse,
188 q.stride(0),
189 q.stride(2),
190 q.stride(1),
191 q.stride(3), # [B, H, SQ, DT]
192 kv.stride(0),
193 kv.stride(2),
194 kv.stride(1),
195 kv.stride(3), # [B, VG, SKV, DT]
196 indices.stride(0),
197 indices.stride(2),
198 indices.stride(1),
199 indices.stride(3), # [B, VG, SQ, K]
200 output.stride(0),
201 output.stride(2),
202 output.stride(1),
203 output.stride(3), # [B, H, SQ, D]
204 lse.stride(0),
205 lse.stride(2),
206 lse.stride(1), # [B, H, SQ]
207 SQ,
208 K,
209 D,
210 TD,
211 DP,
212 TDP,
213 G,
214 BK,
215 BH,
216 is_causal,
217 )
218 return output, lse