Coverage for src/flag_gems/fused/flash_mla.py: 14%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device, error, torch_device_fn 

9from flag_gems.utils import triton_lang_extension as tle 

10from flag_gems.utils.device_info import get_device_capability 

11 

12vendor_name = device.vendor_name 

13device = device.name 

14logger = logging.getLogger(__name__) 

15 

16 

17# @triton.autotune( 

18# configs=[ 

19# triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s) 

20# for h in [32, 64, 128] 

21# for n in [32, 64, 128] 

22# for w in [4, 8] 

23# for s in [1, 2] 

24# ], 

25# key=["head_num"] 

26# ) 

27@triton.heuristics( 

28 values={ 

29 "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0, 

30 } 

31) 

32@triton.jit 

33def flash_mla_attn_kernel( 

34 Q_ptr, 

35 Kv_cache, 

36 Req_to_tokens, 

37 B_seq_len, 

38 O, 

39 sm_scale, 

40 head_num, 

41 stride_q_bs, 

42 stride_q_h, 

43 stride_kv_bs, 

44 stride_req_to_tokens_bs, 

45 stride_o_b, 

46 stride_o_h, 

47 stride_o_s, 

48 BLOCK_H: tl.constexpr, 

49 BLOCK_N: tl.constexpr, 

50 EVEN_H: tl.constexpr, 

51 PAGE_SIZE: tl.constexpr, 

52 HEAD_DIM_V: tl.constexpr, 

53 HEAD_DIM: tl.constexpr, 

54): 

55 cur_head_id = tle.program_id(0) 

56 cur_batch_id = tle.program_id(1) 

57 Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id 

58 

59 cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) 

60 

61 offs_d_ckv = tl.arange(0, HEAD_DIM_V) 

62 offs_q_nope = ( 

63 cur_batch_id * stride_q_bs 

64 + cur_head[:, None] * stride_q_h 

65 + offs_d_ckv[None, :] 

66 ) 

67 

68 offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM) 

69 offs_q_pe = ( 

70 cur_batch_id * stride_q_bs 

71 + cur_head[:, None] * stride_q_h 

72 + offs_d_kpe[None, :] 

73 ) 

74 

75 if EVEN_H: 

76 q_nope = tl.load(Q_ptr + offs_q_nope) 

77 q_pe = tl.load(Q_ptr + offs_q_pe) 

78 else: 

79 mask_head = cur_head < head_num 

80 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None]) 

81 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None]) 

82 

83 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32) 

84 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) 

85 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32) 

86 

87 cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id) 

88 loop_time = cur_batch_seq_len // BLOCK_N 

89 remainder = cur_batch_seq_len % BLOCK_N 

90 offs_n = tl.arange(0, BLOCK_N) 

91 for i in range(0, loop_time): 

92 kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE) 

93 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

94 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] 

95 v_c = tl.load(Kv_cache + offs_v_c) 

96 k_c = tl.trans(v_c) 

97 

98 qk = tl.dot(q_nope, k_c) # qk_nope 

99 

100 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] 

101 k_pe = tl.load(Kv_cache + offs_k_pe) 

102 

103 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope 

104 qk *= sm_scale 

105 

106 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

107 re_scale = tl.exp(e_max - n_e_max) 

108 p = tl.exp(qk - n_e_max[:, None]) 

109 acc *= re_scale[:, None] 

110 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

111 

112 e_sum = e_sum * re_scale + tl.sum(p, 1) 

113 e_max = n_e_max 

114 offs_n += BLOCK_N 

115 

116 if remainder: 

117 mask_kvsplit = offs_n < cur_batch_seq_len 

118 kv_page_number = tl.load( 

119 Req_to_tokens + offs_n // PAGE_SIZE, 

120 mask=mask_kvsplit, 

121 other=0, 

122 ) 

123 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

124 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] 

125 v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0) 

126 k_c = tl.trans(v_c) 

127 

128 qk = tl.dot(q_nope, k_c) # qk_nope 

129 

130 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] 

131 k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0) 

132 

133 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope 

134 qk *= sm_scale 

135 

136 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf")) 

137 

138 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

139 re_scale = tl.exp(e_max - n_e_max) 

140 p = tl.exp(qk - n_e_max[:, None]) 

141 acc *= re_scale[:, None] 

142 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

143 

144 e_sum = e_sum * re_scale + tl.sum(p, 1) 

145 

146 offs_o = ( 

147 cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :] 

148 ) 

149 if EVEN_H: 

150 tl.store( 

151 O + offs_o, 

152 acc / e_sum[:, None], 

153 ) 

154 else: 

155 tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None]) 

156 

157 

158def flash_mla( 

159 q, 

160 block_table, 

161 blocked_k, 

162 max_seqlen_pad, 

163 block_size, 

164 b, 

165 s_q, 

166 cache_seqlens, 

167 h_q, 

168 h_kv, 

169 d, 

170 dv, 

171 causal, 

172): 

173 logger.debug("GEMS FLASH MLA") 

174 assert causal, "causal False not supported" 

175 assert d > dv, "mla with rope dim should be larger than no rope dim" 

176 

177 batch_size, s_q, head_num, d = list(q.shape) 

178 q = q.view([-1, head_num, d]).contiguous() 

179 blocked_k = blocked_k.view([-1, d]).contiguous() 

180 block_table = block_table.contiguous() 

181 cache_seqlens = cache_seqlens.contiguous() 

182 

183 sm_scale = 1 / math.sqrt(d) 

184 

185 o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device) 

186 

187 major, _ = get_device_capability() 

188 if major == 9: 

189 BLOCK_H = 64 

190 num_stages = 3 

191 elif major == 8: 

192 BLOCK_H = 32 

193 num_stages = 2 

194 elif major == 7 and vendor_name == "iluvatar": 

195 BLOCK_H = 32 

196 num_stages = 1 

197 elif major == 3 and vendor_name == "mthreads": 

198 BLOCK_H = 32 

199 num_stages = 1 

200 else: 

201 error.backend_not_support(device) 

202 BLOCK_N = 64 

203 grid = ( 

204 triton.cdiv(head_num, BLOCK_H), 

205 batch_size, 

206 ) 

207 with torch_device_fn.device(device): 

208 flash_mla_attn_kernel[grid]( 

209 q, 

210 blocked_k, 

211 block_table, 

212 cache_seqlens, 

213 o, 

214 sm_scale, 

215 head_num, 

216 # stride 

217 q.stride(0), 

218 q.stride(1), 

219 blocked_k.stride(-2), 

220 block_table.stride(0), 

221 o.stride(0), 

222 o.stride(1), 

223 o.stride(2), 

224 BLOCK_H=BLOCK_H, 

225 BLOCK_N=BLOCK_N, 

226 PAGE_SIZE=block_size, 

227 HEAD_DIM_V=dv, 

228 HEAD_DIM=d, 

229 num_warps=8, 

230 num_stages=num_stages, 

231 ) 

232 

233 return o.view([b, s_q, h_q, dv])