Coverage for src/flag_gems/fused/DSA/sparse_mla.py: 10%

99 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9spar_mla_fwd_configs = [ 

10 triton.Config({"num_stages": 4}, num_warps=8), 

11 triton.Config({"num_stages": 2}, num_warps=4), 

12] 

13 

14 

15@triton.autotune( # Decorate the kernel 

16 configs=spar_mla_fwd_configs, 

17 key=["K", "is_causal"], 

18) 

19@triton.jit 

20def triton_sparse_mla_fwd( 

21 q, 

22 kv, 

23 indices, 

24 sm_scale: tl.constexpr, 

25 output, 

26 lse, 

27 stride_qb, 

28 stride_qh, 

29 stride_qm, 

30 stride_qd, 

31 stride_kvb, 

32 stride_kvg, 

33 stride_kvn, 

34 stride_kvd, 

35 stride_tb, 

36 stride_tg, 

37 stride_tm, 

38 stride_tt, # indices dim 

39 stride_ob, 

40 stride_oh, 

41 stride_om, 

42 stride_od, 

43 stride_lb, 

44 stride_lh, 

45 stride_lm, 

46 SQ: tl.constexpr, # seqlen 

47 K: tl.constexpr, # topk 

48 D: tl.constexpr, # QKV dim 

49 TD: tl.constexpr, # tail dim 

50 DP: tl.constexpr, 

51 TDP: tl.constexpr, 

52 G: tl.constexpr, # group_size 

53 BK: tl.constexpr, 

54 BH: tl.constexpr, 

55 is_causal: tl.constexpr, 

56): 

57 i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

58 NH = tl.cdiv(G, BH) 

59 i_g, i_bh = i_gbh // NH, i_gbh % NH 

60 q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh) 

61 tq_base = q_base + D * stride_qd 

62 kv_base = kv + i_b * stride_kvb + i_g * stride_kvg 

63 tkv_base = kv_base + D * stride_kvd 

64 t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg 

65 o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh) 

66 l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh) 

67 

68 offs_h = tl.arange(0, BH) 

69 offs_d = tl.arange(0, DP) 

70 offs_td = tl.arange(0, TDP) 

71 offs_od = tl.arange(0, DP) 

72 offs_t = tl.arange(0, BK) 

73 mask_h = i_bh * BH + offs_h < G 

74 mask_d = offs_d < D 

75 mask_td = offs_td < TD 

76 mask_od = mask_d 

77 

78 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

79 q_msk = mask_h[:, None] & mask_d[None, :] 

80 q_blk = tl.load(q_ptr, q_msk, other=0.0).to(tl.float16) 

81 

82 tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd 

83 tq_msk = mask_h[:, None] & mask_td[None, :] 

84 tq_blk = tl.load(tq_ptr, tq_msk, other=0.0).to(tl.float16) 

85 

86 max_log = tl.full([BH], float("-inf"), dtype=tl.float16) 

87 sum_exp = tl.full([BH], 1.0, dtype=tl.float16) 

88 acc = tl.zeros([BH, DP], dtype=tl.float16) 

89 qk = tl.zeros([BH, BK], dtype=tl.float16) 

90 

91 log_scale: tl.constexpr = sm_scale * 1.44269504 

92 

93 # max_col = max(0, i_sq + SKV - SQ) if is_causal else SKV-1 

94 max_col = i_sq if is_causal else SQ - 1 

95 

96 NK = tl.cdiv(K, BK) 

97 for ck in range(NK): 

98 t_ptr = (BK * ck + offs_t) * stride_tt 

99 t_msk = t_ptr < K 

100 t_ptr += t_base 

101 kv_ids = tl.load(t_ptr, t_msk, other=-1) 

102 mask_ids = (kv_ids <= max_col) & (kv_ids >= 0) 

103 

104 if tl.max(mask_ids, axis=0) > 0: 

