Coverage for src/flag_gems/ops/get_scheduler_metadata.py: 7%

332 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2import math 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.utils.device_info import get_device_capability, get_device_info 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def tile_size_fwd_sm8x( 

15 sm86_or_89: bool, 

16 headdim: int, 

17 headdim_v: int, 

18 is_causal: bool, 

19 is_local: bool, 

20 element_size: int = 2, 

21 paged_kv: bool = False, 

22 varlen_and_split: bool = False, 

23 softcap: bool = False, 

24 append_kv: bool = False, 

25): 

26 if element_size == 2: # fp16/bf16 

27 if headdim <= 64: 

28 kBlockM = 128 

29 kBlockN = 80 if varlen_and_split else (96 if is_local else 112) 

30 kNWarps = 4 

31 kStages = 1 

32 Q_in_regs = False 

33 

34 elif headdim <= 96: 

35 kBlockM = 128 

36 kBlockN = 48 if (varlen_and_split or is_local) else 64 

37 kNWarps = 4 

38 kStages = 1 

39 Q_in_regs = False 

40 

41 elif headdim <= 128: 

42 use_8_warps = sm86_or_89 or varlen_and_split 

43 kBlockM = 128 

44 if use_8_warps: 

45 kBlockN = ( 

46 (96 if is_local else 112) 

47 if varlen_and_split 

48 else (96 if is_local else 128) 

49 ) 

50 else: 

51 kBlockN = 48 if is_local else 64 

52 kNWarps = 8 if use_8_warps else 4 

53 kStages = 1 

54 Q_in_regs = use_8_warps 

55 

56 elif headdim <= 192: 

57 kBlockN_64 = append_kv or is_local or varlen_and_split or paged_kv 

58 kBlockM = 128 

59 kBlockN = 64 if kBlockN_64 else 96 

60 kNWarps = 8 

61 kStages = 1 if sm86_or_89 else 2 

62 Q_in_regs = not kBlockN_64 

63 

64 else: # headdim > 192 

65 kBlockM = 128 

66 if sm86_or_89: 

67 if append_kv: 

68 kBlockN = 32 

69 elif varlen_and_split or is_local: 

70 kBlockN = 48 

71 else: 

72 kBlockN = 64 

73 else: 

74 if append_kv: 

75 kBlockN = 48 

76 elif varlen_and_split or is_local: 

77 kBlockN = 64 

78 else: 

79 kBlockN = 96 

80 kNWarps = 8 

81 kStages = 1 

82 Q_in_regs = sm86_or_89 and not append_kv 

83 else: 

84 kBlockM = 128 

85 kBlockN = 64 

86 kNWarps = 8 

87 kStages = 2 

88 Q_in_regs = False 

89 

90 return kBlockM, kBlockN, kNWarps, kStages, Q_in_regs 

91 

92 

93def tile_size_fwd_sm90( 

94 headdim: int, 

95 headdim_v: int, 

96 is_causal: bool, 

97 is_local: bool, 

98 element_size: int = 2, 

99 v_colmajor: bool = False, 

100 paged_kv_non_TMA: bool = False, 

101 softcap: bool = False, 

102 use_one_mma_wg: bool = False, 

103): 

104 if element_size == 2: 

105 if headdim <= 64: 

106 if headdim_v == 512: 

107 return 64, 64 

108 elif headdim_v == 256: 

109 return 128, 112 

110 else: 

111 use_blockN_128 = is_causal or is_local 

112 return 192, (128 if use_blockN_128 else 192) 

113 elif headdim <= 96: 

114 return 192, (128 if (is_local or paged_kv_non_TMA) else 144) 

115 elif headdim <= 128: 

116 if use_one_mma_wg: 

117 return 64, (128 if (is_causal or is_local or paged_kv_non_TMA) else 176) 

118 else: 

119 return 128, ( 

120 128 if (is_causal or is_local or paged_kv_non_TMA) else 176 

121 ) 

122 elif headdim <= 192: 

123 return 128, ( 

124 96 

125 if (paged_kv_non_TMA or is_local) 

126 else (128 if headdim_v <= 128 else 112) 

127 ) 

128 else: 

129 return 128, (64 if is_local else 80) 

130 else: 

131 if headdim <= 64: 

132 return 192, 160 

133 elif headdim <= 96: 

