Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/flash_mla.py: 0%

99 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device, error, torch_device_fn 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11device = device.name 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15# @triton.autotune( 

16# configs=[ 

17# triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s) 

18# for h in [32, 64, 128] 

19# for n in [32, 64, 128] 

20# for w in [4, 8] 

21# for s in [1, 2] 

22# ], 

23# key=["head_num"] 

24# ) 

25@triton.heuristics( 

26 values={ 

27 "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0, 

28 } 

29) 

30@triton.jit 

31def flash_mla_attn_kernel( 

32 Q_ptr, 

33 Kv_cache, 

34 Req_to_tokens, 

35 B_seq_len, 

36 O, 

37 sm_scale, 

38 head_num, 

39 stride_q_bs, 

40 stride_q_h, 

41 stride_kv_bs, 

42 stride_req_to_tokens_bs, 

43 stride_o_b, 

44 stride_o_h, 

45 stride_o_s, 

46 BLOCK_H: tl.constexpr, 

47 BLOCK_N: tl.constexpr, 

48 EVEN_H: tl.constexpr, 

49 PAGE_SIZE: tl.constexpr, 

50 HEAD_DIM_V: tl.constexpr, 

51 HEAD_DIM: tl.constexpr, 

52): 

53 cur_head_id = tle.program_id(0) 

54 cur_batch_id = tle.program_id(1) 

55 Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id 

56 

57 cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) 

58 

59 offs_d_ckv = tl.arange(0, HEAD_DIM_V) 

60 offs_q_nope = ( 

61 cur_batch_id * stride_q_bs 

62 + cur_head[:, None] * stride_q_h 

63 + offs_d_ckv[None, :] 

64 ) 

65 

66 offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM) 

67 offs_q_pe = ( 

68 cur_batch_id * stride_q_bs 

69 + cur_head[:, None] * stride_q_h 

70 + offs_d_kpe[None, :] 

71 ) 

72 

73 if EVEN_H: 

74 q_nope = tl.load(Q_ptr + offs_q_nope) 

75 q_pe = tl.load(Q_ptr + offs_q_pe) 

76 else: 

77 mask_head = cur_head < head_num 

78 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None]) 

79 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None]) 

80 

81 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32) 

82 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) 

83 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32) 

84 

85 cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id) 

86 loop_time = cur_batch_seq_len // BLOCK_N 

87 remainder = cur_batch_seq_len % BLOCK_N 

88 offs_n = tl.arange(0, BLOCK_N) 

89 for i in range(0, loop_time): 

90 kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE) 

91 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

92 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] 

93 v_c = tl.load(Kv_cache + offs_v_c) 

94 k_c = tl.trans(v_c) 

95 

96 qk = tl.dot(q_nope, k_c) # qk_nope 

97 

98 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] 

99 k_pe = tl.load(Kv_cache + offs_k_pe) 

100 

101 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope 

102 qk *= sm_scale 

103 

104 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

105 re_scale = tl.exp(e_max - n_e_max) 

106 p = tl.exp(qk - n_e_max[:, None]) 

107 acc *= re_scale[:, None] 

108 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

109 

110 e_sum = e_sum * re_scale + tl.sum(p, 1) 

111 e_max = n_e_max 

112 offs_n += BLOCK_N 

113 

114 if remainder: 

115 mask_kvsplit = offs_n < cur_batch_seq_len 

116 kv_page_number = tl.load( 

117 Req_to_tokens + offs_n // PAGE_SIZE, 

118 mask=mask_kvsplit, 

119 other=0, 

120 ) 

121 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

122 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] 

123 v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0) 

124 k_c = tl.trans(v_c) 

125 

126 qk = tl.dot(q_nope, k_c) # qk_nope 

127 

128 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] 

129 k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0) 

130 

131 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope 

132 qk *= sm_scale 

133 

134 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf")) 

135 

136 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

137 re_scale = tl.exp(e_max - n_e_max) 

138 p = tl.exp(qk - n_e_max[:, None]) 

139 acc *= re_scale[:, None] 

140 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

141 

142 e_sum = e_sum * re_scale + tl.sum(p, 1) 

143 

144 offs_o = ( 

145 cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :] 

146 ) 

147 if EVEN_H: 

148 tl.store( 

149 O + offs_o, 

150 acc / e_sum[:, None], 

151 ) 

152 else: 

153 tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None]) 

154 

155 

156def flash_mla( 

157 q, 

158 block_table, 

159 blocked_k, 

160 max_seqlen_pad, 

161 block_size, 

162 b, 

163 s_q, 

164 cache_seqlens, 

165 h_q, 

166 h_kv, 

167 d, 

168 dv, 

169 causal, 

170): 

171 logger.debug("GEMS FLASH MLA") 

172 assert causal, "causal False not supported" 

173 assert d > dv, "mla with rope dim should be larger than no rope dim" 

174 

175 batch_size, s_q, head_num, d = list(q.shape) 

176 q = q.view([-1, head_num, d]).contiguous() 

177 blocked_k = blocked_k.view([-1, d]).contiguous() 

178 block_table = block_table.contiguous() 

179 cache_seqlens = cache_seqlens.contiguous() 

180 

181 sm_scale = 1 / math.sqrt(d) 

182 

183 o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device) 

184 

185 major, _ = torch.cuda.get_device_capability(device) 

186 if major == 9: 

187 BLOCK_H = 64 

188 num_stages = 3 

189 elif major == 8: 

190 BLOCK_H = 32 

191 num_stages = 2 

192 else: 

193 error.backend_not_support(device) 

194 BLOCK_N = 64 

195 grid = ( 

196 triton.cdiv(head_num, BLOCK_H), 

197 batch_size, 

198 ) 

199 with torch_device_fn.device(device): 

200 flash_mla_attn_kernel[grid]( 

201 q, 

202 blocked_k, 

203 block_table, 

204 cache_seqlens, 

205 o, 

206 sm_scale, 

207 head_num, 

208 # stride 

209 q.stride(0), 

210 q.stride(1), 

211 blocked_k.stride(-2), 

212 block_table.stride(0), 

213 o.stride(0), 

214 o.stride(1), 

215 o.stride(2), 

216 BLOCK_H=BLOCK_H, 

217 BLOCK_N=BLOCK_N, 

218 PAGE_SIZE=block_size, 

219 HEAD_DIM_V=dv, 

220 HEAD_DIM=d, 

221 num_warps=8, 

222 num_stages=num_stages, 

223 ) 

224 

225 return o.view([b, s_q, h_q, dv])