Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/get_scheduler_metadata.py: 0%

277 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.utils.device_info import get_device_capability, get_device_info 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14def get_dtype_bytes(dtype): 

15 if dtype.is_floating_point: 

16 return int(torch.finfo(dtype).bits / 8) 

17 else: 

18 return int(torch.iinfo(dtype).bits / 8) 

19 

20 

21def tile_size_fwd_sm8x( 

22 sm86_or_89: bool, 

23 headdim: int, 

24 headdim_v: int, 

25 is_causal: bool, 

26 is_local: bool, 

27 element_size: int = 2, 

28 paged_kv: bool = False, 

29 varlen_and_split: bool = False, 

30 softcap: bool = False, 

31 append_kv: bool = False, 

32): 

33 if element_size == 2: # fp16/bf16 

34 if headdim <= 64: 

35 kBlockM = 128 

36 kBlockN = 80 if varlen_and_split else (96 if is_local else 112) 

37 kNWarps = 4 

38 kStages = 1 

39 Q_in_regs = False 

40 

41 elif headdim <= 96: 

42 kBlockM = 128 

43 kBlockN = 48 if (varlen_and_split or is_local) else 64 

44 kNWarps = 4 

45 kStages = 1 

46 Q_in_regs = False 

47 

48 elif headdim <= 128: 

49 use_8_warps = sm86_or_89 or varlen_and_split 

50 kBlockM = 128 

51 if use_8_warps: 

52 kBlockN = ( 

53 (96 if is_local else 112) 

54 if varlen_and_split 

55 else (96 if is_local else 128) 

56 ) 

57 else: 

58 kBlockN = 48 if is_local else 64 

59 kNWarps = 8 if use_8_warps else 4 

60 kStages = 1 

61 Q_in_regs = use_8_warps 

62 

63 elif headdim <= 192: 

64 kBlockN_64 = append_kv or is_local or varlen_and_split or paged_kv 

65 kBlockM = 128 

66 kBlockN = 64 if kBlockN_64 else 96 

67 kNWarps = 8 

68 kStages = 1 if sm86_or_89 else 2 

69 Q_in_regs = not kBlockN_64 

70 

71 else: # headdim > 192 

72 kBlockM = 128 

73 if sm86_or_89: 

74 if append_kv: 

75 kBlockN = 32 

76 elif varlen_and_split or is_local: 

77 kBlockN = 48 

78 else: 

79 kBlockN = 64 

80 else: 

81 if append_kv: 

82 kBlockN = 48 

83 elif varlen_and_split or is_local: 

84 kBlockN = 64 

85 else: 

86 kBlockN = 96 

87 kNWarps = 8 

88 kStages = 1 

89 Q_in_regs = sm86_or_89 and not append_kv 

90 else: 

91 kBlockM = 128 

92 kBlockN = 64 

93 kNWarps = 8 

94 kStages = 2 

95 Q_in_regs = False 

96 

97 return kBlockM, kBlockN, kNWarps, kStages, Q_in_regs 

98 

99 

100def get_optimal_block_mn( 

101 device, 

102 headdim, 

103 headdim_v, 

104 is_causal, 

105 is_local, 

106 has_softcap, 

107 element_size=2, 

108 paged_kv=False, 

109 varlen_and_split=False, 

110 append_kv=False, 

111): 

112 major, minor = get_device_capability() 

113 arch = major * 10 + minor 

114 sm86_or_89 = arch == 86 or arch == 89 

115 

116 kBlockM, kBlockN, kNWarps, kStages, Q_in_regs = tile_size_fwd_sm8x( 

117 sm86_or_89=sm86_or_89, 

118 headdim=headdim, 

119 headdim_v=headdim_v, 

120 is_causal=is_causal, 

121 is_local=is_local, 

122 element_size=element_size, 

123 paged_kv=paged_kv, 

124 varlen_and_split=varlen_and_split, 

125 softcap=has_softcap, 

126 append_kv=append_kv, 

127 ) 

128 

129 return kBlockM, kBlockN 

130 

131 

132def round_up_headdim(headdim: int) -> int: 

133 if headdim <= 64: 

134 return 64 

135 if headdim <= 96: 

136 return 96 

137 if headdim <= 128: 

138 return 128 

139 if headdim <= 192: 

140 return 192 

141 if headdim <= 256: 

142 return 256 