134 return 192, 128 

135 elif headdim <= 128: 

136 return 128, ( 

137 160 

138 if paged_kv_non_TMA 

139 else (192 if (v_colmajor or (softcap and is_local)) else 224) 

140 ) 

141 elif headdim <= 192: 

142 return 128, (128 if ((paged_kv_non_TMA or softcap) and is_local) else 160) 

143 else: 

144 return 128, (64 if is_local else 128) 

145 

146 

147def get_optimal_block_mn( 

148 device, 

149 headdim, 

150 headdim_v, 

151 is_causal, 

152 is_local, 

153 has_softcap, 

154 element_size=2, 

155 paged_kv=False, 

156 pagedkv_tma: bool = False, 

157 varlen_and_split=False, 

158 append_kv=False, 

159): 

160 major, minor = get_device_capability() 

161 arch = major * 10 + minor 

162 

163 if arch >= 90: 

164 paged_kv_non_TMA = bool(paged_kv and (not pagedkv_tma)) 

165 kBlockM, kBlockN = tile_size_fwd_sm90( 

166 headdim=headdim, 

167 headdim_v=headdim_v, 

168 is_causal=is_causal, 

169 is_local=is_local, 

170 element_size=element_size, 

171 v_colmajor=False, 

172 paged_kv_non_TMA=paged_kv_non_TMA, 

173 softcap=has_softcap, 

174 use_one_mma_wg=False, 

175 ) 

176 return kBlockM, kBlockN 

177 else: 

178 kBlockM, kBlockN, kNWarps, kStages, Q_in_regs = tile_size_fwd_sm8x( 

179 sm86_or_89=arch == 86 or arch == 89, 

180 headdim=headdim, 

181 headdim_v=headdim_v, 

182 is_causal=is_causal, 

183 is_local=is_local, 

184 element_size=element_size, 

185 paged_kv=paged_kv, 

186 varlen_and_split=varlen_and_split, 

187 softcap=has_softcap, 

188 append_kv=append_kv, 

189 ) 

190 return kBlockM, kBlockN 

191 

192 

193def round_up_headdim(headdim: int) -> int: 

194 if headdim <= 64: 

195 return 64 

196 if headdim <= 96: 

197 return 96 

198 if headdim <= 128: 

199 return 128 

200 if headdim <= 192: 

201 return 192 

202 if headdim <= 256: 

203 return 256 

204 return 256 

205 

206 

207def round_up_headdimv(headdim_v: int) -> int: 

208 if headdim_v <= 64: 

209 return 64 

210 if headdim_v <= 96: 

211 return 96 

212 if headdim_v <= 128: 

213 return 128 

214 if headdim_v <= 192: 

215 return 192 

216 if headdim_v <= 256: 

217 return 256 

218 return 512 

219 

220 

221def get_pagedkv_tma( 

222 arch: int, 

223 page_size: int, 

224 has_page_table: bool, 

225 leftpad_k: Optional[torch.Tensor], 

226 max_seqlen_q: int, 

227 max_seqlen_k_new: int, 

228 num_heads: int, 

229 num_heads_k: int, 

230 d_rounded: int, 

231 dv_rounded: int, 

232 is_causal: bool, 

233 is_local: bool, 

234 element_size: int, 

235 softcap: bool, 

236): 

237 if ( 

238 arch < 90 

239 or (not has_page_table) 

240 or (leftpad_k is not None) 

241 or (max_seqlen_k_new > 0) 

242 ): 

243 return False 

244 kBlockM, kBlockN = tile_size_fwd_sm90( 

245 headdim=d_rounded, 

246 headdim_v=dv_rounded, 

247 is_causal=is_causal, 

248 is_local=is_local, 

249 element_size=element_size, 

250 v_colmajor=False, 

251 paged_kv_non_TMA=False, 

252 softcap=softcap, 

253 use_one_mma_wg=False, 

254 ) 

255 if page_size % kBlockN != 0: 

256 return False 

