Coverage for src/flag_gems/runtime/backend/_metax/fused/flash_mla.py: 0%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device, error, torch_device_fn 

9from flag_gems.utils import triton_lang_extension as tle 

10from flag_gems.utils.device_info import get_device_capability 

11 

12vendor_name = device.vendor_name 

13device = device.name 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.heuristics( 

18 values={ 

19 "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0, 

20 } 

21) 

22@triton.jit 

23def flash_mla_attn_kernel( 

24 Q_ptr, 

25 Kv_cache, 

26 Req_to_tokens, 

27 B_seq_len, 

28 O, 

29 sm_scale, 

30 head_num, 

31 stride_q_bs, 

32 stride_q_h, 

33 stride_kv_bs, 

34 stride_req_to_tokens_bs, 

35 stride_o_b, 

36 stride_o_h, 

37 stride_o_s, 

38 BLOCK_H: tl.constexpr, 

39 BLOCK_N: tl.constexpr, 

40 EVEN_H: tl.constexpr, 

41 PAGE_SIZE: tl.constexpr, 

42 HEAD_DIM_V: tl.constexpr, 

43 HEAD_DIM: tl.constexpr, 

44): 

45 cur_head_id = tle.program_id(0) 

46 cur_batch_id = tle.program_id(1) 

47 Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id 

48 

49 cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) 

50 

51 offs_d_ckv = tl.arange(0, HEAD_DIM_V) 

52 offs_q_nope = ( 

53 cur_batch_id * stride_q_bs 

54 + cur_head[:, None] * stride_q_h 

55 + offs_d_ckv[None, :] 

56 ) 

57 

58 offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM) 

59 offs_q_pe = ( 

60 cur_batch_id * stride_q_bs 

61 + cur_head[:, None] * stride_q_h 

62 + offs_d_kpe[None, :] 

63 ) 

64 

65 if EVEN_H: 

66 q_nope = tl.load(Q_ptr + offs_q_nope) 

67 q_pe = tl.load(Q_ptr + offs_q_pe) 

68 else: 

69 mask_head = cur_head < head_num 

70 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None]) 

71 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None]) 

72 

73 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32) 

74 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) 

75 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32) 

76 

77 cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id) 

78 loop_time = cur_batch_seq_len // BLOCK_N 

79 remainder = cur_batch_seq_len % BLOCK_N 

80 offs_n = tl.arange(0, BLOCK_N) 

81 for i in range(0, loop_time): 

82 kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE) 

83 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

84 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] 

85 v_c = tl.load(Kv_cache + offs_v_c) 

86 k_c = tl.trans(v_c) 

87 

88 qk = tl.dot(q_nope, k_c) # qk_nope 

89 

90 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] 

91 k_pe = tl.load(Kv_cache + offs_k_pe) 

92 

93 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope 

94 qk *= sm_scale 

95 

96 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

97 re_scale = tl.exp(e_max - n_e_max) 

98 p = tl.exp(qk - n_e_max[:, None]) 

99 acc *= re_scale[:, None] 

100 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

101 

102 e_sum = e_sum * re_scale + tl.sum(p, 1) 

103 e_max = n_e_max 

104 offs_n += BLOCK_N 

105 

106 if remainder: 

107 mask_kvsplit = offs_n < cur_batch_seq_len 

108 kv_page_number = tl.load( 

109 Req_to_tokens + offs_n // PAGE_SIZE, 

110 mask=mask_kvsplit, 

111 other=0, 

112 ) 

113 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

114 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] 

115 v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0) 

116 k_c = tl.trans(v_c) 

117 

118 qk = tl.dot(q_nope, k_c) # qk_nope 

119 

120 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] 

121 k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0) 

122 

123 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope 

124 qk *= sm_scale 

125 

126 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf")) 

127 

128 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

129 re_scale = tl.exp(e_max - n_e_max) 

130 p = tl.exp(qk - n_e_max[:, None]) 

131 acc *= re_scale[:, None] 

132 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

133 

134 e_sum = e_sum * re_scale + tl.sum(p, 1) 

135 

136 offs_o = ( 

137 cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :] 

138 ) 

139 if EVEN_H: 

140 tl.store( 

141 O + offs_o, 

142 acc / e_sum[:, None], 

143 ) 

144 else: 

145 tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None]) 

146 

147 

148def flash_mla( 

149 q, 

150 block_table, 

151 blocked_k, 

152 max_seqlen_pad, 

153 block_size, 

154 b, 

155 s_q, 

156 cache_seqlens, 

157 h_q, 

158 h_kv, 

159 d, 

160 dv, 

161 causal, 

162): 

163 logger.debug("METAX GEMS FLASH MLA") 

164 assert causal, "causal False not supported" 

165 assert d > dv, "mla with rope dim should be larger than no rope dim" 

166 

167 batch_size, s_q, head_num, d = list(q.shape) 

168 q = q.view([-1, head_num, d]).contiguous() 

169 blocked_k = blocked_k.view([-1, d]).contiguous() 

170 block_table = block_table.contiguous() 

171 cache_seqlens = cache_seqlens.contiguous() 

172 

173 sm_scale = 1 / math.sqrt(d) 

174 

175 o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device) 

176 

177 major, _ = get_device_capability() 

178 if major == 9: 

179 BLOCK_H = 64 

180 num_stages = 3 

181 elif major == 8: 

182 BLOCK_H = 16 

183 num_stages = 2 

184 elif major == 7 and vendor_name == "iluvatar": 

185 BLOCK_H = 32 

186 num_stages = 1 

187 else: 

188 error.backend_not_support(device) 

189 BLOCK_N = 32 

190 grid = ( 

191 triton.cdiv(head_num, BLOCK_H), 

192 batch_size, 

193 ) 

194 with torch_device_fn.device(device): 

195 flash_mla_attn_kernel[grid]( 

196 q, 

197 blocked_k, 

198 block_table, 

199 cache_seqlens, 

200 o, 

201 sm_scale, 

202 head_num, 

203 # stride 

204 q.stride(0), 

205 q.stride(1), 

206 blocked_k.stride(-2), 

207 block_table.stride(0), 

208 o.stride(0), 

209 o.stride(1), 

210 o.stride(2), 

211 BLOCK_H=BLOCK_H, 

212 BLOCK_N=BLOCK_N, 

213 PAGE_SIZE=block_size, 

214 HEAD_DIM_V=dv, 

215 HEAD_DIM=d, 

216 num_warps=8, 

217 num_stages=num_stages, 

218 ) 

219 

220 return o.view([b, s_q, h_q, dv])