143 return 256 

144 

145 

146def round_up_headdimv(headdim_v: int) -> int: 

147 if headdim_v <= 64: 

148 return 64 

149 if headdim_v <= 96: 

150 return 96 

151 if headdim_v <= 128: 

152 return 128 

153 if headdim_v <= 192: 

154 return 192 

155 if headdim_v <= 256: 

156 return 256 

157 return 512 

158 

159 

160def use_one_mma_wg( 

161 arch: int, 

162 headdim: int, 

163 seqlen_q: int, 

164 pack_gqa: bool, 

165 num_heads: int, 

166 num_heads_k: int, 

167) -> bool: 

168 if arch < 90 or headdim != 128: 

169 return False 

170 

171 qhead_per_khead = 1 if not pack_gqa else num_heads // num_heads_k 

172 effective_seqlen_q = seqlen_q * qhead_per_khead 

173 

174 return effective_seqlen_q <= 64 

175 

176 

177def get_num_splits( 

178 batch_size: int, 

179 num_heads: int, 

180 num_heads_k: int, 

181 headdim: int, 

182 headdim_v: int, 

183 d_rounded: int, 

184 dv_rounded: int, 

185 max_seqlen_q: int, 

186 max_seqlen_k: int, 

187 max_seqlen_k_new: int, 

188 arch: int, 

189 num_sm: int, 

190 is_causal: bool, 

191 is_local: bool, 

192 has_softcap: float, 

193 is_varlen: bool, 

194 has_page_table: bool, 

195 element_size: int = 2, # fp16/bf16 = 2, fp8 = 1 

196 max_splits: int = 128, 

197 use_dynamic_split: bool = False, 

198) -> int: 

199 pagedkv_tma = False 

200 append_kv = max_seqlen_k_new > 0 

201 

202 if arch >= 90: 

203 # TODO: tile_size_fwd_sm90 

204 kBlockM, kBlockN = get_optimal_block_mn( 

205 device=0, 

206 headdim=d_rounded, 

207 headdim_v=dv_rounded, 

208 is_causal=is_causal, 

209 is_local=is_local, 

210 has_softcap=has_softcap, 

211 element_size=element_size, 

212 paged_kv=has_page_table and not pagedkv_tma, 

213 varlen_and_split=is_varlen, 

214 append_kv=append_kv, 

215 ) 

216 else: 

217 sm86_or_89 = arch == 86 or arch == 89 

218 kBlockM, kBlockN, _, _, _ = tile_size_fwd_sm8x( 

219 sm86_or_89=sm86_or_89, 

220 headdim=d_rounded, 

221 headdim_v=dv_rounded, 

222 is_causal=is_causal, 

223 is_local=is_local, 

224 element_size=element_size, 

225 paged_kv=has_page_table, 

226 varlen_and_split=is_varlen, 

227 softcap=(has_softcap > 0.0), 

228 append_kv=append_kv, 

229 ) 

230 

231 seqlen_q_packgqa = max_seqlen_q * (num_heads // num_heads_k) 

232 

233 if is_local: 

234 seqlen_k_loaded = max(0, min(max_seqlen_k, kBlockM + max_seqlen_q)) 

235 else: 

236 seqlen_k_loaded = max_seqlen_k 

237 

238 num_n_blocks = (seqlen_k_loaded + kBlockN - 1) // kBlockN 

239 num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) // kBlockM 

240 

241 size_one_kv_head = max_seqlen_k * (headdim + headdim_v) * element_size 

242 

243 effective_batch = 1 if use_dynamic_split else batch_size 

244 total_mblocks = effective_batch * num_heads_k * num_m_blocks 

245 

246 return _vllm_num_splits_heuristic( 

247 total_mblocks=total_mblocks, 

248 num_sm=num_sm, 

249 num_n_blocks=num_n_blocks, 

250 num_m_blocks=num_m_blocks, 

251 size_one_kv_head=size_one_kv_head, 

252 is_causal_or_local=is_causal or is_local, 

253 max_splits=max_splits, 

254 ) 

255 

256 

257def _vllm_num_splits_heuristic( 

258 total_mblocks: int, 

259 num_sm: int, 

260 num_n_blocks: int, 

261 num_m_blocks: int, 

262 size_one_kv_head: int, 

263 is_causal_or_local: bool, 

264 max_splits: int, 

265) -> int: 