257 seqlen_q_packgqa = max_seqlen_q * (num_heads // num_heads_k) 

258 return seqlen_q_packgqa > kBlockM 

259 

260 

261def use_one_mma_wg( 

262 arch: int, 

263 headdim: int, 

264 seqlen_q: int, 

265 pack_gqa: bool, 

266 num_heads: int, 

267 num_heads_k: int, 

268) -> bool: 

269 if arch < 90 or headdim != 128: 

270 return False 

271 

272 qhead_per_khead = 1 if not pack_gqa else num_heads // num_heads_k 

273 effective_seqlen_q = seqlen_q * qhead_per_khead 

274 

275 return effective_seqlen_q <= 64 

276 

277 

278def should_pack_gqa( 

279 varlen_q: bool, 

280 seqlen_q: int, 

281 qhead_per_khead: int, 

282 blockM: int, 

283) -> bool: 

284 if varlen_q: 

285 return True 

286 

287 def round_up(a: int, b: int) -> int: 

288 return (a + b - 1) // b * b 

289 

290 nopack_eff = float(seqlen_q) / float(round_up(seqlen_q, blockM)) 

291 pack_eff = float(seqlen_q * qhead_per_khead) / float( 

292 round_up(seqlen_q * qhead_per_khead, blockM) 

293 ) 

294 return nopack_eff < 0.9 * pack_eff 

295 

296 

297def get_num_splits( 

298 batch_size: int, 

299 num_heads: int, 

300 num_heads_k: int, 

301 headdim: int, 

302 headdim_v: int, 

303 d_rounded: int, 

304 dv_rounded: int, 

305 max_seqlen_q: int, 

306 max_seqlen_k: int, 

307 max_seqlen_k_new: int, 

308 arch: int, 

309 num_sm: int, 

310 is_causal: bool, 

311 is_local: bool, 

312 has_softcap: float, 

313 is_varlen: bool, 

314 has_page_table: bool, 

315 pack_gqa: bool, 

316 window_size_left: int, 

317 window_size_right: int, 

318 element_size: int = 2, # fp16/bf16 = 2, fp8 = 1 

319 max_splits: int = 128, 

320 use_dynamic_split: bool = False, 

321) -> int: 

322 pagedkv_tma = False 

323 append_kv = max_seqlen_k_new > 0 

324 

325 if arch >= 90: 

326 uomw = use_one_mma_wg( 

327 arch=arch, 

328 headdim=headdim, 

329 seqlen_q=max_seqlen_q, 

330 pack_gqa=pack_gqa, 

331 num_heads=num_heads, 

332 num_heads_k=num_heads_k, 

333 ) 

334 kBlockM, kBlockN = tile_size_fwd_sm90( 

335 headdim=d_rounded, 

336 headdim_v=dv_rounded, 

337 is_causal=is_causal, 

338 is_local=is_local, 

339 element_size=element_size, 

340 v_colmajor=False, 

341 paged_kv_non_TMA=(has_page_table and not pagedkv_tma), 

342 softcap=(has_softcap > 0.0), 

343 use_one_mma_wg=uomw, 

344 ) 

345 else: 

346 sm86_or_89 = arch == 86 or arch == 89 

347 kBlockM, kBlockN, _, _, _ = tile_size_fwd_sm8x( 

348 sm86_or_89=sm86_or_89, 

349 headdim=d_rounded, 

350 headdim_v=dv_rounded, 

351 is_causal=is_causal, 

352 is_local=is_local, 

353 element_size=element_size, 

354 paged_kv=has_page_table, 

355 varlen_and_split=is_varlen, 

356 softcap=(has_softcap > 0.0), 

357 append_kv=append_kv, 

358 ) 

359 

360 seqlen_q_packgqa = max_seqlen_q * (num_heads // num_heads_k) 

361 

362 if is_local: 

363 seqlen_k_loaded = max( 

364 0, 

365 min(max_seqlen_k, window_size_left + window_size_right + 1 + kBlockM), 

366 ) 

367 else: 

368 seqlen_k_loaded = max_seqlen_k 

369 

370 num_n_blocks = (seqlen_k_loaded + kBlockN - 1) // kBlockN 

371 num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) // kBlockM 

372 

373 size_one_kv_head = max_seqlen_k * (headdim + headdim_v) * element_size 

374 

375 effective_batch = 1 if use_dynamic_split else batch_size 

376 total_mblocks = effective_batch * num_heads_k * num_m_blocks 

377 

378 return _vllm_num_splits_heuristic( 

379 total_mblocks=total_mblocks, 

380 num_sm=num_sm, 

381 num_n_blocks=num_n_blocks, 

382 num_m_blocks=num_m_blocks, 

383 size_one_kv_head=size_one_kv_head, 

384 is_causal_or_local=is_causal or is_local, 

385 max_splits=max_splits, 

386 ) 

387 

388 

389def _vllm_num_splits_heuristic( 

390 total_mblocks: int, 

391 num_sm: int, 

392 num_n_blocks: int, 

393 num_m_blocks: int, 

394 size_one_kv_head: int, 

395 is_causal_or_local: bool, 

396 max_splits: int, 

397) -> int: 

398 if total_mblocks >= 0.8 * num_sm: 

399 size_l2 = 50 * 1024 * 1024 

400 if ( 

401 size_one_kv_head > size_l2 

402 and num_m_blocks >= num_sm * 2 

403 and not is_causal_or_local 

404 ): 

405 return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits) 

