Coverage for src/flag_gems/ops/get_scheduler_metadata.py: 7%
332 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2import math
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems.utils.device_info import get_device_capability, get_device_info
11logger = logging.getLogger(__name__)
14def tile_size_fwd_sm8x(
15 sm86_or_89: bool,
16 headdim: int,
17 headdim_v: int,
18 is_causal: bool,
19 is_local: bool,
20 element_size: int = 2,
21 paged_kv: bool = False,
22 varlen_and_split: bool = False,
23 softcap: bool = False,
24 append_kv: bool = False,
25):
26 if element_size == 2: # fp16/bf16
27 if headdim <= 64:
28 kBlockM = 128
29 kBlockN = 80 if varlen_and_split else (96 if is_local else 112)
30 kNWarps = 4
31 kStages = 1
32 Q_in_regs = False
34 elif headdim <= 96:
35 kBlockM = 128
36 kBlockN = 48 if (varlen_and_split or is_local) else 64
37 kNWarps = 4
38 kStages = 1
39 Q_in_regs = False
41 elif headdim <= 128:
42 use_8_warps = sm86_or_89 or varlen_and_split
43 kBlockM = 128
44 if use_8_warps:
45 kBlockN = (
46 (96 if is_local else 112)
47 if varlen_and_split
48 else (96 if is_local else 128)
49 )
50 else:
51 kBlockN = 48 if is_local else 64
52 kNWarps = 8 if use_8_warps else 4
53 kStages = 1
54 Q_in_regs = use_8_warps
56 elif headdim <= 192:
57 kBlockN_64 = append_kv or is_local or varlen_and_split or paged_kv
58 kBlockM = 128
59 kBlockN = 64 if kBlockN_64 else 96
60 kNWarps = 8
61 kStages = 1 if sm86_or_89 else 2
62 Q_in_regs = not kBlockN_64
64 else: # headdim > 192
65 kBlockM = 128
66 if sm86_or_89:
67 if append_kv:
68 kBlockN = 32
69 elif varlen_and_split or is_local:
70 kBlockN = 48
71 else:
72 kBlockN = 64
73 else:
74 if append_kv:
75 kBlockN = 48
76 elif varlen_and_split or is_local:
77 kBlockN = 64
78 else:
79 kBlockN = 96
80 kNWarps = 8
81 kStages = 1
82 Q_in_regs = sm86_or_89 and not append_kv
83 else:
84 kBlockM = 128
85 kBlockN = 64
86 kNWarps = 8
87 kStages = 2
88 Q_in_regs = False
90 return kBlockM, kBlockN, kNWarps, kStages, Q_in_regs
93def tile_size_fwd_sm90(
94 headdim: int,
95 headdim_v: int,
96 is_causal: bool,
97 is_local: bool,
98 element_size: int = 2,
99 v_colmajor: bool = False,
100 paged_kv_non_TMA: bool = False,
101 softcap: bool = False,
102 use_one_mma_wg: bool = False,
103):
104 if element_size == 2:
105 if headdim <= 64:
106 if headdim_v == 512:
107 return 64, 64
108 elif headdim_v == 256:
109 return 128, 112
110 else:
111 use_blockN_128 = is_causal or is_local
112 return 192, (128 if use_blockN_128 else 192)
113 elif headdim <= 96:
114 return 192, (128 if (is_local or paged_kv_non_TMA) else 144)
115 elif headdim <= 128:
116 if use_one_mma_wg:
117 return 64, (128 if (is_causal or is_local or paged_kv_non_TMA) else 176)
118 else:
119 return 128, (
120 128 if (is_causal or is_local or paged_kv_non_TMA) else 176
121 )
122 elif headdim <= 192:
123 return 128, (
124 96
125 if (paged_kv_non_TMA or is_local)
126 else (128 if headdim_v <= 128 else 112)
127 )
128 else:
129 return 128, (64 if is_local else 80)
130 else:
131 if headdim <= 64:
132 return 192, 160
133 elif headdim <= 96:
134 return 192, 128
135 elif headdim <= 128:
136 return 128, (
137 160
138 if paged_kv_non_TMA
139 else (192 if (v_colmajor or (softcap and is_local)) else 224)
140 )
141 elif headdim <= 192:
142 return 128, (128 if ((paged_kv_non_TMA or softcap) and is_local) else 160)
143 else:
144 return 128, (64 if is_local else 128)
147def get_optimal_block_mn(
148 device,
149 headdim,
150 headdim_v,
151 is_causal,
152 is_local,
153 has_softcap,
154 element_size=2,
155 paged_kv=False,
156 pagedkv_tma: bool = False,
157 varlen_and_split=False,
158 append_kv=False,
159):
160 major, minor = get_device_capability()
161 arch = major * 10 + minor
163 if arch >= 90:
164 paged_kv_non_TMA = bool(paged_kv and (not pagedkv_tma))
165 kBlockM, kBlockN = tile_size_fwd_sm90(
166 headdim=headdim,
167 headdim_v=headdim_v,
168 is_causal=is_causal,
169 is_local=is_local,
170 element_size=element_size,
171 v_colmajor=False,
172 paged_kv_non_TMA=paged_kv_non_TMA,
173 softcap=has_softcap,
174 use_one_mma_wg=False,
175 )
176 return kBlockM, kBlockN
177 else:
178 kBlockM, kBlockN, kNWarps, kStages, Q_in_regs = tile_size_fwd_sm8x(
179 sm86_or_89=arch == 86 or arch == 89,
180 headdim=headdim,
181 headdim_v=headdim_v,
182 is_causal=is_causal,
183 is_local=is_local,
184 element_size=element_size,
185 paged_kv=paged_kv,
186 varlen_and_split=varlen_and_split,
187 softcap=has_softcap,
188 append_kv=append_kv,
189 )
190 return kBlockM, kBlockN
193def round_up_headdim(headdim: int) -> int:
194 if headdim <= 64:
195 return 64
196 if headdim <= 96:
197 return 96
198 if headdim <= 128:
199 return 128
200 if headdim <= 192:
201 return 192
202 if headdim <= 256:
203 return 256
204 return 256
207def round_up_headdimv(headdim_v: int) -> int:
208 if headdim_v <= 64:
209 return 64
210 if headdim_v <= 96:
211 return 96
212 if headdim_v <= 128:
213 return 128
214 if headdim_v <= 192:
215 return 192
216 if headdim_v <= 256:
217 return 256
218 return 512
221def get_pagedkv_tma(
222 arch: int,
223 page_size: int,
224 has_page_table: bool,
225 leftpad_k: Optional[torch.Tensor],
226 max_seqlen_q: int,
227 max_seqlen_k_new: int,
228 num_heads: int,
229 num_heads_k: int,
230 d_rounded: int,
231 dv_rounded: int,
232 is_causal: bool,
233 is_local: bool,
234 element_size: int,
235 softcap: bool,
236):
237 if (
238 arch < 90
239 or (not has_page_table)
240 or (leftpad_k is not None)
241 or (max_seqlen_k_new > 0)
242 ):
243 return False
244 kBlockM, kBlockN = tile_size_fwd_sm90(
245 headdim=d_rounded,
246 headdim_v=dv_rounded,
247 is_causal=is_causal,
248 is_local=is_local,
249 element_size=element_size,
250 v_colmajor=False,
251 paged_kv_non_TMA=False,
252 softcap=softcap,
253 use_one_mma_wg=False,
254 )
255 if page_size % kBlockN != 0:
256 return False
257 seqlen_q_packgqa = max_seqlen_q * (num_heads // num_heads_k)
258 return seqlen_q_packgqa > kBlockM
261def use_one_mma_wg(
262 arch: int,
263 headdim: int,
264 seqlen_q: int,
265 pack_gqa: bool,
266 num_heads: int,
267 num_heads_k: int,
268) -> bool:
269 if arch < 90 or headdim != 128:
270 return False
272 qhead_per_khead = 1 if not pack_gqa else num_heads // num_heads_k
273 effective_seqlen_q = seqlen_q * qhead_per_khead
275 return effective_seqlen_q <= 64
278def should_pack_gqa(
279 varlen_q: bool,
280 seqlen_q: int,
281 qhead_per_khead: int,
282 blockM: int,
283) -> bool:
284 if varlen_q:
285 return True
287 def round_up(a: int, b: int) -> int:
288 return (a + b - 1) // b * b
290 nopack_eff = float(seqlen_q) / float(round_up(seqlen_q, blockM))
291 pack_eff = float(seqlen_q * qhead_per_khead) / float(
292 round_up(seqlen_q * qhead_per_khead, blockM)
293 )
294 return nopack_eff < 0.9 * pack_eff
297def get_num_splits(
298 batch_size: int,
299 num_heads: int,
300 num_heads_k: int,
301 headdim: int,
302 headdim_v: int,
303 d_rounded: int,
304 dv_rounded: int,
305 max_seqlen_q: int,
306 max_seqlen_k: int,
307 max_seqlen_k_new: int,
308 arch: int,
309 num_sm: int,
310 is_causal: bool,
311 is_local: bool,
312 has_softcap: float,
313 is_varlen: bool,
314 has_page_table: bool,
315 pack_gqa: bool,
316 window_size_left: int,
317 window_size_right: int,
318 element_size: int = 2, # fp16/bf16 = 2, fp8 = 1
319 max_splits: int = 128,
320 use_dynamic_split: bool = False,
321) -> int:
322 pagedkv_tma = False
323 append_kv = max_seqlen_k_new > 0
325 if arch >= 90:
326 uomw = use_one_mma_wg(
327 arch=arch,
328 headdim=headdim,
329 seqlen_q=max_seqlen_q,
330 pack_gqa=pack_gqa,
331 num_heads=num_heads,
332 num_heads_k=num_heads_k,
333 )
334 kBlockM, kBlockN = tile_size_fwd_sm90(
335 headdim=d_rounded,
336 headdim_v=dv_rounded,
337 is_causal=is_causal,
338 is_local=is_local,
339 element_size=element_size,
340 v_colmajor=False,
341 paged_kv_non_TMA=(has_page_table and not pagedkv_tma),
342 softcap=(has_softcap > 0.0),
343 use_one_mma_wg=uomw,
344 )
345 else:
346 sm86_or_89 = arch == 86 or arch == 89
347 kBlockM, kBlockN, _, _, _ = tile_size_fwd_sm8x(
348 sm86_or_89=sm86_or_89,
349 headdim=d_rounded,
350 headdim_v=dv_rounded,
351 is_causal=is_causal,
352 is_local=is_local,
353 element_size=element_size,
354 paged_kv=has_page_table,
355 varlen_and_split=is_varlen,
356 softcap=(has_softcap > 0.0),
357 append_kv=append_kv,
358 )
360 seqlen_q_packgqa = max_seqlen_q * (num_heads // num_heads_k)
362 if is_local:
363 seqlen_k_loaded = max(
364 0,
365 min(max_seqlen_k, window_size_left + window_size_right + 1 + kBlockM),
366 )
367 else:
368 seqlen_k_loaded = max_seqlen_k
370 num_n_blocks = (seqlen_k_loaded + kBlockN - 1) // kBlockN
371 num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) // kBlockM
373 size_one_kv_head = max_seqlen_k * (headdim + headdim_v) * element_size
375 effective_batch = 1 if use_dynamic_split else batch_size
376 total_mblocks = effective_batch * num_heads_k * num_m_blocks
378 return _vllm_num_splits_heuristic(
379 total_mblocks=total_mblocks,
380 num_sm=num_sm,
381 num_n_blocks=num_n_blocks,
382 num_m_blocks=num_m_blocks,
383 size_one_kv_head=size_one_kv_head,
384 is_causal_or_local=is_causal or is_local,
385 max_splits=max_splits,
386 )
389def _vllm_num_splits_heuristic(
390 total_mblocks: int,
391 num_sm: int,
392 num_n_blocks: int,
393 num_m_blocks: int,
394 size_one_kv_head: int,
395 is_causal_or_local: bool,
396 max_splits: int,
397) -> int:
398 if total_mblocks >= 0.8 * num_sm:
399 size_l2 = 50 * 1024 * 1024
400 if (
401 size_one_kv_head > size_l2
402 and num_m_blocks >= num_sm * 2
403 and not is_causal_or_local
404 ):
405 return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits)
406 else:
407 return 1
409 if num_n_blocks <= 4:
410 return 1
412 max_splits = min(max_splits, num_sm, num_n_blocks)
414 max_efficiency = 0.0
415 efficiencies = []
417 for num_splits in range(1, max_splits + 1):
418 n_waves = float(total_mblocks * num_splits) / num_sm
419 eff = n_waves / math.ceil(n_waves)
420 if eff > max_efficiency:
421 max_efficiency = eff
422 efficiencies.append(eff)
424 for num_splits in range(1, max_splits + 1):
425 if efficiencies[num_splits - 1] >= 0.85 * max_efficiency:
426 return num_splits
428 return 1
431@triton.jit
432def _prepare_pass1_kernel(
433 num_m_blocks_ptr,
434 num_n_blocks_ptr,
435 total_blocks_ptr,
436 seqlen_k_ptr,
437 cu_seqlens_q_ptr,
438 cu_seqlens_k_ptr,
439 cu_seqlens_k_new_ptr,
440 seqused_q_ptr,
441 seqused_k_ptr,
442 leftpad_k_ptr,
443 batch,
444 qhead_per_khead,
445 max_seqlen_q: tl.constexpr,
446 max_seqlen_k_new: tl.constexpr,
447 BLOCK_M: tl.constexpr,
448 BLOCK_N: tl.constexpr,
449 BLOCK_SIZE_B: tl.constexpr,
450 # HAS_XXX is used to implement static branch in kernel
451 HAS_CU_SEQLENS_Q: tl.constexpr,
452 HAS_CU_SEQLENS_K: tl.constexpr,
453 HAS_SEQUSED_Q: tl.constexpr,
454 HAS_SEQUSED_K: tl.constexpr,
455 HAS_LEFT_PAD: tl.constexpr,
456 HAS_K_NEW: tl.constexpr,
457 HAS_CU_SEQLENS_K_NEW: tl.constexpr,
458):
459 pid = tl.program_id(0)
460 b_start = pid * BLOCK_SIZE_B
461 b_offs = b_start + tl.arange(0, BLOCK_SIZE_B)
462 mask = b_offs < batch
464 if HAS_SEQUSED_Q:
465 q_len = tl.load(seqused_q_ptr + b_offs, mask=mask, other=0)
466 elif HAS_CU_SEQLENS_Q:
467 cur = tl.load(cu_seqlens_q_ptr + b_offs, mask=mask, other=0)
468 nxt = tl.load(cu_seqlens_q_ptr + b_offs + 1, mask=mask, other=0)
469 q_len = nxt - cur
470 else:
471 q_len = tl.full(
472 [BLOCK_SIZE_B], max_seqlen_q, dtype=tl.int32
473 ) # max_seqlen_q constexpr
474 q_len = q_len * qhead_per_khead
475 m_blocks = (q_len + BLOCK_M - 1) // BLOCK_M
477 if HAS_SEQUSED_K:
478 k_len = tl.load(seqused_k_ptr + b_offs, mask=mask, other=0)
479 elif HAS_CU_SEQLENS_K:
480 cur = tl.load(cu_seqlens_k_ptr + b_offs, mask=mask, other=0)
481 nxt = tl.load(cu_seqlens_k_ptr + b_offs + 1, mask=mask, other=0)
482 k_len = nxt - cur
483 else:
484 k_len = tl.load(seqlen_k_ptr + b_offs, mask=mask, other=0)
485 left = tl.load(leftpad_k_ptr + b_offs, mask=mask, other=0) if HAS_LEFT_PAD else 0
487 if HAS_K_NEW:
488 if HAS_CU_SEQLENS_K_NEW:
489 cur_new = tl.load(cu_seqlens_k_new_ptr + b_offs, mask=mask, other=0)
490 nxt_new = tl.load(cu_seqlens_k_new_ptr + b_offs + 1, mask=mask, other=0)
491 k_len += nxt_new - cur_new
492 else:
493 k_len += max_seqlen_k_new
494 k_len = k_len - left
495 n_blocks = (k_len + BLOCK_N - 1) // BLOCK_N
497 tl.store(num_m_blocks_ptr + b_offs, m_blocks, mask=mask)
498 tl.store(num_n_blocks_ptr + b_offs, n_blocks, mask=mask)
499 total = m_blocks * n_blocks
500 tl.atomic_add(total_blocks_ptr, tl.sum(total, axis=0))
503@triton.jit
504def _prepare_pass2_kernel(
505 num_n_blocks_per_seq_ptr,
506 num_splits_dynamic_ptr,
507 total_blocks,
508 num_batch,
509 num_head,
510 num_sm,
511 num_splits_static,
512 BLOCK_SIZE_B: tl.constexpr,
513):
514 pid = tl.program_id(axis=0)
515 b_start = pid * BLOCK_SIZE_B
516 b_offsets = b_start + tl.arange(0, BLOCK_SIZE_B)
517 b_mask = b_offsets < num_batch
519 blocks_per_sm_float = tl.ceil(total_blocks * 1.1 * num_head / num_sm)
520 blocks_per_sm = blocks_per_sm_float.to(tl.int32)
522 blocks_per_sm = tl.maximum(1, blocks_per_sm)
524 num_n_blocks = tl.load(num_n_blocks_per_seq_ptr + b_offsets, mask=b_mask, other=0)
525 num_splits_dynamic = (num_n_blocks + blocks_per_sm - 1) // blocks_per_sm
527 num_splits_dynamic = tl.minimum(num_splits_dynamic, num_splits_static)
528 num_splits_dynamic = tl.maximum(1, num_splits_dynamic)
530 tl.store(num_splits_dynamic_ptr + b_offsets, num_splits_dynamic, mask=b_mask)
533def get_pack_gqa(
534 arch: int,
535 has_page_table: bool,
536 pagedkv_tma: bool,
537 num_splits: int,
538 num_heads: int,
539 num_heads_k: int,
540 # SM90-specific params for heuristic
541 varlen_q: bool,
542 seqlen_q: int,
543 d_rounded: int,
544 dv_rounded: int,
545 is_causal: bool,
546 is_local: bool,
547 element_size: int,
548 softcap: bool,
549) -> bool:
550 if arch < 90 or (has_page_table and not pagedkv_tma) or num_splits > 1:
551 return True
552 if num_heads == num_heads_k:
553 return False
554 kBlockM, _ = tile_size_fwd_sm90(
555 headdim=d_rounded,
556 headdim_v=dv_rounded,
557 is_causal=is_causal,
558 is_local=is_local,
559 element_size=element_size,
560 v_colmajor=False,
561 paged_kv_non_TMA=(has_page_table and not pagedkv_tma),
562 softcap=softcap,
563 use_one_mma_wg=False,
564 )
565 qhead_per_khead = num_heads // num_heads_k
566 return should_pack_gqa(varlen_q, seqlen_q, qhead_per_khead, kBlockM)
569def get_scheduler_metadata(
570 batch_size: int,
571 max_seqlen_q: int,
572 max_seqlen_k: int,
573 num_heads: int,
574 num_heads_k: int,
575 headdim: int,
576 headdim_v: int,
577 qkv_dtype: torch.dtype,
578 seqused_k: torch.Tensor,
579 cu_seqlens_q: Optional[torch.Tensor] = None,
580 cu_seqlens_k: Optional[torch.Tensor] = None,
581 cu_seqlens_k_new: Optional[torch.Tensor] = None,
582 seqused_q: Optional[torch.Tensor] = None,
583 leftpad_k: Optional[torch.Tensor] = None,
584 page_size: Optional[int] = None,
585 max_seqlen_k_new: int = 0,
586 is_causal: bool = False,
587 window_size_left: int = -1,
588 window_size_right: int = -1,
589 has_softcap: bool = False,
590 num_splits: int = 0,
591 pack_gqa: Optional[bool] = None,
592 sm_margin: int = 0,
593) -> torch.Tensor:
594 device = seqused_k.device
595 dtype = torch.int32
597 # check parameters
598 supported_dtypes = (torch.half, torch.bfloat16)
599 assert (
600 qkv_dtype in supported_dtypes
601 ), "FlashAttention only supports fp16 and bf16 data type"
602 assert (
603 num_heads % num_heads_k == 0
604 ), "Number of heads in key/value must divide number of heads in query"
606 # is_causal & window_size implementation
607 effective_is_causal = is_causal
608 effective_window_left = window_size_left if window_size_left >= 0 else -1
609 effective_window_right = window_size_right
611 if effective_window_left >= max_seqlen_k - 1:
612 effective_window_left = -1
613 if effective_window_right >= max_seqlen_q - 1:
614 effective_window_right = -1
616 if (
617 max_seqlen_q == 1
618 and effective_window_left == -1
619 and effective_window_right == -1
620 ):
621 if (headdim <= 64 or headdim > 128) or page_size is None:
622 effective_is_causal = False
624 if effective_is_causal:
625 effective_window_right = 0
627 final_is_causal = effective_window_left < 0 and effective_window_right == 0
628 final_is_local = (
629 effective_window_left >= 0 or effective_window_right >= 0
630 ) and not final_is_causal
632 major, minor = get_device_capability()
633 arch = major * 10 + minor
634 num_sm = get_device_info().sm_count - sm_margin
636 softcap = 1.0 if has_softcap else 0.0
638 element_size = qkv_dtype.itemsize
640 has_page_table = page_size is not None
642 d_rounded = round_up_headdim(headdim)
643 dv_rounded = round_up_headdimv(headdim_v)
645 pagedkv_tma = get_pagedkv_tma(
646 arch=arch,
647 page_size=page_size if page_size is not None else 1,
648 has_page_table=has_page_table,
649 leftpad_k=leftpad_k,
650 max_seqlen_q=max_seqlen_q,
651 max_seqlen_k_new=max_seqlen_k_new,
652 num_heads=num_heads,
653 num_heads_k=num_heads_k,
654 d_rounded=d_rounded,
655 dv_rounded=dv_rounded,
656 is_causal=final_is_causal,
657 is_local=final_is_local,
658 element_size=element_size,
659 softcap=has_softcap,
660 )
662 blockM, blockN = get_optimal_block_mn(
663 device=device,
664 headdim=headdim,
665 headdim_v=headdim_v,
666 is_causal=final_is_causal,
667 is_local=final_is_local,
668 has_softcap=has_softcap,
669 element_size=element_size,
670 paged_kv=has_page_table,
671 pagedkv_tma=pagedkv_tma,
672 )
674 # GQA
675 varlen_q_flag = cu_seqlens_q is not None or seqused_q is not None
676 pack_gqa = (
677 pack_gqa
678 if pack_gqa is not None
679 else get_pack_gqa(
680 arch=arch,
681 has_page_table=has_page_table,
682 pagedkv_tma=pagedkv_tma,
683 num_splits=num_splits,
684 num_heads=num_heads,
685 num_heads_k=num_heads_k,
686 varlen_q=varlen_q_flag,
687 seqlen_q=max_seqlen_q,
688 d_rounded=d_rounded,
689 dv_rounded=dv_rounded,
690 is_causal=final_is_causal,
691 is_local=final_is_local,
692 element_size=element_size,
693 softcap=has_softcap,
694 )
695 )
696 qhead_per_khead = (
697 1 if not pack_gqa else (num_heads + num_heads_k - 1) // num_heads_k
698 )
699 num_head_k = num_heads_k if pack_gqa else num_heads
701 seqlen_q = (
702 seqused_q
703 if seqused_q is not None
704 else torch.full((batch_size,), max_seqlen_q, dtype=dtype, device=device)
705 )
706 seqlen_k = seqused_k
707 seqlen_knew = (
708 torch.full((batch_size,), max_seqlen_k_new, dtype=dtype, device=device)
709 if max_seqlen_k_new > 0
710 else None
711 )
713 num_m_blocks = torch.empty_like(seqlen_q)
714 num_n_blocks = torch.empty_like(seqlen_k)
715 total_blocks = torch.zeros((1,), dtype=dtype, device=device)
716 num_splits_dynamic = torch.empty_like(seqlen_q)
718 BLOCK_SIZE_B = 128
719 grid = (triton.cdiv(batch_size, BLOCK_SIZE_B),)
721 total_blocks_val = total_blocks.item()
723 # dynamic split depends ONLY on batch_size, regardless of num_splits_static
724 use_dynamic_split = batch_size <= 992
726 if num_splits <= 0:
727 element_size = qkv_dtype.itemsize
728 is_fp16 = qkv_dtype == torch.float16
729 is_bf16 = qkv_dtype == torch.bfloat16
731 if not (is_fp16 or is_bf16):
732 raise ValueError(
733 f"不支持的数据类型: {qkv_dtype}. FlashAttention只支持: torch.float16, torch.bfloat16"
734 )
736 d_rounded = d_rounded
737 dv_rounded = dv_rounded
739 eff_num_splits = get_num_splits(
740 batch_size=batch_size,
741 num_heads=num_heads,
742 num_heads_k=num_heads_k,
743 headdim=headdim,
744 headdim_v=headdim_v,
745 d_rounded=d_rounded,
746 dv_rounded=dv_rounded,
747 max_seqlen_q=max_seqlen_q,
748 max_seqlen_k=max_seqlen_k,
749 max_seqlen_k_new=max_seqlen_k_new,
750 arch=arch,
751 num_sm=num_sm,
752 is_causal=final_is_causal,
753 is_local=final_is_local,
754 has_softcap=softcap,
755 is_varlen=True,
756 has_page_table=has_page_table,
757 pack_gqa=pack_gqa,
758 window_size_left=effective_window_left,
759 window_size_right=effective_window_right,
760 element_size=element_size,
761 use_dynamic_split=use_dynamic_split,
762 )
763 else:
764 eff_num_splits = num_splits
766 eff_num_splits = min(eff_num_splits, 256, num_sm)
768 # Always enable PackGQA for Split
769 pack_gqa = True if eff_num_splits > 1 else pack_gqa
771 # Recompute qhead_per_khead/num_head_k for the kernels
772 qhead_per_khead = (
773 1 if not pack_gqa else (num_heads + num_heads_k - 1) // num_heads_k
774 )
775 num_head_k = num_heads_k if pack_gqa else num_heads
777 is_varlen = True
778 if arch >= 90:
779 uomw = use_one_mma_wg(
780 arch=arch,
781 headdim=headdim,
782 seqlen_q=max_seqlen_q,
783 pack_gqa=pack_gqa,
784 num_heads=num_heads,
785 num_heads_k=num_heads_k,
786 )
787 blockM, blockN = tile_size_fwd_sm90(
788 headdim=round_up_headdim(headdim),
789 headdim_v=round_up_headdimv(headdim_v),
790 is_causal=final_is_causal,
791 is_local=final_is_local,
792 element_size=element_size,
793 v_colmajor=False,
794 paged_kv_non_TMA=(has_page_table and not pagedkv_tma),
795 softcap=has_softcap,
796 use_one_mma_wg=uomw,
797 )
798 else:
799 blockM, blockN = get_optimal_block_mn(
800 device=device,
801 headdim=headdim,
802 headdim_v=headdim_v,
803 is_causal=final_is_causal,
804 is_local=final_is_local,
805 has_softcap=has_softcap,
806 element_size=element_size,
807 paged_kv=has_page_table,
808 pagedkv_tma=pagedkv_tma,
809 varlen_and_split=is_varlen and (eff_num_splits > 1),
810 append_kv=(max_seqlen_k_new > 0),
811 )
813 num_m_blocks = torch.empty_like(seqlen_q)
814 num_n_blocks = torch.empty_like(seqlen_k)
815 total_blocks = torch.zeros((1,), dtype=dtype, device=device)
816 num_splits_dynamic = torch.empty_like(seqlen_q)
818 _prepare_pass1_kernel[grid](
819 num_m_blocks,
820 num_n_blocks,
821 total_blocks,
822 seqlen_k,
823 cu_seqlens_q,
824 cu_seqlens_k,
825 cu_seqlens_k_new,
826 seqused_q,
827 seqused_k,
828 leftpad_k,
829 batch_size,
830 qhead_per_khead,
831 max_seqlen_q=max_seqlen_q,
832 max_seqlen_k_new=max_seqlen_k_new,
833 BLOCK_M=blockM,
834 BLOCK_N=blockN,
835 BLOCK_SIZE_B=BLOCK_SIZE_B,
836 HAS_CU_SEQLENS_Q=cu_seqlens_q is not None,
837 HAS_CU_SEQLENS_K=cu_seqlens_k is not None,
838 HAS_SEQUSED_Q=seqused_q is not None,
839 HAS_SEQUSED_K=True,
840 HAS_LEFT_PAD=leftpad_k is not None,
841 HAS_K_NEW=seqlen_knew is not None,
842 HAS_CU_SEQLENS_K_NEW=cu_seqlens_k_new is not None,
843 )
845 total_blocks_val = total_blocks.item()
847 if use_dynamic_split:
848 _prepare_pass2_kernel[grid](
849 num_n_blocks,
850 num_splits_dynamic,
851 total_blocks=total_blocks_val,
852 num_batch=batch_size,
853 num_head=num_head_k,
854 num_sm=num_sm,
855 num_splits_static=eff_num_splits,
856 BLOCK_SIZE_B=BLOCK_SIZE_B,
857 )
858 else:
859 num_splits_dynamic.fill_(eff_num_splits)
861 final_num_splits = eff_num_splits
863 is_varlen = True
865 if arch >= 90:
866 scheduler_needs_semaphore = (
867 (final_is_causal or final_is_local) and (final_num_splits == 1)
868 ) or is_varlen
869 else:
870 scheduler_needs_semaphore = (final_is_causal and not is_varlen) or (
871 is_varlen and final_num_splits > 1
872 )
874 if use_dynamic_split:
875 final_num_splits_for_sem_check = eff_num_splits
876 else:
877 final_num_splits_for_sem_check = eff_num_splits
879 scheduler_needs_semaphore = arch >= 90 or final_num_splits_for_sem_check > 1
881 alloc_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * batch_size
883 if alloc_size > 0:
884 scheduler_metadata = torch.empty(alloc_size, dtype=torch.int32, device=device)
885 offset = 0
886 if scheduler_needs_semaphore:
887 scheduler_metadata[offset] = 0
888 offset += 1
890 if use_dynamic_split:
891 scheduler_metadata[offset:] = num_splits_dynamic
892 elif scheduler_needs_semaphore and not use_dynamic_split:
893 pass
894 return scheduler_metadata
895 else:
896 return torch.empty((0,), dtype=torch.int32, device=device)