266 if total_mblocks >= 0.8 * num_sm: 

267 size_l2 = 50 * 1024 * 1024 

268 if ( 

269 size_one_kv_head > size_l2 

270 and num_m_blocks >= num_sm * 2 

271 and not is_causal_or_local 

272 ): 

273 return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits) 

274 else: 

275 return 1 

276 

277 if num_n_blocks <= 4: 

278 return 1 

279 

280 max_splits = min(max_splits, num_sm, num_n_blocks) 

281 

282 max_efficiency = 0.0 

283 efficiencies = [] 

284 

285 for num_splits in range(1, max_splits + 1): 

286 n_waves = float(total_mblocks * num_splits) / num_sm 

287 eff = n_waves / math.ceil(n_waves) 

288 if eff > max_efficiency: 

289 max_efficiency = eff 

290 efficiencies.append(eff) 

291 

292 for num_splits in range(1, max_splits + 1): 

293 if efficiencies[num_splits - 1] >= 0.85 * max_efficiency: 

294 return num_splits 

295 

296 return 1 

297 

298 

299@triton.jit 

300def _prepare_pass1_kernel( 

301 num_m_blocks_ptr, 

302 num_n_blocks_ptr, 

303 total_blocks_ptr, 

304 seqlen_k_ptr, 

305 cu_seqlens_q_ptr, 

306 cu_seqlens_k_ptr, 

307 cu_seqlens_k_new_ptr, 

308 seqused_q_ptr, 

309 seqused_k_ptr, 

310 leftpad_k_ptr, 

311 batch, 

312 qhead_per_khead, 

313 max_seqlen_q: tl.constexpr, 

314 max_seqlen_k_new: tl.constexpr, 

315 BLOCK_M: tl.constexpr, 

316 BLOCK_N: tl.constexpr, 

317 BLOCK_SIZE_B: tl.constexpr, 

318 # HAS_XXX is used to implement static branch in kernel 

319 HAS_CU_SEQLENS_Q: tl.constexpr, 

320 HAS_CU_SEQLENS_K: tl.constexpr, 

321 HAS_SEQUSED_Q: tl.constexpr, 

322 HAS_SEQUSED_K: tl.constexpr, 

323 HAS_LEFT_PAD: tl.constexpr, 

324 HAS_K_NEW: tl.constexpr, 

325 HAS_CU_SEQLENS_K_NEW: tl.constexpr, 

326): 

327 pid = tl.program_id(0) 

328 b_start = pid * BLOCK_SIZE_B 

329 b_offs = b_start + tl.arange(0, BLOCK_SIZE_B) 

330 mask = b_offs < batch 

331 

332 if HAS_SEQUSED_Q: 

333 q_len = tl.load(seqused_q_ptr + b_offs, mask=mask, other=0) 

334 elif HAS_CU_SEQLENS_Q: 

335 cur = tl.load(cu_seqlens_q_ptr + b_offs, mask=mask, other=0) 

336 nxt = tl.load(cu_seqlens_q_ptr + b_offs + 1, mask=mask, other=0) 

337 q_len = nxt - cur 

338 else: 

339 q_len = tl.full( 

340 [BLOCK_SIZE_B], max_seqlen_q, dtype=tl.int32 

341 ) # max_seqlen_q constexpr 

342 q_len = q_len * qhead_per_khead 

343 m_blocks = (q_len + BLOCK_M - 1) // BLOCK_M 

344 

345 if HAS_SEQUSED_K: 

346 k_len = tl.load(seqused_k_ptr + b_offs, mask=mask, other=0) 

347 elif HAS_CU_SEQLENS_K: 

348 cur = tl.load(cu_seqlens_k_ptr + b_offs, mask=mask, other=0) 

349 nxt = tl.load(cu_seqlens_k_ptr + b_offs + 1, mask=mask, other=0) 

350 k_len = nxt - cur 

351 else: 

352 k_len = tl.load(seqlen_k_ptr + b_offs, mask=mask, other=0) 

353 left = tl.load(leftpad_k_ptr + b_offs, mask=mask, other=0) if HAS_LEFT_PAD else 0 

354 

355 if HAS_K_NEW: 

356 if HAS_CU_SEQLENS_K_NEW: 

357 cur_new = tl.load(cu_seqlens_k_new_ptr + b_offs, mask=mask, other=0) 