406 else: 

407 return 1 

408 

409 if num_n_blocks <= 4: 

410 return 1 

411 

412 max_splits = min(max_splits, num_sm, num_n_blocks) 

413 

414 max_efficiency = 0.0 

415 efficiencies = [] 

416 

417 for num_splits in range(1, max_splits + 1): 

418 n_waves = float(total_mblocks * num_splits) / num_sm 

419 eff = n_waves / math.ceil(n_waves) 

420 if eff > max_efficiency: 

421 max_efficiency = eff 

422 efficiencies.append(eff) 

423 

424 for num_splits in range(1, max_splits + 1): 

425 if efficiencies[num_splits - 1] >= 0.85 * max_efficiency: 

426 return num_splits 

427 

428 return 1 

429 

430 

431@triton.jit 

432def _prepare_pass1_kernel( 

433 num_m_blocks_ptr, 

434 num_n_blocks_ptr, 

435 total_blocks_ptr, 

436 seqlen_k_ptr, 

437 cu_seqlens_q_ptr, 

438 cu_seqlens_k_ptr, 

439 cu_seqlens_k_new_ptr, 

440 seqused_q_ptr, 

441 seqused_k_ptr, 

442 leftpad_k_ptr, 

443 batch, 

444 qhead_per_khead, 

445 max_seqlen_q: tl.constexpr, 

446 max_seqlen_k_new: tl.constexpr, 

447 BLOCK_M: tl.constexpr, 

448 BLOCK_N: tl.constexpr, 

449 BLOCK_SIZE_B: tl.constexpr, 

450 # HAS_XXX is used to implement static branch in kernel 

451 HAS_CU_SEQLENS_Q: tl.constexpr, 

452 HAS_CU_SEQLENS_K: tl.constexpr, 

453 HAS_SEQUSED_Q: tl.constexpr, 

454 HAS_SEQUSED_K: tl.constexpr, 

455 HAS_LEFT_PAD: tl.constexpr, 

456 HAS_K_NEW: tl.constexpr, 

457 HAS_CU_SEQLENS_K_NEW: tl.constexpr, 

458): 

459 pid = tl.program_id(0) 

460 b_start = pid * BLOCK_SIZE_B 

461 b_offs = b_start + tl.arange(0, BLOCK_SIZE_B) 

462 mask = b_offs < batch 

463 

464 if HAS_SEQUSED_Q: 

465 q_len = tl.load(seqused_q_ptr + b_offs, mask=mask, other=0) 

466 elif HAS_CU_SEQLENS_Q: 

467 cur = tl.load(cu_seqlens_q_ptr + b_offs, mask=mask, other=0) 

468 nxt = tl.load(cu_seqlens_q_ptr + b_offs + 1, mask=mask, other=0) 

469 q_len = nxt - cur 

470 else: 

471 q_len = tl.full( 

472 [BLOCK_SIZE_B], max_seqlen_q, dtype=tl.int32 

473 ) # max_seqlen_q constexpr 

474 q_len = q_len * qhead_per_khead 

475 m_blocks = (q_len + BLOCK_M - 1) // BLOCK_M 

476 

477 if HAS_SEQUSED_K: 

478 k_len = tl.load(seqused_k_ptr + b_offs, mask=mask, other=0) 

479 elif HAS_CU_SEQLENS_K: 

480 cur = tl.load(cu_seqlens_k_ptr + b_offs, mask=mask, other=0) 