105 kv_ptr = ( 

106 kv_base + offs_d[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn 

107 ) 

108 kv_msk = mask_d[:, None] & mask_ids[None, :] 

109 kv_blk = tl.load(kv_ptr, kv_msk, other=0.0).to(tl.float16) # [DP, BK] 

110 

111 tkv_ptr = ( 

112 tkv_base + offs_td[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn 

113 ) 

114 tkv_msk = mask_td[:, None] & mask_ids[None, :] 

115 tkv_blk = tl.load(tkv_ptr, tkv_msk, other=0.0).to(tl.float16) # [TDP, BK] 

116 

117 qk = tl.dot(q_blk, kv_blk, out_dtype=tl.float16) 

118 qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float16) * log_scale 

119 # qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float16) * sm_scale 

120 

121 qk = tl.where(mask_ids[None, :], qk, float("-inf")) # [BH, BK] 

122 

123 new_max = tl.maximum(max_log, tl.max(qk, axis=1)) 

124 exp_qk = tl.math.exp2(qk - new_max[:, None]).to(tl.float16) 

125 # exp_qk = tl.math.exp(qk - new_max[:, None]).to(tl.float16) 

126 sum_qk = tl.sum(exp_qk, axis=1) 

127 alpha = tl.math.exp2(max_log - new_max).to(tl.float16) 

128 # alpha = tl.math.exp(max_log - new_max).to(tl.float16) 

129 sum_exp = sum_exp * alpha + sum_qk 

130 acc = acc * alpha[:, None] 

131 acc = tl.dot( 

132 exp_qk, kv_blk.trans(), acc, out_dtype=tl.float16 

133 ) # [BH, BK] @ [BK, DP] = [BH, DP] 

134 

135 max_log = new_max.to(tl.float16) 

136 

137 out_vals = acc / sum_exp[:, None] 

138 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od 

139 o_msk = mask_h[:, None] & mask_od[None, :] 

140 # o_msk &= tl.zeros_like(o_msk) 

141 tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk) 

142 

143 fin_log = max_log + tl.math.log2(sum_exp.to(tl.float32)) # return lse / ln2 

144 # fin_log *= 0.69314718 

145 # fin_log = max_log + tl.math.log(sum_exp.to(tl.float32)) 

146 # fin_log *= 1.44269504 # return lse / ln2 

147 l_ptr = l_base + offs_h * stride_lh 

148 l_msk = mask_h 

149 tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk) 

150 

151 

152def triton_sparse_mla_fwd_interface( 

153 q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512 

154): 

155 logger.debug("GEMS SPARSE_MLA_FWD_INTERFACE") 

156 is_causal = True 

157 assert return_p_sum is False, "This kernel file is for fwd only" 

158 assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() 

159 B, SQ, H, DT = q.shape 

160 _, _, VG, _ = kv.shape 

161 

162 # assert DT == 576, "you should assign dim otherwise" 

163 D = d_v 

164 

165 assert kv.shape[-1] == DT 

166 TD = DT - D 

167 DP = triton.next_power_of_2(D) 

168 TDP = triton.next_power_of_2(TD) 

169 assert kv.shape[0] == B 

170 _, _, _, K = indices.shape 

171 assert indices.shape == (B, SQ, VG, K) 

172 G = H // VG 

173 if sm_scale is None: 

174 sm_scale = DT**-0.5 

175 BH = max(16, min(64, triton.next_power_of_2(G))) 

176 NH = triton.cdiv(G, BH) 

177 BK = 32 

178 output = torch.zeros((B, SQ, H, D), device=q.device, dtype=q.dtype) 

179 lse = torch.full((B, SQ, H), float("-inf"), device=q.device, dtype=q.dtype) 

180 grid = (B, SQ, VG * NH) # (SQ//BQ, B*H) 

181 triton_sparse_mla_fwd[grid]( 

182 q, 

183 kv, 

184 indices, 

185 sm_scale, 

186 output, 

187 lse, 

188 q.stride(0), 

189 q.stride(2), 

190 q.stride(1), 

191 q.stride(3), # [B, H, SQ, DT] 

192 kv.stride(0), 

193 kv.stride(2), 

194 kv.stride(1), 

195 kv.stride(3), # [B, VG, SKV, DT] 

196 indices.stride(0), 

197 indices.stride(2), 

198 indices.stride(1), 

199 indices.stride(3), # [B, VG, SQ, K] 

200 output.stride(0), 

201 output.stride(2), 

202 output.stride(1), 

203 output.stride(3), # [B, H, SQ, D] 

204 lse.stride(0), 

205 lse.stride(2), 

206 lse.stride(1), # [B, H, SQ] 

207 SQ, 

208 K, 

209 D, 

210 TD, 

211 DP, 

212 TDP, 

213 G, 

214 BK, 

215 BH, 

216 is_causal, 

217 ) 

218 return output, lse