358 nxt_new = tl.load(cu_seqlens_k_new_ptr + b_offs + 1, mask=mask, other=0) 

359 k_len += nxt_new - cur_new 

360 else: 

361 k_len += max_seqlen_k_new 

362 k_len = k_len - left 

363 n_blocks = (k_len + BLOCK_N - 1) // BLOCK_N 

364 

365 tl.store(num_m_blocks_ptr + b_offs, m_blocks, mask=mask) 

366 tl.store(num_n_blocks_ptr + b_offs, n_blocks, mask=mask) 

367 total = m_blocks * n_blocks 

368 tl.atomic_add(total_blocks_ptr, tl.sum(total, axis=0)) 

369 

370 

371@triton.jit 

372def _prepare_pass2_kernel( 

373 num_n_blocks_per_seq_ptr, 

374 num_splits_dynamic_ptr, 

375 total_blocks, 

376 num_batch, 

377 num_head, 

378 num_sm, 

379 num_splits_static, 

380 BLOCK_SIZE_B: tl.constexpr, 

381): 

382 """ 

383 Triton Kernel: Pass 2 

384 - Calculates the dynamic number of splits for the Split-K optimization, 

385 based on the total number of blocks computed in Pass 1. 

386 """ 

387 pid = tl.program_id(axis=0) 

388 b_start = pid * BLOCK_SIZE_B 

389 b_offsets = b_start + tl.arange(0, BLOCK_SIZE_B) 

390 b_mask = b_offsets < num_batch 

391 

392 blocks_per_sm_float = tl.ceil(total_blocks * 1.1 * num_head / num_sm) 

393 blocks_per_sm = blocks_per_sm_float.to(tl.int32) 

394 

395 blocks_per_sm = tl.maximum(1, blocks_per_sm) 

396 

397 num_n_blocks = tl.load(num_n_blocks_per_seq_ptr + b_offsets, mask=b_mask, other=0) 

398 num_splits_dynamic = (num_n_blocks + blocks_per_sm - 1) // blocks_per_sm 

399 

400 num_splits_dynamic = tl.minimum(num_splits_dynamic, num_splits_static) 

401 num_splits_dynamic = tl.maximum(1, num_splits_dynamic) 

402 

403 tl.store(num_splits_dynamic_ptr + b_offsets, num_splits_dynamic, mask=b_mask) 

404 

405 

406def get_pack_gqa( 

407 arch: int, 

408 has_page_table: bool, 

409 pagedkv_tma: bool, 

410 num_splits: int, 

411 num_heads: int, 

412 num_heads_k: int, 

413) -> bool: 

414 if arch < 90 or (has_page_table and not pagedkv_tma) or num_splits > 1: 

415 return True 

416 

417 if num_heads == num_heads_k: 

418 return False 

419 

420 # TODO: implement tile_size_fwd_sm90 and should_pack_gqa (Hopper+ only) 

421 return False 

422 

423 

424def get_scheduler_metadata( 

425 batch_size: int, 

426 max_seqlen_q: int, 

427 max_seqlen_k: int, 

428 num_heads: int, 

429 num_heads_k: int, 

430 headdim: int, 

431 headdim_v: int, 

432 qkv_dtype: torch.dtype, 

433 seqused_k: torch.Tensor, 

434 cu_seqlens_q: Optional[torch.Tensor] = None, 

435 cu_seqlens_k: Optional[torch.Tensor] = None, 

436 cu_seqlens_k_new: Optional[torch.Tensor] = None, 

437 seqused_q: Optional[torch.Tensor] = None, 

438 leftpad_k: Optional[torch.Tensor] = None, 

439 page_size: Optional[int] = None, 

440 max_seqlen_k_new: int = 0, 

441 is_causal: bool = False, 

442 window_size_left: int = -1, 

443 window_size_right: int = -1, 

444 has_softcap: bool = False, 

445 num_splits: int = 0, 

446 pack_gqa: Optional[bool] = None, 

447 sm_margin: int = 0, 

448) -> torch.Tensor: 

449 device = seqused_k.device 

450 dtype = torch.int32 

451 

452 # check parameters 

453 supported_dtypes = (torch.half, torch.bfloat16) 

454 assert ( 

455 qkv_dtype in supported_dtypes 

456 ), "FlashAttention only supports fp16 and bf16 data type" 