481 nxt = tl.load(cu_seqlens_k_ptr + b_offs + 1, mask=mask, other=0) 

482 k_len = nxt - cur 

483 else: 

484 k_len = tl.load(seqlen_k_ptr + b_offs, mask=mask, other=0) 

485 left = tl.load(leftpad_k_ptr + b_offs, mask=mask, other=0) if HAS_LEFT_PAD else 0 

486 

487 if HAS_K_NEW: 

488 if HAS_CU_SEQLENS_K_NEW: 

489 cur_new = tl.load(cu_seqlens_k_new_ptr + b_offs, mask=mask, other=0) 

490 nxt_new = tl.load(cu_seqlens_k_new_ptr + b_offs + 1, mask=mask, other=0) 

491 k_len += nxt_new - cur_new 

492 else: 

493 k_len += max_seqlen_k_new 

494 k_len = k_len - left 

495 n_blocks = (k_len + BLOCK_N - 1) // BLOCK_N 

496 

497 tl.store(num_m_blocks_ptr + b_offs, m_blocks, mask=mask) 

498 tl.store(num_n_blocks_ptr + b_offs, n_blocks, mask=mask) 

499 total = m_blocks * n_blocks 

500 tl.atomic_add(total_blocks_ptr, tl.sum(total, axis=0)) 

501 

502 

503@triton.jit 

504def _prepare_pass2_kernel( 

505 num_n_blocks_per_seq_ptr, 

506 num_splits_dynamic_ptr, 

507 total_blocks, 

508 num_batch, 

509 num_head, 

510 num_sm, 

511 num_splits_static, 

512 BLOCK_SIZE_B: tl.constexpr, 

513): 

514 pid = tl.program_id(axis=0) 

515 b_start = pid * BLOCK_SIZE_B 

516 b_offsets = b_start + tl.arange(0, BLOCK_SIZE_B) 

517 b_mask = b_offsets < num_batch 

518 

519 blocks_per_sm_float = tl.ceil(total_blocks * 1.1 * num_head / num_sm) 

520 blocks_per_sm = blocks_per_sm_float.to(tl.int32) 

521 

522 blocks_per_sm = tl.maximum(1, blocks_per_sm) 

523 

524 num_n_blocks = tl.load(num_n_blocks_per_seq_ptr + b_offsets, mask=b_mask, other=0) 

525 num_splits_dynamic = (num_n_blocks + blocks_per_sm - 1) // blocks_per_sm 

526 

527 num_splits_dynamic = tl.minimum(num_splits_dynamic, num_splits_static) 

528 num_splits_dynamic = tl.maximum(1, num_splits_dynamic) 

529 

530 tl.store(num_splits_dynamic_ptr + b_offsets, num_splits_dynamic, mask=b_mask) 

531 

532 

533def get_pack_gqa( 

534 arch: int, 

535 has_page_table: bool, 

536 pagedkv_tma: bool, 

537 num_splits: int, 

538 num_heads: int, 

539 num_heads_k: int, 

540 # SM90-specific params for heuristic 

541 varlen_q: bool, 

542 seqlen_q: int, 

543 d_rounded: int, 

544 dv_rounded: int, 

545 is_causal: bool, 

546 is_local: bool, 

547 element_size: int, 

548 softcap: bool, 

549) -> bool: 

550 if arch < 90 or (has_page_table and not pagedkv_tma) or num_splits > 1: 

551 return True 

552 if num_heads == num_heads_k: 

553 return False 

554 kBlockM, _ = tile_size_fwd_sm90( 

555 headdim=d_rounded, 

556 headdim_v=dv_rounded, 

557 is_causal=is_causal, 

558 is_local=is_local, 

559 element_size=element_size, 

560 v_colmajor=False, 

561 paged_kv_non_TMA=(has_page_table and not pagedkv_tma), 

562 softcap=softcap, 

563 use_one_mma_wg=False, 

564 ) 

565 qhead_per_khead = num_heads // num_heads_k 

566 return should_pack_gqa(varlen_q, seqlen_q, qhead_per_khead, kBlockM) 

567 

568 

