Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/flash_mla.py: 0%
99 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device, error, torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
11device = device.name
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15# @triton.autotune(
16# configs=[
17# triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s)
18# for h in [32, 64, 128]
19# for n in [32, 64, 128]
20# for w in [4, 8]
21# for s in [1, 2]
22# ],
23# key=["head_num"]
24# )
25@triton.heuristics(
26 values={
27 "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0,
28 }
29)
30@triton.jit
31def flash_mla_attn_kernel(
32 Q_ptr,
33 Kv_cache,
34 Req_to_tokens,
35 B_seq_len,
36 O,
37 sm_scale,
38 head_num,
39 stride_q_bs,
40 stride_q_h,
41 stride_kv_bs,
42 stride_req_to_tokens_bs,
43 stride_o_b,
44 stride_o_h,
45 stride_o_s,
46 BLOCK_H: tl.constexpr,
47 BLOCK_N: tl.constexpr,
48 EVEN_H: tl.constexpr,
49 PAGE_SIZE: tl.constexpr,
50 HEAD_DIM_V: tl.constexpr,
51 HEAD_DIM: tl.constexpr,
52):
53 cur_head_id = tle.program_id(0)
54 cur_batch_id = tle.program_id(1)
55 Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id
57 cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
59 offs_d_ckv = tl.arange(0, HEAD_DIM_V)
60 offs_q_nope = (
61 cur_batch_id * stride_q_bs
62 + cur_head[:, None] * stride_q_h
63 + offs_d_ckv[None, :]
64 )
66 offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM)
67 offs_q_pe = (
68 cur_batch_id * stride_q_bs
69 + cur_head[:, None] * stride_q_h
70 + offs_d_kpe[None, :]
71 )
73 if EVEN_H:
74 q_nope = tl.load(Q_ptr + offs_q_nope)
75 q_pe = tl.load(Q_ptr + offs_q_pe)
76 else:
77 mask_head = cur_head < head_num
78 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None])
79 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None])
81 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32)
82 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
83 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32)
85 cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id)
86 loop_time = cur_batch_seq_len // BLOCK_N
87 remainder = cur_batch_seq_len % BLOCK_N
88 offs_n = tl.arange(0, BLOCK_N)
89 for i in range(0, loop_time):
90 kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE)
91 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
92 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
93 v_c = tl.load(Kv_cache + offs_v_c)
94 k_c = tl.trans(v_c)
96 qk = tl.dot(q_nope, k_c) # qk_nope
98 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
99 k_pe = tl.load(Kv_cache + offs_k_pe)
101 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
102 qk *= sm_scale
104 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
105 re_scale = tl.exp(e_max - n_e_max)
106 p = tl.exp(qk - n_e_max[:, None])
107 acc *= re_scale[:, None]
108 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
110 e_sum = e_sum * re_scale + tl.sum(p, 1)
111 e_max = n_e_max
112 offs_n += BLOCK_N
114 if remainder:
115 mask_kvsplit = offs_n < cur_batch_seq_len
116 kv_page_number = tl.load(
117 Req_to_tokens + offs_n // PAGE_SIZE,
118 mask=mask_kvsplit,
119 other=0,
120 )
121 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
122 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
123 v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0)
124 k_c = tl.trans(v_c)
126 qk = tl.dot(q_nope, k_c) # qk_nope
128 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
129 k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0)
131 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
132 qk *= sm_scale
134 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf"))
136 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
137 re_scale = tl.exp(e_max - n_e_max)
138 p = tl.exp(qk - n_e_max[:, None])
139 acc *= re_scale[:, None]
140 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
142 e_sum = e_sum * re_scale + tl.sum(p, 1)
144 offs_o = (
145 cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :]
146 )
147 if EVEN_H:
148 tl.store(
149 O + offs_o,
150 acc / e_sum[:, None],
151 )
152 else:
153 tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None])
156def flash_mla(
157 q,
158 block_table,
159 blocked_k,
160 max_seqlen_pad,
161 block_size,
162 b,
163 s_q,
164 cache_seqlens,
165 h_q,
166 h_kv,
167 d,
168 dv,
169 causal,
170):
171 logger.debug("GEMS FLASH MLA")
172 assert causal, "causal False not supported"
173 assert d > dv, "mla with rope dim should be larger than no rope dim"
175 batch_size, s_q, head_num, d = list(q.shape)
176 q = q.view([-1, head_num, d]).contiguous()
177 blocked_k = blocked_k.view([-1, d]).contiguous()
178 block_table = block_table.contiguous()
179 cache_seqlens = cache_seqlens.contiguous()
181 sm_scale = 1 / math.sqrt(d)
183 o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device)
185 major, _ = torch.cuda.get_device_capability(device)
186 if major == 9:
187 BLOCK_H = 64
188 num_stages = 3
189 elif major == 8:
190 BLOCK_H = 32
191 num_stages = 2
192 else:
193 error.backend_not_support(device)
194 BLOCK_N = 64
195 grid = (
196 triton.cdiv(head_num, BLOCK_H),
197 batch_size,
198 )
199 with torch_device_fn.device(device):
200 flash_mla_attn_kernel[grid](
201 q,
202 blocked_k,
203 block_table,
204 cache_seqlens,
205 o,
206 sm_scale,
207 head_num,
208 # stride
209 q.stride(0),
210 q.stride(1),
211 blocked_k.stride(-2),
212 block_table.stride(0),
213 o.stride(0),
214 o.stride(1),
215 o.stride(2),
216 BLOCK_H=BLOCK_H,
217 BLOCK_N=BLOCK_N,
218 PAGE_SIZE=block_size,
219 HEAD_DIM_V=dv,
220 HEAD_DIM=d,
221 num_warps=8,
222 num_stages=num_stages,
223 )
225 return o.view([b, s_q, h_q, dv])