457 assert ( 

458 num_heads % num_heads_k == 0 

459 ), "Number of heads in key/value must divide number of heads in query" 

460 

461 # is_causal & window_size implementation 

462 effective_is_causal = is_causal 

463 effective_window_left = window_size_left if window_size_left >= 0 else -1 

464 effective_window_right = window_size_right 

465 

466 if effective_window_left >= max_seqlen_k - 1: 

467 effective_window_left = -1 

468 if effective_window_right >= max_seqlen_q - 1: 

469 effective_window_right = -1 

470 

471 if ( 

472 max_seqlen_q == 1 

473 and effective_window_left == -1 

474 and effective_window_right == -1 

475 ): 

476 if (headdim <= 64 or headdim > 128) or page_size is None: 

477 effective_is_causal = False 

478 

479 if effective_is_causal: 

480 effective_window_right = 0 

481 

482 final_is_causal = effective_window_left < 0 and effective_window_right == 0 

483 final_is_local = ( 

484 effective_window_left >= 0 or effective_window_right >= 0 

485 ) and not final_is_causal 

486 

487 major, minor = get_device_capability() 

488 arch = major * 10 + minor 

489 num_sm = get_device_info().sm_count - sm_margin 

490 

491 softcap = 1.0 if has_softcap else 0.0 

492 

493 element_size = get_dtype_bytes(qkv_dtype) 

494 

495 has_page_table = page_size is not None 

496 

497 # TODO implement get_pagedkv_tma function (Hopper+ only) 

498 pagedkv_tma = False 

499 

500 blockM, blockN = get_optimal_block_mn( 

501 device=device, 

502 headdim=headdim, 

503 headdim_v=headdim_v, 

504 is_causal=final_is_causal, 

505 is_local=final_is_local, 

506 has_softcap=has_softcap, 

507 element_size=element_size, 

508 ) 

509 

510 # GQA 

511 pack_gqa = ( 

512 pack_gqa 

513 if pack_gqa is not None 

514 else get_pack_gqa( 

515 arch=arch, 

516 has_page_table=has_page_table, 

517 pagedkv_tma=pagedkv_tma, 

518 num_splits=num_splits, # Note: user-provided num_splits, not eff_num_splits 

519 num_heads=num_heads, 

520 num_heads_k=num_heads_k, 

521 ) 

522 ) 

523 qhead_per_khead = ( 

524 1 if not pack_gqa else (num_heads + num_heads_k - 1) // num_heads_k 

525 ) 

526 num_head_k = num_heads_k if pack_gqa else num_heads 

527 

528 # TODO: implement use_one_mma_wg (Hopper+ only) 

529 

530 seqlen_q = ( 

531 seqused_q 

532 if seqused_q is not None 

533 else torch.full((batch_size,), max_seqlen_q, dtype=dtype, device=device) 

534 ) 

535 seqlen_k = seqused_k 

536 seqlen_knew = ( 

537 torch.full((batch_size,), max_seqlen_k_new, dtype=dtype, device=device) 

538 if max_seqlen_k_new > 0 

539 else None 

540 ) 

541 

542 num_m_blocks = torch.empty_like(seqlen_q) 

543 num_n_blocks = torch.empty_like(seqlen_k) 

544 total_blocks = torch.zeros((1,), dtype=dtype, device=device) 

545 num_splits_dynamic = torch.empty_like(seqlen_q) 

546 

547 BLOCK_SIZE_B = 128 

548 grid = (triton.cdiv(batch_size, BLOCK_SIZE_B),) 

549 

550 _prepare_pass1_kernel[grid]( 

551 num_m_blocks, 

552 num_n_blocks, 

553 total_blocks, 

554 seqlen_k, 

555 cu_seqlens_q, 

556 cu_seqlens_k, 

557 cu_seqlens_k_new, 

558 seqused_q, 

559 seqused_k, 

560 leftpad_k, 

561 batch_size, 

562 qhead_per_khead, 

563 max_seqlen_q=max_seqlen_q, 

564 max_seqlen_k_new=max_seqlen_k_new, 

565 BLOCK_M=blockM, 

566 BLOCK_N=blockN, 

567 BLOCK_SIZE_B=BLOCK_SIZE_B, 

568 HAS_CU_SEQLENS_Q=cu_seqlens_q is not None, 

569 HAS_CU_SEQLENS_K=cu_seqlens_k is not None, 

570 HAS_SEQUSED_Q=seqused_q is not None, 

571 HAS_SEQUSED_K=True, 

572 HAS_LEFT_PAD=leftpad_k is not None, 

573 HAS_K_NEW=seqlen_knew is not None, 

574 HAS_CU_SEQLENS_K_NEW=cu_seqlens_k_new is not None, 

575 ) 