569def get_scheduler_metadata( 

570 batch_size: int, 

571 max_seqlen_q: int, 

572 max_seqlen_k: int, 

573 num_heads: int, 

574 num_heads_k: int, 

575 headdim: int, 

576 headdim_v: int, 

577 qkv_dtype: torch.dtype, 

578 seqused_k: torch.Tensor, 

579 cu_seqlens_q: Optional[torch.Tensor] = None, 

580 cu_seqlens_k: Optional[torch.Tensor] = None, 

581 cu_seqlens_k_new: Optional[torch.Tensor] = None, 

582 seqused_q: Optional[torch.Tensor] = None, 

583 leftpad_k: Optional[torch.Tensor] = None, 

584 page_size: Optional[int] = None, 

585 max_seqlen_k_new: int = 0, 

586 is_causal: bool = False, 

587 window_size_left: int = -1, 

588 window_size_right: int = -1, 

589 has_softcap: bool = False, 

590 num_splits: int = 0, 

591 pack_gqa: Optional[bool] = None, 

592 sm_margin: int = 0, 

593) -> torch.Tensor: 

594 device = seqused_k.device 

595 dtype = torch.int32 

596 

597 # check parameters 

598 supported_dtypes = (torch.half, torch.bfloat16) 

599 assert ( 

600 qkv_dtype in supported_dtypes 

601 ), "FlashAttention only supports fp16 and bf16 data type" 

602 assert ( 

603 num_heads % num_heads_k == 0 

604 ), "Number of heads in key/value must divide number of heads in query" 

605 

606 # is_causal & window_size implementation 

607 effective_is_causal = is_causal 

608 effective_window_left = window_size_left if window_size_left >= 0 else -1 

609 effective_window_right = window_size_right 

610 

611 if effective_window_left >= max_seqlen_k - 1: 

612 effective_window_left = -1 

613 if effective_window_right >= max_seqlen_q - 1: 

614 effective_window_right = -1 

615 

616 if ( 

617 max_seqlen_q == 1 

618 and effective_window_left == -1 

619 and effective_window_right == -1 

620 ): 

621 if (headdim <= 64 or headdim > 128) or page_size is None: 

622 effective_is_causal = False 

623 

624 if effective_is_causal: 

625 effective_window_right = 0 

626 

627 final_is_causal = effective_window_left < 0 and effective_window_right == 0 

628 final_is_local = ( 

629 effective_window_left >= 0 or effective_window_right >= 0 

630 ) and not final_is_causal 

631 

632 major, minor = get_device_capability() 

633 arch = major * 10 + minor 

634 num_sm = get_device_info().sm_count - sm_margin 

635 

636 softcap = 1.0 if has_softcap else 0.0 

637 

638 element_size = qkv_dtype.itemsize 

639 

640 has_page_table = page_size is not None 

641 

642 d_rounded = round_up_headdim(headdim) 

643 dv_rounded = round_up_headdimv(headdim_v) 

644 

645 pagedkv_tma = get_pagedkv_tma( 

646 arch=arch, 

647 page_size=page_size if page_size is not None else 1, 

648 has_page_table=has_page_table, 

649 leftpad_k=leftpad_k, 

650 max_seqlen_q=max_seqlen_q, 

651 max_seqlen_k_new=max_seqlen_k_new, 

652 num_heads=num_heads, 

653 num_heads_k=num_heads_k, 

654 d_rounded=d_rounded, 

655 dv_rounded=dv_rounded, 

656 is_causal=final_is_causal, 

657 is_local=final_is_local, 

658 element_size=element_size, 

659 softcap=has_softcap, 

660 ) 

661 

662 blockM, blockN = get_optimal_block_mn( 

663 device=device, 

664 headdim=headdim, 

665 headdim_v=headdim_v, 

666 is_causal=final_is_causal, 

667 is_local=final_is_local, 

668 has_softcap=has_softcap, 

669 element_size=element_size, 

670 paged_kv=has_page_table, 

671 pagedkv_tma=pagedkv_tma, 

672 ) 

673 

674 # GQA 

675 varlen_q_flag = cu_seqlens_q is not None or seqused_q is not None 

676 pack_gqa = ( 

677 pack_gqa 

678 if pack_gqa is not None 

679 else get_pack_gqa( 

680 arch=arch, 

681 has_page_table=has_page_table, 

682 pagedkv_tma=pagedkv_tma, 

683 num_splits=num_splits, 

684 num_heads=num_heads, 

685 num_heads_k=num_heads_k, 

686 varlen_q=varlen_q_flag, 

687 seqlen_q=max_seqlen_q, 

688 d_rounded=d_rounded, 

689 dv_rounded=dv_rounded, 

690 is_causal=final_is_causal, 

691 is_local=final_is_local, 

692 element_size=element_size, 

693 softcap=has_softcap, 

694 ) 

695 ) 

