Coverage for src/flag_gems/runtime/backend/_metax/fused/flash_mla.py: 0%
104 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device, error, torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.device_info import get_device_capability
12vendor_name = device.vendor_name
13device = device.name
14logger = logging.getLogger(__name__)
17@triton.heuristics(
18 values={
19 "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0,
20 }
21)
22@triton.jit
23def flash_mla_attn_kernel(
24 Q_ptr,
25 Kv_cache,
26 Req_to_tokens,
27 B_seq_len,
28 O,
29 sm_scale,
30 head_num,
31 stride_q_bs,
32 stride_q_h,
33 stride_kv_bs,
34 stride_req_to_tokens_bs,
35 stride_o_b,
36 stride_o_h,
37 stride_o_s,
38 BLOCK_H: tl.constexpr,
39 BLOCK_N: tl.constexpr,
40 EVEN_H: tl.constexpr,
41 PAGE_SIZE: tl.constexpr,
42 HEAD_DIM_V: tl.constexpr,
43 HEAD_DIM: tl.constexpr,
44):
45 cur_head_id = tle.program_id(0)
46 cur_batch_id = tle.program_id(1)
47 Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id
49 cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
51 offs_d_ckv = tl.arange(0, HEAD_DIM_V)
52 offs_q_nope = (
53 cur_batch_id * stride_q_bs
54 + cur_head[:, None] * stride_q_h
55 + offs_d_ckv[None, :]
56 )
58 offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM)
59 offs_q_pe = (
60 cur_batch_id * stride_q_bs
61 + cur_head[:, None] * stride_q_h
62 + offs_d_kpe[None, :]
63 )
65 if EVEN_H:
66 q_nope = tl.load(Q_ptr + offs_q_nope)
67 q_pe = tl.load(Q_ptr + offs_q_pe)
68 else:
69 mask_head = cur_head < head_num
70 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None])
71 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None])
73 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32)
74 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
75 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32)
77 cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id)
78 loop_time = cur_batch_seq_len // BLOCK_N
79 remainder = cur_batch_seq_len % BLOCK_N
80 offs_n = tl.arange(0, BLOCK_N)
81 for i in range(0, loop_time):
82 kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE)
83 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
84 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
85 v_c = tl.load(Kv_cache + offs_v_c)
86 k_c = tl.trans(v_c)
88 qk = tl.dot(q_nope, k_c) # qk_nope
90 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
91 k_pe = tl.load(Kv_cache + offs_k_pe)
93 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
94 qk *= sm_scale
96 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
97 re_scale = tl.exp(e_max - n_e_max)
98 p = tl.exp(qk - n_e_max[:, None])
99 acc *= re_scale[:, None]
100 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
102 e_sum = e_sum * re_scale + tl.sum(p, 1)
103 e_max = n_e_max
104 offs_n += BLOCK_N
106 if remainder:
107 mask_kvsplit = offs_n < cur_batch_seq_len
108 kv_page_number = tl.load(
109 Req_to_tokens + offs_n // PAGE_SIZE,
110 mask=mask_kvsplit,
111 other=0,
112 )
113 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
114 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
115 v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0)
116 k_c = tl.trans(v_c)
118 qk = tl.dot(q_nope, k_c) # qk_nope
120 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
121 k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0)
123 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
124 qk *= sm_scale
126 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf"))
128 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
129 re_scale = tl.exp(e_max - n_e_max)
130 p = tl.exp(qk - n_e_max[:, None])
131 acc *= re_scale[:, None]
132 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
134 e_sum = e_sum * re_scale + tl.sum(p, 1)
136 offs_o = (
137 cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :]
138 )
139 if EVEN_H:
140 tl.store(
141 O + offs_o,
142 acc / e_sum[:, None],
143 )
144 else:
145 tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None])
148def flash_mla(
149 q,
150 block_table,
151 blocked_k,
152 max_seqlen_pad,
153 block_size,
154 b,
155 s_q,
156 cache_seqlens,
157 h_q,
158 h_kv,
159 d,
160 dv,
161 causal,
162):
163 logger.debug("METAX GEMS FLASH MLA")
164 assert causal, "causal False not supported"
165 assert d > dv, "mla with rope dim should be larger than no rope dim"
167 batch_size, s_q, head_num, d = list(q.shape)
168 q = q.view([-1, head_num, d]).contiguous()
169 blocked_k = blocked_k.view([-1, d]).contiguous()
170 block_table = block_table.contiguous()
171 cache_seqlens = cache_seqlens.contiguous()
173 sm_scale = 1 / math.sqrt(d)
175 o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device)
177 major, _ = get_device_capability()
178 if major == 9:
179 BLOCK_H = 64
180 num_stages = 3
181 elif major == 8:
182 BLOCK_H = 16
183 num_stages = 2
184 elif major == 7 and vendor_name == "iluvatar":
185 BLOCK_H = 32
186 num_stages = 1
187 else:
188 error.backend_not_support(device)
189 BLOCK_N = 32
190 grid = (
191 triton.cdiv(head_num, BLOCK_H),
192 batch_size,
193 )
194 with torch_device_fn.device(device):
195 flash_mla_attn_kernel[grid](
196 q,
197 blocked_k,
198 block_table,
199 cache_seqlens,
200 o,
201 sm_scale,
202 head_num,
203 # stride
204 q.stride(0),
205 q.stride(1),
206 blocked_k.stride(-2),
207 block_table.stride(0),
208 o.stride(0),
209 o.stride(1),
210 o.stride(2),
211 BLOCK_H=BLOCK_H,
212 BLOCK_N=BLOCK_N,
213 PAGE_SIZE=block_size,
214 HEAD_DIM_V=dv,
215 HEAD_DIM=d,
216 num_warps=8,
217 num_stages=num_stages,
218 )
220 return o.view([b, s_q, h_q, dv])