576 

577 total_blocks_val = total_blocks.item() 

578 

579 use_dynamic_split = (num_splits <= 0) and (batch_size <= 992) 

580 

581 if num_splits <= 0: 

582 element_size = get_dtype_bytes(qkv_dtype) 

583 is_fp16 = qkv_dtype == torch.float16 

584 is_bf16 = qkv_dtype == torch.bfloat16 

585 

586 if not (is_fp16 or is_bf16): 

587 raise ValueError( 

588 f"不支持的数据类型: {qkv_dtype}. FlashAttention只支持: torch.float16, torch.bfloat16" 

589 ) 

590 

591 d_rounded = round_up_headdim(headdim) 

592 dv_rounded = round_up_headdimv(headdim_v) 

593 

594 eff_num_splits = get_num_splits( 

595 batch_size=batch_size, 

596 num_heads=num_heads, 

597 num_heads_k=num_heads_k, 

598 headdim=headdim, 

599 headdim_v=headdim_v, 

600 d_rounded=d_rounded, 

601 dv_rounded=dv_rounded, 

602 max_seqlen_q=max_seqlen_q, 

603 max_seqlen_k=max_seqlen_k, 

604 max_seqlen_k_new=max_seqlen_k_new, 

605 arch=arch, 

606 num_sm=num_sm, 

607 is_causal=final_is_causal, 

608 is_local=final_is_local, 

609 has_softcap=softcap, 

610 is_varlen=True, 

611 has_page_table=has_page_table, 

612 element_size=element_size, 

613 use_dynamic_split=use_dynamic_split, 

614 ) 

615 else: 

616 eff_num_splits = num_splits 

617 

618 eff_num_splits = min(eff_num_splits, 256, num_sm) 

619 

620 pack_gqa = eff_num_splits > 1 

621 

622 if pack_gqa: 

623 qhead_per_khead = (num_heads + num_heads_k - 1) // num_heads_k 

624 num_head_k = num_heads_k 

625 else: 

626 qhead_per_khead = 1 

627 num_head_k = num_heads 

628 

629 if use_dynamic_split: 

630 _prepare_pass2_kernel[grid]( 

631 num_n_blocks, 

632 num_splits_dynamic, 

633 total_blocks=total_blocks_val, 

634 num_batch=batch_size, 

635 num_head=num_head_k, 

636 num_sm=num_sm, 

637 num_splits_static=eff_num_splits, 

638 BLOCK_SIZE_B=BLOCK_SIZE_B, 

639 ) 

640 else: 

641 num_splits_dynamic.fill_(eff_num_splits) 

642 

643 final_num_splits = eff_num_splits 

644 

645 is_varlen = True 

646 

647 if arch >= 90: 

648 scheduler_needs_semaphore = ( 

649 (final_is_causal or final_is_local) and (final_num_splits == 1) 

650 ) or is_varlen 

651 else: 

652 scheduler_needs_semaphore = (final_is_causal and not is_varlen) or ( 

653 is_varlen and final_num_splits > 1 

654 ) 

655 

656 if use_dynamic_split: 

657 final_num_splits_for_sem_check = eff_num_splits 

658 else: 

659 final_num_splits_for_sem_check = eff_num_splits 

660 

661 scheduler_needs_semaphore = arch >= 90 or final_num_splits_for_sem_check > 1 

662 

663 alloc_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * batch_size 

664 

665 if alloc_size > 0: 

666 scheduler_metadata = torch.empty(alloc_size, dtype=torch.int32, device=device) 

667 offset = 0 

668 if scheduler_needs_semaphore: 

669 scheduler_metadata[offset] = total_blocks_val 

670 offset += 1 

671 

672 if use_dynamic_split: 

673 scheduler_metadata[offset:] = num_splits_dynamic 

674 elif scheduler_needs_semaphore and not use_dynamic_split: 

675 pass 

676 return scheduler_metadata 

677 else: 

678 return torch.empty((0,), dtype=torch.int32, device=device)