696 qhead_per_khead = ( 

697 1 if not pack_gqa else (num_heads + num_heads_k - 1) // num_heads_k 

698 ) 

699 num_head_k = num_heads_k if pack_gqa else num_heads 

700 

701 seqlen_q = ( 

702 seqused_q 

703 if seqused_q is not None 

704 else torch.full((batch_size,), max_seqlen_q, dtype=dtype, device=device) 

705 ) 

706 seqlen_k = seqused_k 

707 seqlen_knew = ( 

708 torch.full((batch_size,), max_seqlen_k_new, dtype=dtype, device=device) 

709 if max_seqlen_k_new > 0 

710 else None 

711 ) 

712 

713 num_m_blocks = torch.empty_like(seqlen_q) 

714 num_n_blocks = torch.empty_like(seqlen_k) 

715 total_blocks = torch.zeros((1,), dtype=dtype, device=device) 

716 num_splits_dynamic = torch.empty_like(seqlen_q) 

717 

718 BLOCK_SIZE_B = 128 

719 grid = (triton.cdiv(batch_size, BLOCK_SIZE_B),) 

720 

721 total_blocks_val = total_blocks.item() 

722 

723 # dynamic split depends ONLY on batch_size, regardless of num_splits_static 

724 use_dynamic_split = batch_size <= 992 

725 

726 if num_splits <= 0: 

727 element_size = qkv_dtype.itemsize 

728 is_fp16 = qkv_dtype == torch.float16 

729 is_bf16 = qkv_dtype == torch.bfloat16 

730 

731 if not (is_fp16 or is_bf16): 

732 raise ValueError( 

733 f"不支持的数据类型: {qkv_dtype}. FlashAttention只支持: torch.float16, torch.bfloat16" 

734 ) 

735 

736 d_rounded = d_rounded 

737 dv_rounded = dv_rounded 

738 

739 eff_num_splits = get_num_splits( 

740 batch_size=batch_size, 

741 num_heads=num_heads, 

742 num_heads_k=num_heads_k, 

743 headdim=headdim, 

744 headdim_v=headdim_v, 

745 d_rounded=d_rounded, 

746 dv_rounded=dv_rounded, 

747 max_seqlen_q=max_seqlen_q, 

748 max_seqlen_k=max_seqlen_k, 

749 max_seqlen_k_new=max_seqlen_k_new, 

750 arch=arch, 

751 num_sm=num_sm, 

752 is_causal=final_is_causal, 

753 is_local=final_is_local, 

754 has_softcap=softcap, 

755 is_varlen=True, 

756 has_page_table=has_page_table, 

757 pack_gqa=pack_gqa, 

758 window_size_left=effective_window_left, 

759 window_size_right=effective_window_right, 

760 element_size=element_size, 

761 use_dynamic_split=use_dynamic_split, 

762 ) 

763 else: 

764 eff_num_splits = num_splits 

765 

766 eff_num_splits = min(eff_num_splits, 256, num_sm) 

767 

768 # Always enable PackGQA for Split 

769 pack_gqa = True if eff_num_splits > 1 else pack_gqa 

770 

771 # Recompute qhead_per_khead/num_head_k for the kernels 

772 qhead_per_khead = ( 

773 1 if not pack_gqa else (num_heads + num_heads_k - 1) // num_heads_k 

774 ) 

775 num_head_k = num_heads_k if pack_gqa else num_heads 

776 

777 is_varlen = True 

778 if arch >= 90: 

779 uomw = use_one_mma_wg( 

780 arch=arch, 

781 headdim=headdim, 

782 seqlen_q=max_seqlen_q, 

783 pack_gqa=pack_gqa, 

784 num_heads=num_heads, 

785 num_heads_k=num_heads_k, 

786 ) 

