Coverage for src/flag_gems/fused/flash_mla.py: 33%
107 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device, error, torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.device_info import get_device_capability
12vendor_name = device.vendor_name
13device = device.name
14logger = logging.getLogger(__name__)
17# @triton.autotune(
18# configs=[
19# triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s)
20# for h in [32, 64, 128]
21# for n in [32, 64, 128]
22# for w in [4, 8]
23# for s in [1, 2]
24# ],
25# key=["head_num"]
26# )
27@triton.heuristics(
28 values={
29 "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0,
30 }
31)
32@triton.jit
33def flash_mla_attn_kernel(
34 Q_ptr,
35 Kv_cache,
36 Req_to_tokens,
37 B_seq_len,
38 O,
39 sm_scale,
40 head_num,
41 stride_q_bs,
42 stride_q_h,
43 stride_kv_bs,
44 stride_req_to_tokens_bs,
45 stride_o_b,
46 stride_o_h,
47 stride_o_s,
48 BLOCK_H: tl.constexpr,
49 BLOCK_N: tl.constexpr,
50 EVEN_H: tl.constexpr,
51 PAGE_SIZE: tl.constexpr,
52 HEAD_DIM_V: tl.constexpr,
53 HEAD_DIM: tl.constexpr,
54):
55 cur_head_id = tle.program_id(0)
56 cur_batch_id = tle.program_id(1)
57 Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id
59 cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
61 offs_d_ckv = tl.arange(0, HEAD_DIM_V)
62 offs_q_nope = (
63 cur_batch_id * stride_q_bs
64 + cur_head[:, None] * stride_q_h
65 + offs_d_ckv[None, :]
66 )
68 offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM)
69 offs_q_pe = (
70 cur_batch_id * stride_q_bs
71 + cur_head[:, None] * stride_q_h
72 + offs_d_kpe[None, :]
73 )
75 if EVEN_H:
76 q_nope = tl.load(Q_ptr + offs_q_nope)
77 q_pe = tl.load(Q_ptr + offs_q_pe)
78 else:
79 mask_head = cur_head < head_num
80 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None])
81 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None])
83 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32)
84 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
85 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32)
87 cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id)
88 loop_time = cur_batch_seq_len // BLOCK_N
89 remainder = cur_batch_seq_len % BLOCK_N
90 offs_n = tl.arange(0, BLOCK_N)
91 for i in range(0, loop_time):
92 kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE)
93 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
94 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
95 v_c = tl.load(Kv_cache + offs_v_c)
96 k_c = tl.trans(v_c)
98 qk = tl.dot(q_nope, k_c) # qk_nope
100 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
101 k_pe = tl.load(Kv_cache + offs_k_pe)
103 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
104 qk *= sm_scale
106 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
107 re_scale = tl.exp(e_max - n_e_max)
108 p = tl.exp(qk - n_e_max[:, None])
109 acc *= re_scale[:, None]
110 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
112 e_sum = e_sum * re_scale + tl.sum(p, 1)
113 e_max = n_e_max
114 offs_n += BLOCK_N
116 if remainder:
117 mask_kvsplit = offs_n < cur_batch_seq_len
118 kv_page_number = tl.load(
119 Req_to_tokens + offs_n // PAGE_SIZE,
120 mask=mask_kvsplit,
121 other=0,
122 )
123 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
124 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
125 v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0)
126 k_c = tl.trans(v_c)
128 qk = tl.dot(q_nope, k_c) # qk_nope
130 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
131 k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0)
133 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
134 qk *= sm_scale
136 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf"))
138 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
139 re_scale = tl.exp(e_max - n_e_max)
140 p = tl.exp(qk - n_e_max[:, None])
141 acc *= re_scale[:, None]
142 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
144 e_sum = e_sum * re_scale + tl.sum(p, 1)
146 offs_o = (
147 cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :]
148 )
149 if EVEN_H:
150 tl.store(
151 O + offs_o,
152 acc / e_sum[:, None],
153 )
154 else:
155 tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None])
158def flash_mla(
159 q,
160 block_table,
161 blocked_k,
162 max_seqlen_pad,
163 block_size,
164 b,
165 s_q,
166 cache_seqlens,
167 h_q,
168 h_kv,
169 d,
170 dv,
171 causal,
172):
173 logger.debug("GEMS FLASH MLA")
174 assert causal, "causal False not supported"
175 assert d > dv, "mla with rope dim should be larger than no rope dim"
177 batch_size, s_q, head_num, d = list(q.shape)
178 q = q.view([-1, head_num, d]).contiguous()
179 blocked_k = blocked_k.view([-1, d]).contiguous()
180 block_table = block_table.contiguous()
181 cache_seqlens = cache_seqlens.contiguous()
183 sm_scale = 1 / math.sqrt(d)
185 o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device)
187 major, _ = get_device_capability()
188 if major == 9:
189 BLOCK_H = 64
190 num_stages = 3
191 elif major == 8:
192 BLOCK_H = 32
193 num_stages = 2
194 elif major == 7 and vendor_name == "iluvatar":
195 BLOCK_H = 32
196 num_stages = 1
197 elif major == 3 and vendor_name == "mthreads":
198 BLOCK_H = 32
199 num_stages = 1
200 else:
201 error.backend_not_support(device)
202 BLOCK_N = 64
203 grid = (
204 triton.cdiv(head_num, BLOCK_H),
205 batch_size,
206 )
207 with torch_device_fn.device(device):
208 flash_mla_attn_kernel[grid](
209 q,
210 blocked_k,
211 block_table,
212 cache_seqlens,
213 o,
214 sm_scale,
215 head_num,
216 # stride
217 q.stride(0),
218 q.stride(1),
219 blocked_k.stride(-2),
220 block_table.stride(0),
221 o.stride(0),
222 o.stride(1),
223 o.stride(2),
224 BLOCK_H=BLOCK_H,
225 BLOCK_N=BLOCK_N,
226 PAGE_SIZE=block_size,
227 HEAD_DIM_V=dv,
228 HEAD_DIM=d,
229 num_warps=8,
230 num_stages=num_stages,
231 )
233 return o.view([b, s_q, h_q, dv])