Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/get_scheduler_metadata.py: 0%
277 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
2import math
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems.utils.device_info import get_device_capability, get_device_info
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14def get_dtype_bytes(dtype):
15 if dtype.is_floating_point:
16 return int(torch.finfo(dtype).bits / 8)
17 else:
18 return int(torch.iinfo(dtype).bits / 8)
21def tile_size_fwd_sm8x(
22 sm86_or_89: bool,
23 headdim: int,
24 headdim_v: int,
25 is_causal: bool,
26 is_local: bool,
27 element_size: int = 2,
28 paged_kv: bool = False,
29 varlen_and_split: bool = False,
30 softcap: bool = False,
31 append_kv: bool = False,
32):
33 if element_size == 2: # fp16/bf16
34 if headdim <= 64:
35 kBlockM = 128
36 kBlockN = 80 if varlen_and_split else (96 if is_local else 112)
37 kNWarps = 4
38 kStages = 1
39 Q_in_regs = False
41 elif headdim <= 96:
42 kBlockM = 128
43 kBlockN = 48 if (varlen_and_split or is_local) else 64
44 kNWarps = 4
45 kStages = 1
46 Q_in_regs = False
48 elif headdim <= 128:
49 use_8_warps = sm86_or_89 or varlen_and_split
50 kBlockM = 128
51 if use_8_warps:
52 kBlockN = (
53 (96 if is_local else 112)
54 if varlen_and_split
55 else (96 if is_local else 128)
56 )
57 else:
58 kBlockN = 48 if is_local else 64
59 kNWarps = 8 if use_8_warps else 4
60 kStages = 1
61 Q_in_regs = use_8_warps
63 elif headdim <= 192:
64 kBlockN_64 = append_kv or is_local or varlen_and_split or paged_kv
65 kBlockM = 128
66 kBlockN = 64 if kBlockN_64 else 96
67 kNWarps = 8
68 kStages = 1 if sm86_or_89 else 2
69 Q_in_regs = not kBlockN_64
71 else: # headdim > 192
72 kBlockM = 128
73 if sm86_or_89:
74 if append_kv:
75 kBlockN = 32
76 elif varlen_and_split or is_local:
77 kBlockN = 48
78 else:
79 kBlockN = 64
80 else:
81 if append_kv:
82 kBlockN = 48
83 elif varlen_and_split or is_local:
84 kBlockN = 64
85 else:
86 kBlockN = 96
87 kNWarps = 8
88 kStages = 1
89 Q_in_regs = sm86_or_89 and not append_kv
90 else:
91 kBlockM = 128
92 kBlockN = 64
93 kNWarps = 8
94 kStages = 2
95 Q_in_regs = False
97 return kBlockM, kBlockN, kNWarps, kStages, Q_in_regs
100def get_optimal_block_mn(
101 device,
102 headdim,
103 headdim_v,
104 is_causal,
105 is_local,
106 has_softcap,
107 element_size=2,
108 paged_kv=False,
109 varlen_and_split=False,
110 append_kv=False,
111):
112 major, minor = get_device_capability()
113 arch = major * 10 + minor
114 sm86_or_89 = arch == 86 or arch == 89
116 kBlockM, kBlockN, kNWarps, kStages, Q_in_regs = tile_size_fwd_sm8x(
117 sm86_or_89=sm86_or_89,
118 headdim=headdim,
119 headdim_v=headdim_v,
120 is_causal=is_causal,
121 is_local=is_local,
122 element_size=element_size,
123 paged_kv=paged_kv,
124 varlen_and_split=varlen_and_split,
125 softcap=has_softcap,
126 append_kv=append_kv,
127 )
129 return kBlockM, kBlockN
132def round_up_headdim(headdim: int) -> int:
133 if headdim <= 64:
134 return 64
135 if headdim <= 96:
136 return 96
137 if headdim <= 128:
138 return 128
139 if headdim <= 192:
140 return 192
141 if headdim <= 256:
142 return 256
143 return 256
146def round_up_headdimv(headdim_v: int) -> int:
147 if headdim_v <= 64:
148 return 64
149 if headdim_v <= 96:
150 return 96
151 if headdim_v <= 128:
152 return 128
153 if headdim_v <= 192:
154 return 192
155 if headdim_v <= 256:
156 return 256
157 return 512
160def use_one_mma_wg(
161 arch: int,
162 headdim: int,
163 seqlen_q: int,
164 pack_gqa: bool,
165 num_heads: int,
166 num_heads_k: int,
167) -> bool:
168 if arch < 90 or headdim != 128:
169 return False
171 qhead_per_khead = 1 if not pack_gqa else num_heads // num_heads_k
172 effective_seqlen_q = seqlen_q * qhead_per_khead
174 return effective_seqlen_q <= 64
177def get_num_splits(
178 batch_size: int,
179 num_heads: int,
180 num_heads_k: int,
181 headdim: int,
182 headdim_v: int,
183 d_rounded: int,
184 dv_rounded: int,
185 max_seqlen_q: int,
186 max_seqlen_k: int,
187 max_seqlen_k_new: int,
188 arch: int,
189 num_sm: int,
190 is_causal: bool,
191 is_local: bool,
192 has_softcap: float,
193 is_varlen: bool,
194 has_page_table: bool,
195 element_size: int = 2, # fp16/bf16 = 2, fp8 = 1
196 max_splits: int = 128,
197 use_dynamic_split: bool = False,
198) -> int:
199 pagedkv_tma = False
200 append_kv = max_seqlen_k_new > 0
202 if arch >= 90:
203 # TODO: tile_size_fwd_sm90
204 kBlockM, kBlockN = get_optimal_block_mn(
205 device=0,
206 headdim=d_rounded,
207 headdim_v=dv_rounded,
208 is_causal=is_causal,
209 is_local=is_local,
210 has_softcap=has_softcap,
211 element_size=element_size,
212 paged_kv=has_page_table and not pagedkv_tma,
213 varlen_and_split=is_varlen,
214 append_kv=append_kv,
215 )
216 else:
217 sm86_or_89 = arch == 86 or arch == 89
218 kBlockM, kBlockN, _, _, _ = tile_size_fwd_sm8x(
219 sm86_or_89=sm86_or_89,
220 headdim=d_rounded,
221 headdim_v=dv_rounded,
222 is_causal=is_causal,
223 is_local=is_local,
224 element_size=element_size,
225 paged_kv=has_page_table,
226 varlen_and_split=is_varlen,
227 softcap=(has_softcap > 0.0),
228 append_kv=append_kv,
229 )
231 seqlen_q_packgqa = max_seqlen_q * (num_heads // num_heads_k)
233 if is_local:
234 seqlen_k_loaded = max(0, min(max_seqlen_k, kBlockM + max_seqlen_q))
235 else:
236 seqlen_k_loaded = max_seqlen_k
238 num_n_blocks = (seqlen_k_loaded + kBlockN - 1) // kBlockN
239 num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) // kBlockM
241 size_one_kv_head = max_seqlen_k * (headdim + headdim_v) * element_size
243 effective_batch = 1 if use_dynamic_split else batch_size
244 total_mblocks = effective_batch * num_heads_k * num_m_blocks
246 return _vllm_num_splits_heuristic(
247 total_mblocks=total_mblocks,
248 num_sm=num_sm,
249 num_n_blocks=num_n_blocks,
250 num_m_blocks=num_m_blocks,
251 size_one_kv_head=size_one_kv_head,
252 is_causal_or_local=is_causal or is_local,
253 max_splits=max_splits,
254 )
257def _vllm_num_splits_heuristic(
258 total_mblocks: int,
259 num_sm: int,
260 num_n_blocks: int,
261 num_m_blocks: int,
262 size_one_kv_head: int,
263 is_causal_or_local: bool,
264 max_splits: int,
265) -> int:
266 if total_mblocks >= 0.8 * num_sm:
267 size_l2 = 50 * 1024 * 1024
268 if (
269 size_one_kv_head > size_l2
270 and num_m_blocks >= num_sm * 2
271 and not is_causal_or_local
272 ):
273 return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits)
274 else:
275 return 1
277 if num_n_blocks <= 4:
278 return 1
280 max_splits = min(max_splits, num_sm, num_n_blocks)
282 max_efficiency = 0.0
283 efficiencies = []
285 for num_splits in range(1, max_splits + 1):
286 n_waves = float(total_mblocks * num_splits) / num_sm
287 eff = n_waves / math.ceil(n_waves)
288 if eff > max_efficiency:
289 max_efficiency = eff
290 efficiencies.append(eff)
292 for num_splits in range(1, max_splits + 1):
293 if efficiencies[num_splits - 1] >= 0.85 * max_efficiency:
294 return num_splits
296 return 1
299@triton.jit
300def _prepare_pass1_kernel(
301 num_m_blocks_ptr,
302 num_n_blocks_ptr,
303 total_blocks_ptr,
304 seqlen_k_ptr,
305 cu_seqlens_q_ptr,
306 cu_seqlens_k_ptr,
307 cu_seqlens_k_new_ptr,
308 seqused_q_ptr,
309 seqused_k_ptr,
310 leftpad_k_ptr,
311 batch,
312 qhead_per_khead,
313 max_seqlen_q: tl.constexpr,
314 max_seqlen_k_new: tl.constexpr,
315 BLOCK_M: tl.constexpr,
316 BLOCK_N: tl.constexpr,
317 BLOCK_SIZE_B: tl.constexpr,
318 # HAS_XXX is used to implement static branch in kernel
319 HAS_CU_SEQLENS_Q: tl.constexpr,
320 HAS_CU_SEQLENS_K: tl.constexpr,
321 HAS_SEQUSED_Q: tl.constexpr,
322 HAS_SEQUSED_K: tl.constexpr,
323 HAS_LEFT_PAD: tl.constexpr,
324 HAS_K_NEW: tl.constexpr,
325 HAS_CU_SEQLENS_K_NEW: tl.constexpr,
326):
327 pid = tl.program_id(0)
328 b_start = pid * BLOCK_SIZE_B
329 b_offs = b_start + tl.arange(0, BLOCK_SIZE_B)
330 mask = b_offs < batch
332 if HAS_SEQUSED_Q:
333 q_len = tl.load(seqused_q_ptr + b_offs, mask=mask, other=0)
334 elif HAS_CU_SEQLENS_Q:
335 cur = tl.load(cu_seqlens_q_ptr + b_offs, mask=mask, other=0)
336 nxt = tl.load(cu_seqlens_q_ptr + b_offs + 1, mask=mask, other=0)
337 q_len = nxt - cur
338 else:
339 q_len = tl.full(
340 [BLOCK_SIZE_B], max_seqlen_q, dtype=tl.int32
341 ) # max_seqlen_q constexpr
342 q_len = q_len * qhead_per_khead
343 m_blocks = (q_len + BLOCK_M - 1) // BLOCK_M
345 if HAS_SEQUSED_K:
346 k_len = tl.load(seqused_k_ptr + b_offs, mask=mask, other=0)
347 elif HAS_CU_SEQLENS_K:
348 cur = tl.load(cu_seqlens_k_ptr + b_offs, mask=mask, other=0)
349 nxt = tl.load(cu_seqlens_k_ptr + b_offs + 1, mask=mask, other=0)
350 k_len = nxt - cur
351 else:
352 k_len = tl.load(seqlen_k_ptr + b_offs, mask=mask, other=0)
353 left = tl.load(leftpad_k_ptr + b_offs, mask=mask, other=0) if HAS_LEFT_PAD else 0
355 if HAS_K_NEW:
356 if HAS_CU_SEQLENS_K_NEW:
357 cur_new = tl.load(cu_seqlens_k_new_ptr + b_offs, mask=mask, other=0)
358 nxt_new = tl.load(cu_seqlens_k_new_ptr + b_offs + 1, mask=mask, other=0)
359 k_len += nxt_new - cur_new
360 else:
361 k_len += max_seqlen_k_new
362 k_len = k_len - left
363 n_blocks = (k_len + BLOCK_N - 1) // BLOCK_N
365 tl.store(num_m_blocks_ptr + b_offs, m_blocks, mask=mask)
366 tl.store(num_n_blocks_ptr + b_offs, n_blocks, mask=mask)
367 total = m_blocks * n_blocks
368 tl.atomic_add(total_blocks_ptr, tl.sum(total, axis=0))
371@triton.jit
372def _prepare_pass2_kernel(
373 num_n_blocks_per_seq_ptr,
374 num_splits_dynamic_ptr,
375 total_blocks,
376 num_batch,
377 num_head,
378 num_sm,
379 num_splits_static,
380 BLOCK_SIZE_B: tl.constexpr,
381):
382 """
383 Triton Kernel: Pass 2
384 - Calculates the dynamic number of splits for the Split-K optimization,
385 based on the total number of blocks computed in Pass 1.
386 """
387 pid = tl.program_id(axis=0)
388 b_start = pid * BLOCK_SIZE_B
389 b_offsets = b_start + tl.arange(0, BLOCK_SIZE_B)
390 b_mask = b_offsets < num_batch
392 blocks_per_sm_float = tl.ceil(total_blocks * 1.1 * num_head / num_sm)
393 blocks_per_sm = blocks_per_sm_float.to(tl.int32)
395 blocks_per_sm = tl.maximum(1, blocks_per_sm)
397 num_n_blocks = tl.load(num_n_blocks_per_seq_ptr + b_offsets, mask=b_mask, other=0)
398 num_splits_dynamic = (num_n_blocks + blocks_per_sm - 1) // blocks_per_sm
400 num_splits_dynamic = tl.minimum(num_splits_dynamic, num_splits_static)
401 num_splits_dynamic = tl.maximum(1, num_splits_dynamic)
403 tl.store(num_splits_dynamic_ptr + b_offsets, num_splits_dynamic, mask=b_mask)
406def get_pack_gqa(
407 arch: int,
408 has_page_table: bool,
409 pagedkv_tma: bool,
410 num_splits: int,
411 num_heads: int,
412 num_heads_k: int,
413) -> bool:
414 if arch < 90 or (has_page_table and not pagedkv_tma) or num_splits > 1:
415 return True
417 if num_heads == num_heads_k:
418 return False
420 # TODO: implement tile_size_fwd_sm90 and should_pack_gqa (Hopper+ only)
421 return False
424def get_scheduler_metadata(
425 batch_size: int,
426 max_seqlen_q: int,
427 max_seqlen_k: int,
428 num_heads: int,
429 num_heads_k: int,
430 headdim: int,
431 headdim_v: int,
432 qkv_dtype: torch.dtype,
433 seqused_k: torch.Tensor,
434 cu_seqlens_q: Optional[torch.Tensor] = None,
435 cu_seqlens_k: Optional[torch.Tensor] = None,
436 cu_seqlens_k_new: Optional[torch.Tensor] = None,
437 seqused_q: Optional[torch.Tensor] = None,
438 leftpad_k: Optional[torch.Tensor] = None,
439 page_size: Optional[int] = None,
440 max_seqlen_k_new: int = 0,
441 is_causal: bool = False,
442 window_size_left: int = -1,
443 window_size_right: int = -1,
444 has_softcap: bool = False,
445 num_splits: int = 0,
446 pack_gqa: Optional[bool] = None,
447 sm_margin: int = 0,
448) -> torch.Tensor:
449 device = seqused_k.device
450 dtype = torch.int32
452 # check parameters
453 supported_dtypes = (torch.half, torch.bfloat16)
454 assert (
455 qkv_dtype in supported_dtypes
456 ), "FlashAttention only supports fp16 and bf16 data type"
457 assert (
458 num_heads % num_heads_k == 0
459 ), "Number of heads in key/value must divide number of heads in query"
461 # is_causal & window_size implementation
462 effective_is_causal = is_causal
463 effective_window_left = window_size_left if window_size_left >= 0 else -1
464 effective_window_right = window_size_right
466 if effective_window_left >= max_seqlen_k - 1:
467 effective_window_left = -1
468 if effective_window_right >= max_seqlen_q - 1:
469 effective_window_right = -1
471 if (
472 max_seqlen_q == 1
473 and effective_window_left == -1
474 and effective_window_right == -1
475 ):
476 if (headdim <= 64 or headdim > 128) or page_size is None:
477 effective_is_causal = False
479 if effective_is_causal:
480 effective_window_right = 0
482 final_is_causal = effective_window_left < 0 and effective_window_right == 0
483 final_is_local = (
484 effective_window_left >= 0 or effective_window_right >= 0
485 ) and not final_is_causal
487 major, minor = get_device_capability()
488 arch = major * 10 + minor
489 num_sm = get_device_info().sm_count - sm_margin
491 softcap = 1.0 if has_softcap else 0.0
493 element_size = get_dtype_bytes(qkv_dtype)
495 has_page_table = page_size is not None
497 # TODO implement get_pagedkv_tma function (Hopper+ only)
498 pagedkv_tma = False
500 blockM, blockN = get_optimal_block_mn(
501 device=device,
502 headdim=headdim,
503 headdim_v=headdim_v,
504 is_causal=final_is_causal,
505 is_local=final_is_local,
506 has_softcap=has_softcap,
507 element_size=element_size,
508 )
510 # GQA
511 pack_gqa = (
512 pack_gqa
513 if pack_gqa is not None
514 else get_pack_gqa(
515 arch=arch,
516 has_page_table=has_page_table,
517 pagedkv_tma=pagedkv_tma,
518 num_splits=num_splits, # Note: user-provided num_splits, not eff_num_splits
519 num_heads=num_heads,
520 num_heads_k=num_heads_k,
521 )
522 )
523 qhead_per_khead = (
524 1 if not pack_gqa else (num_heads + num_heads_k - 1) // num_heads_k
525 )
526 num_head_k = num_heads_k if pack_gqa else num_heads
528 # TODO: implement use_one_mma_wg (Hopper+ only)
530 seqlen_q = (
531 seqused_q
532 if seqused_q is not None
533 else torch.full((batch_size,), max_seqlen_q, dtype=dtype, device=device)
534 )
535 seqlen_k = seqused_k
536 seqlen_knew = (
537 torch.full((batch_size,), max_seqlen_k_new, dtype=dtype, device=device)
538 if max_seqlen_k_new > 0
539 else None
540 )
542 num_m_blocks = torch.empty_like(seqlen_q)
543 num_n_blocks = torch.empty_like(seqlen_k)
544 total_blocks = torch.zeros((1,), dtype=dtype, device=device)
545 num_splits_dynamic = torch.empty_like(seqlen_q)
547 BLOCK_SIZE_B = 128
548 grid = (triton.cdiv(batch_size, BLOCK_SIZE_B),)
550 _prepare_pass1_kernel[grid](
551 num_m_blocks,
552 num_n_blocks,
553 total_blocks,
554 seqlen_k,
555 cu_seqlens_q,
556 cu_seqlens_k,
557 cu_seqlens_k_new,
558 seqused_q,
559 seqused_k,
560 leftpad_k,
561 batch_size,
562 qhead_per_khead,
563 max_seqlen_q=max_seqlen_q,
564 max_seqlen_k_new=max_seqlen_k_new,
565 BLOCK_M=blockM,
566 BLOCK_N=blockN,
567 BLOCK_SIZE_B=BLOCK_SIZE_B,
568 HAS_CU_SEQLENS_Q=cu_seqlens_q is not None,
569 HAS_CU_SEQLENS_K=cu_seqlens_k is not None,
570 HAS_SEQUSED_Q=seqused_q is not None,
571 HAS_SEQUSED_K=True,
572 HAS_LEFT_PAD=leftpad_k is not None,
573 HAS_K_NEW=seqlen_knew is not None,
574 HAS_CU_SEQLENS_K_NEW=cu_seqlens_k_new is not None,
575 )
577 total_blocks_val = total_blocks.item()
579 use_dynamic_split = (num_splits <= 0) and (batch_size <= 992)
581 if num_splits <= 0:
582 element_size = get_dtype_bytes(qkv_dtype)
583 is_fp16 = qkv_dtype == torch.float16
584 is_bf16 = qkv_dtype == torch.bfloat16
586 if not (is_fp16 or is_bf16):
587 raise ValueError(
588 f"不支持的数据类型: {qkv_dtype}. FlashAttention只支持: torch.float16, torch.bfloat16"
589 )
591 d_rounded = round_up_headdim(headdim)
592 dv_rounded = round_up_headdimv(headdim_v)
594 eff_num_splits = get_num_splits(
595 batch_size=batch_size,
596 num_heads=num_heads,
597 num_heads_k=num_heads_k,
598 headdim=headdim,
599 headdim_v=headdim_v,
600 d_rounded=d_rounded,
601 dv_rounded=dv_rounded,
602 max_seqlen_q=max_seqlen_q,
603 max_seqlen_k=max_seqlen_k,
604 max_seqlen_k_new=max_seqlen_k_new,
605 arch=arch,
606 num_sm=num_sm,
607 is_causal=final_is_causal,
608 is_local=final_is_local,
609 has_softcap=softcap,
610 is_varlen=True,
611 has_page_table=has_page_table,
612 element_size=element_size,
613 use_dynamic_split=use_dynamic_split,
614 )
615 else:
616 eff_num_splits = num_splits
618 eff_num_splits = min(eff_num_splits, 256, num_sm)
620 pack_gqa = eff_num_splits > 1
622 if pack_gqa:
623 qhead_per_khead = (num_heads + num_heads_k - 1) // num_heads_k
624 num_head_k = num_heads_k
625 else:
626 qhead_per_khead = 1
627 num_head_k = num_heads
629 if use_dynamic_split:
630 _prepare_pass2_kernel[grid](
631 num_n_blocks,
632 num_splits_dynamic,
633 total_blocks=total_blocks_val,
634 num_batch=batch_size,
635 num_head=num_head_k,
636 num_sm=num_sm,
637 num_splits_static=eff_num_splits,
638 BLOCK_SIZE_B=BLOCK_SIZE_B,
639 )
640 else:
641 num_splits_dynamic.fill_(eff_num_splits)
643 final_num_splits = eff_num_splits
645 is_varlen = True
647 if arch >= 90:
648 scheduler_needs_semaphore = (
649 (final_is_causal or final_is_local) and (final_num_splits == 1)
650 ) or is_varlen
651 else:
652 scheduler_needs_semaphore = (final_is_causal and not is_varlen) or (
653 is_varlen and final_num_splits > 1
654 )
656 if use_dynamic_split:
657 final_num_splits_for_sem_check = eff_num_splits
658 else:
659 final_num_splits_for_sem_check = eff_num_splits
661 scheduler_needs_semaphore = arch >= 90 or final_num_splits_for_sem_check > 1
663 alloc_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * batch_size
665 if alloc_size > 0:
666 scheduler_metadata = torch.empty(alloc_size, dtype=torch.int32, device=device)
667 offset = 0
668 if scheduler_needs_semaphore:
669 scheduler_metadata[offset] = total_blocks_val
670 offset += 1
672 if use_dynamic_split:
673 scheduler_metadata[offset:] = num_splits_dynamic
674 elif scheduler_needs_semaphore and not use_dynamic_split:
675 pass
676 return scheduler_metadata
677 else:
678 return torch.empty((0,), dtype=torch.int32, device=device)