787 blockM, blockN = tile_size_fwd_sm90( 

788 headdim=round_up_headdim(headdim), 

789 headdim_v=round_up_headdimv(headdim_v), 

790 is_causal=final_is_causal, 

791 is_local=final_is_local, 

792 element_size=element_size, 

793 v_colmajor=False, 

794 paged_kv_non_TMA=(has_page_table and not pagedkv_tma), 

795 softcap=has_softcap, 

796 use_one_mma_wg=uomw, 

797 ) 

798 else: 

799 blockM, blockN = get_optimal_block_mn( 

800 device=device, 

801 headdim=headdim, 

802 headdim_v=headdim_v, 

803 is_causal=final_is_causal, 

804 is_local=final_is_local, 

805 has_softcap=has_softcap, 

806 element_size=element_size, 

807 paged_kv=has_page_table, 

808 pagedkv_tma=pagedkv_tma, 

809 varlen_and_split=is_varlen and (eff_num_splits > 1), 

810 append_kv=(max_seqlen_k_new > 0), 

811 ) 

812 

813 num_m_blocks = torch.empty_like(seqlen_q) 

814 num_n_blocks = torch.empty_like(seqlen_k) 

815 total_blocks = torch.zeros((1,), dtype=dtype, device=device) 

816 num_splits_dynamic = torch.empty_like(seqlen_q) 

817 

818 _prepare_pass1_kernel[grid]( 

819 num_m_blocks, 

820 num_n_blocks, 

821 total_blocks, 

822 seqlen_k, 

823 cu_seqlens_q, 

824 cu_seqlens_k, 

825 cu_seqlens_k_new, 

826 seqused_q, 

827 seqused_k, 

828 leftpad_k, 

829 batch_size, 

830 qhead_per_khead, 

831 max_seqlen_q=max_seqlen_q, 

832 max_seqlen_k_new=max_seqlen_k_new, 

833 BLOCK_M=blockM, 

834 BLOCK_N=blockN, 

835 BLOCK_SIZE_B=BLOCK_SIZE_B, 

836 HAS_CU_SEQLENS_Q=cu_seqlens_q is not None, 

837 HAS_CU_SEQLENS_K=cu_seqlens_k is not None, 

838 HAS_SEQUSED_Q=seqused_q is not None, 

839 HAS_SEQUSED_K=True, 

840 HAS_LEFT_PAD=leftpad_k is not None, 

841 HAS_K_NEW=seqlen_knew is not None, 

842 HAS_CU_SEQLENS_K_NEW=cu_seqlens_k_new is not None, 

843 ) 

844 

845 total_blocks_val = total_blocks.item() 

846 

847 if use_dynamic_split: 

848 _prepare_pass2_kernel[grid]( 

849 num_n_blocks, 

850 num_splits_dynamic, 

851 total_blocks=total_blocks_val, 

852 num_batch=batch_size, 

853 num_head=num_head_k, 

854 num_sm=num_sm, 

855 num_splits_static=eff_num_splits, 

856 BLOCK_SIZE_B=BLOCK_SIZE_B, 

857 ) 

858 else: 

859 num_splits_dynamic.fill_(eff_num_splits) 

860 

861 final_num_splits = eff_num_splits 

862 

863 is_varlen = True 

864 

865 if arch >= 90: 

866 scheduler_needs_semaphore = ( 

867 (final_is_causal or final_is_local) and (final_num_splits == 1) 

868 ) or is_varlen 

869 else: 

870 scheduler_needs_semaphore = (final_is_causal and not is_varlen) or ( 

871 is_varlen and final_num_splits > 1 

872 ) 

873 

874 if use_dynamic_split: 

875 final_num_splits_for_sem_check = eff_num_splits 

876 else: 

877 final_num_splits_for_sem_check = eff_num_splits 

878 

879 scheduler_needs_semaphore = arch >= 90 or final_num_splits_for_sem_check > 1 

880 

881 alloc_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * batch_size 

882 

883 if alloc_size > 0: 

884 scheduler_metadata = torch.empty(alloc_size, dtype=torch.int32, device=device) 

885 offset = 0 

886 if scheduler_needs_semaphore: 

887 scheduler_metadata[offset] = 0 

888 offset += 1 

889 

890 if use_dynamic_split: 

891 scheduler_metadata[offset:] = num_splits_dynamic 

892 elif scheduler_needs_semaphore and not use_dynamic_split: 

893 pass 

894 return scheduler_metadata 

895 else: 

896 return torch.empty((0,), dtype=torch.int32, device=device)