Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/flash_api.py: 0%

364 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6 

7import flag_gems 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils.random_utils import philox_backend_seed_offset 

11 

12from .flash_kernel import ( 

13 block_m_splitkv_heuristic, 

14 block_n_splitkv_heuristic, 

15 flash_fwd_kernel, 

16 flash_fwd_splitkv_combine_kernel, 

17 flash_fwd_splitkv_kernel, 

18 flash_varlen_fwd_kernel, 

19) 

20 

21logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

22_debug = False 

23 

24 

25def CHECK_DEVICE(x): 

26 assert x.device.type == flag_gems.device 

27 

28 

29class fwd_params: 

30 __slots__ = ( 

31 # pointers and strides 

32 "q_ptr", 

33 "k_ptr", 

34 "v_ptr", 

35 "o_ptr", 

36 "p_ptr", 

37 "softmax_lse_ptr", 

38 "q_row_stride", 

39 "k_row_stride", 

40 "v_row_stride", 

41 "q_head_stride", 

42 "k_head_stride", 

43 "v_head_stride", 

44 "o_row_stride", 

45 "o_head_stride", 

46 "q_batch_stride", 

47 "k_batch_stride", 

48 "v_batch_stride", 

49 "o_batch_stride", 

50 "is_cu_seqlens_q", 

51 "cu_seqlens_q_ptr", 

52 "is_cu_seqlens_k", 

53 "cu_seqlens_k_ptr", 

54 "is_seqused_k", 

55 "seqused_k_ptr", 

56 # sizes 

57 "b", 

58 "bk", 

59 "h", 

60 "hk", 

61 "h_hk_ratio", 

62 "seqlen_q", 

63 "seqlen_k", 

64 "seqlen_q_rounded", 

65 "seqlen_k_rounded", 

66 "d", 

67 "d_rounded", 

68 # scaling factors 

69 "is_softcap", 

70 "softcap", 

71 "scale_softmax", 

72 "scale_softmax_log2", 

73 # dropout 

74 "is_dropout", 

75 "p_dropout", 

76 "rp_dropout", 

77 "p_dropout_in_uint8_t", 

78 "philox_args", 

79 "return_softmax", 

80 # masking 

81 "is_causal", 

82 "is_local", 

83 "window_size_left", 

84 "window_size_right", 

85 "seqlenq_ngroups_swapped", 

86 # alibi 

87 "is_alibi", 

88 "alibi_slopes_ptr", 

89 "alibi_slopes_batch_stride", 

90 # block table 

91 "total_q", 

92 "page_table_ptr", 

93 "page_table_batch_stride", 

94 "block_size", 

95 ) 

96 

97 def __init__( 

98 self, 

99 q_ptr, 

100 k_ptr, 

101 v_ptr, 

102 o_ptr, 

103 p_ptr, 

104 softmax_lse_ptr, 

105 q_row_stride, 

106 k_row_stride, 

107 v_row_stride, 

108 q_head_stride, 

109 k_head_stride, 

110 v_head_stride, 

111 o_row_stride, 

112 o_head_stride, 

113 q_batch_stride, 

114 k_batch_stride, 

115 v_batch_stride, 

116 o_batch_stride, 

117 is_cu_seqlens_q, 

118 cu_seqlens_q_ptr, 

119 is_cu_seqlens_k, 

120 cu_seqlens_k_ptr, 

121 is_seqused_k, 

122 seqused_k_ptr, 

123 # sizes 

124 b, 

125 bk, 

126 h, 

127 hk, 

128 h_hk_ratio, 

129 seqlen_q, 

130 seqlen_k, 

131 seqlen_q_rounded, 

132 seqlen_k_rounded, 

133 d, 

134 d_rounded, 

135 # scaling factors 

136 is_softcap, 

137 softcap, 

138 scale_softmax, 

139 scale_softmax_log2, 

140 # dropout 

141 is_dropout, 

142 p_dropout, 

143 rp_dropout, 

144 p_dropout_in_uint8_t, 

145 philox_args, 

146 return_softmax, 

147 # masking 

148 is_causal, 

149 is_local, 

150 window_size_left, 

151 window_size_right, 

152 seqlenq_ngroups_swapped, 

153 # alibi 

154 is_alibi, 

155 alibi_slopes_ptr, 

156 alibi_slopes_batch_stride, 

157 # block table 

158 total_q, 

159 page_table_ptr, 

160 page_table_batch_stride, 

161 block_size, 

162 ): 

163 self.q_ptr = q_ptr 

164 self.k_ptr = k_ptr 

165 self.v_ptr = v_ptr 

166 self.o_ptr = o_ptr 

167 self.p_ptr = p_ptr 

168 self.softmax_lse_ptr = softmax_lse_ptr 

169 self.q_row_stride = q_row_stride 

170 self.k_row_stride = k_row_stride 

171 self.v_row_stride = v_row_stride 

172 self.q_head_stride = q_head_stride 

173 self.k_head_stride = k_head_stride 

174 self.v_head_stride = v_head_stride 

175 self.o_row_stride = o_row_stride 

176 self.o_head_stride = o_head_stride 

177 self.q_batch_stride = q_batch_stride 

178 self.k_batch_stride = k_batch_stride 

179 self.v_batch_stride = v_batch_stride 

180 self.o_batch_stride = o_batch_stride 

181 self.is_cu_seqlens_q = is_cu_seqlens_q 

182 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr 

183 self.is_cu_seqlens_k = is_cu_seqlens_k 

184 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr 

185 self.is_seqused_k = is_seqused_k 

186 self.seqused_k_ptr = seqused_k_ptr 

187 # sizes 

188 self.b = b 

189 self.bk = bk 

190 self.h = h 

191 self.hk = hk 

192 self.h_hk_ratio = h_hk_ratio 

193 self.seqlen_q = seqlen_q 

194 self.seqlen_k = seqlen_k 

195 self.seqlen_q_rounded = seqlen_q_rounded 

196 self.seqlen_k_rounded = seqlen_k_rounded 

197 self.d = d 

198 self.d_rounded = d_rounded 

199 # scaling factors 

200 self.is_softcap = is_softcap 

201 self.softcap = softcap 

202 self.scale_softmax = scale_softmax 

203 self.scale_softmax_log2 = scale_softmax_log2 

204 # dropout 

205 self.is_dropout = is_dropout 

206 self.p_dropout = p_dropout 

207 self.rp_dropout = rp_dropout 

208 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t 

209 self.philox_args = philox_args 

210 self.return_softmax = return_softmax 

211 # masking 

212 self.is_causal = is_causal 

213 self.is_local = is_local 

214 self.window_size_left = window_size_left 

215 self.window_size_right = window_size_right 

216 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped 

217 # alibi 

218 self.is_alibi = is_alibi 

219 self.alibi_slopes_ptr = alibi_slopes_ptr 

220 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride 

221 # block table 

222 self.total_q = total_q 

223 self.page_table_ptr = page_table_ptr 

224 self.page_table_batch_stride = page_table_batch_stride 

225 self.block_size = block_size 

226 

227 def args(self): 

228 return tuple(getattr(self, k) for k in self.__slots__) 

229 

230 

231def mha_varlan_fwd( 

232 q, 

233 k, 

234 v, 

235 out, 

236 cu_seqlens_q, 

237 cu_seqlens_k, 

238 seqused_k, 

239 leftpad_k, 

240 page_table, 

241 alibi_slopes, 

242 max_seqlen_q, 

243 max_seqlen_k, 

244 p_dropout, 

245 softmax_scale, 

246 zero_tensors, 

247 is_causal, 

248 window_size_left, 

249 window_size_right, 

250 softcap, 

251 return_softmax, 

252 gen, 

253): 

254 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

255 q_device = q.device 

256 q_dtype = q.dtype 

257 assert q_dtype in ( 

258 torch.float16, 

259 torch.bfloat16, 

260 ), "FlashAttention only support fp16 and bf16 data type" 

261 assert q_dtype == k.dtype 

262 assert q_dtype == v.dtype 

263 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

264 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

265 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

266 

267 assert cu_seqlens_q.dtype == torch.int32 

268 assert cu_seqlens_q.is_contiguous() 

269 

270 assert cu_seqlens_k.dtype == torch.int32 

271 assert cu_seqlens_k.is_contiguous() 

272 

273 assert page_table is not None 

274 

275 # q shape: [total_q_tokens, num_heads, head_size] 

276 # k shape: 

277 # paged_kv: [num_pages, block_size, num_heads_k, head_size] 

278 # batch_size, number of sentences 

279 total_q, num_heads, head_size = q.size() 

280 num_heads_k = k.size(2) 

281 batch_size = cu_seqlens_q.numel() - 1 

282 block_size = k.size(1) 

283 num_pages = k.size(0) 

284 k_batch_size = num_pages 

285 # max_num_pages_per_seq = page_table.size(1) 

286 page_table_batch_stride = page_table.stride(0) 

287 k_batch_stride = k.stride(0) 

288 v_batch_stride = v.stride(0) 

289 

290 assert k.size() == v.size() 

291 assert cu_seqlens_q.size() == (batch_size + 1,) 

292 assert cu_seqlens_k.size() == (batch_size + 1,) 

293 

294 # Check output shape 

295 if out is not None: 

296 assert out.stride(-1) == 1 

297 assert out.dtype == q.dtype 

298 assert out.size() == (total_q, num_heads, head_size) 

299 

300 if seqused_k is not None: 

301 assert seqused_k.is_contiguous() 

302 assert seqused_k.size() == (batch_size,) 

303 

304 if max_seqlen_q == 1 and alibi_slopes is None: 

305 is_causal = False 

306 

307 if is_causal: 

308 window_size_right = 0 

309 

310 # check disable swa 

311 if window_size_left >= max_seqlen_k: 

312 window_size_left = -1 

313 if window_size_right >= max_seqlen_k: 

314 window_size_right = -1 

315 

316 is_local = window_size_left >= 0 

317 

318 # Optimize all single-query sequences by swapping the query-group and sequence dimensions 

319 seqlenq_ngroups_swapped = ( 

320 max_seqlen_q == 1 

321 and alibi_slopes is None 

322 and num_heads > num_heads_k 

323 and window_size_left < 0 

324 and window_size_right < 0 

325 and p_dropout == 0 

326 ) 

327 q_groups = num_heads // num_heads_k 

328 if seqlenq_ngroups_swapped: 

329 logger.debug("Swapping query groups and sequence dimensions") 

330 q = ( 

331 q.reshape((batch_size, num_heads_k, q_groups, head_size)) 

332 .transpose(1, 2) 

333 .reshape(batch_size * q_groups, num_heads_k, head_size) 

334 ) 

335 max_seqlen_q = q_groups 

336 num_heads = num_heads_k 

337 cu_seqlens_q = None 

338 q_batch_stride = q.stride(0) * max_seqlen_q 

339 k_batch_stride = k.stride(0) 

340 v_batch_stride = v.stride(0) 

341 # o_batch_stride = out.stride(0) * max_seqlen_q 

342 else: 

343 q_batch_stride = 0 

344 k_batch_stride = 0 

345 v_batch_stride = 0 

346 o_batch_stride = 0 

347 

348 total_q = q.size(0) 

349 

350 assert leftpad_k is None, "leftpad_k is not supported." 

351 assert ( 

352 head_size <= 256 

353 ), "FlashAttention forward only supports head dimension at most 256" 

354 assert ( 

355 head_size % 8 == 0 

356 ), "head_size must be a multiple of 8, this is ensured by padding!" 

357 assert ( 

358 num_heads % num_heads_k == 0 

359 ), "Number of heads in key/value must divide number of heads in query" 

360 

361 assert q.shape == (total_q, num_heads, head_size) 

362 assert k.shape == (num_pages, block_size, num_heads_k, head_size) 

363 assert v.shape == (num_pages, block_size, num_heads_k, head_size) 

364 assert k.stride() == v.stride() 

365 

366 if softcap > 0.0: 

367 assert p_dropout == 0, "dropout is not supported if softcap is used." 

368 

369 round_multiple = lambda x, m: (x + m - 1) // m * m 

370 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256 

371 seqlen_q_rounded = round_multiple(max_seqlen_q, 128) 

372 seqlen_k_rounded = round_multiple(max_seqlen_k, 32) 

373 

374 M_LOG2E = 1.4426950408889634074 

375 if softcap > 0.0: 

376 is_softcap = True 

377 adjusted_scale_softmax = softcap 

378 adjusted_softcap = softmax_scale / softcap 

379 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

380 else: 

381 is_softcap = False 

382 adjusted_softcap = 0.0 

383 adjusted_scale_softmax = softmax_scale 

384 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

385 

386 # Set alibi params 

387 if alibi_slopes is not None: 

388 assert alibi_slopes.device == q_device 

389 assert alibi_slopes.dtype in (torch.float,) 

390 assert alibi_slopes.stride(-1) == 1 

391 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

392 batch_size, 

393 num_heads, 

394 ) 

395 alibi_slopes_batch_stride = ( 

396 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

397 ) 

398 is_alibi = True 

399 else: 

400 alibi_slopes_batch_stride = 0 

401 is_alibi = False 

402 

403 # Prepare params to kernel 

404 with torch_device_fn.device(q_device): 

405 if out is not None: 

406 out_ = out 

407 if seqlenq_ngroups_swapped: 

408 out = torch.empty_like(q, dtype=v.dtype) 

409 else: 

410 out_ = None 

411 out = torch.empty_like(q, dtype=v.dtype) 

412 

413 if seqlenq_ngroups_swapped: 

414 o_batch_stride = out.stride(0) * max_seqlen_q 

415 

416 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device) 

417 

418 if p_dropout > 0: 

419 is_dropout = True 

420 increment = batch_size * num_heads * 32 

421 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

422 philox_args = torch.tensor( 

423 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

424 ) 

425 else: 

426 is_dropout = False 

427 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

428 

429 p_dropout = 1 - p_dropout 

430 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

431 rp_dropout = 1.0 / p_dropout 

432 

433 if return_softmax: 

434 assert is_dropout, "Only supported with non-zero dropout." 

435 p = torch.empty( 

436 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

437 device=q_device, 

438 ) 

439 else: 

440 p = torch.empty((), device=q_device) 

441 

442 if zero_tensors: 

443 out.zero_() 

444 lse.fill_(float("-inf")) 

445 

446 params = fwd_params( 

447 q, # q_ptr, 

448 k, # k_ptr, 

449 v, # v_ptr, 

450 out, # o_ptr, 

451 p, # p_ptr, 

452 lse, # softmax_lse_ptr, 

453 q.stride(-3), # q_row_stride, 

454 k.stride(-3), # k_row_stride, 

455 v.stride(-3), # v_row_stride, 

456 q.stride(-2), # q_head_stride, 

457 k.stride(-2), # k_head_stride, 

458 v.stride(-2), # v_head_stride, 

459 out.stride(-3), # o_row_stride, 

460 out.stride(-2), # o_head_stride, 

461 q_batch_stride, # q_batch_stride, 

462 k_batch_stride, # k_batch_stride, 

463 v_batch_stride, # v_batch_stride, 

464 o_batch_stride, # o_batch_stride, 

465 cu_seqlens_q is not None, # is_cu_seqlens_q, 

466 cu_seqlens_q, # cu_seqlens_q_ptr, 

467 seqused_k is None, # is_cu_seqlens_k, 

468 cu_seqlens_k, # cu_seqlens_k_ptr, 

469 seqused_k is not None, # is_seqused_k, 

470 seqused_k, # seqused_k_ptr, 

471 # sizes 

472 batch_size, # b, 

473 k_batch_size, # bk, 

474 num_heads, # h, 

475 num_heads_k, # hk, 

476 num_heads // num_heads_k, # h_hk_ratio, 

477 max_seqlen_q, # seqlen_q, 

478 max_seqlen_k, # seqlen_k, 

479 seqlen_q_rounded, # seqlen_q_rounded, 

480 seqlen_k_rounded, # seqlen_k_rounded, 

481 head_size, # d, 

482 head_size_rounded, # d_rounded, 

483 # scaling factors 

484 is_softcap, 

485 adjusted_softcap, # softcap, 

486 adjusted_scale_softmax, # scale_softmax, 

487 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

488 # dropout 

489 is_dropout, 

490 p_dropout, 

491 rp_dropout, 

492 p_dropout_in_uint8_t, 

493 philox_args, 

494 return_softmax, 

495 # causal and swa 

496 is_causal, # is_causal, 

497 is_local, # is_local, 

498 window_size_left, # window_size_left, 

499 window_size_right, # window_size_right, 

500 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

501 # alibi 

502 is_alibi, # 

503 alibi_slopes, # alibi_slopes_ptr, 

504 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

505 # block table params 

506 total_q, # total_q, 

507 page_table, # page_table_ptr, 

508 page_table_batch_stride, # page_table_batch_stride, 

509 block_size, # block_size, 

510 ) 

511 

512 if flag_gems.vendor_name == "iluvatar": 

513 params.k_ptr = k.view(k.shape[0], k.shape[1], -1) 

514 params.v_ptr = v.view(v.shape[0], v.shape[1], -1) 

515 logger.debug("kernel: flash_varlen_fwd") 

516 grid = lambda args: ( 

517 triton.cdiv(max_seqlen_q, args["BLOCK_M"]), 

518 batch_size, 

519 num_heads, 

520 ) 

521 kernel = flash_varlen_fwd_kernel[grid] 

522 args = tuple(getattr(params, k) for k in params.__slots__) 

523 

524 # We assess which phase the requests are likely to be in and set the config accordingly. 

525 # prefill_config: BLOCK_M=128, BLOCK_N=32, num_warps=4, num_stages=3 

526 # decode_config: BLOCK_M=32, BLOCK_N=32, num_warps=4, num_stages=3 

527 avg_seqlen_q = total_q / batch_size 

528 if avg_seqlen_q >= 256: 

529 varlen_fwd_config_str = "mha_varlen_prefill" 

530 else: 

531 varlen_fwd_config_str = "mha_varlen_decode" 

532 cfg = runtime.get_heuristic_config(varlen_fwd_config_str) 

533 cfg_params = { 

534 "BLOCK_M": cfg["BLOCK_M"](args), 

535 "BLOCK_N": cfg["BLOCK_N"](args), 

536 "BLOCK_K": triton.next_power_of_2(head_size), 

537 "num_warps": cfg["num_warps"](args), 

538 "num_stages": cfg["num_stages"](args), 

539 } 

540 

541 logger.debug("Average query sequence length: %d", avg_seqlen_q) 

542 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params) 

543 kernel(*args, **cfg_params) 

544 

545 if seqlenq_ngroups_swapped: 

546 out = out.reshape( 

547 batch_size, max_seqlen_q, num_heads_k, head_size 

548 ).transpose(1, 2) 

549 if out_ is not None: 

550 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out) 

551 out = out_ 

552 else: 

553 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size) 

554 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q) 

555 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size) 

556 

557 unused = torch.empty((), dtype=torch.int64, device=q_device) 

558 return out, q, k, v, lse, philox_args, unused, p 

559 

560 

561def mha_fwd( 

562 q, 

563 k, 

564 v, 

565 out, 

566 alibi_slopes, 

567 p_dropout, 

568 softmax_scale, 

569 is_causal, 

570 window_size_left, 

571 window_size_right, 

572 softcap, 

573 return_softmax, 

574 disable_splitkv=False, 

575): 

576 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

577 q_dtype = q.dtype 

578 q_device = q.device 

579 assert q_dtype in ( 

580 torch.float16, 

581 torch.bfloat16, 

582 ), "FlashAttention only support fp16 and bf16 data type" 

583 assert q_dtype == k.dtype 

584 assert q_dtype == v.dtype 

585 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

586 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

587 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

588 batch_size, seqlen_q, num_heads, head_size = q.size() 

589 _, seqlen_k, num_heads_k, _ = k.size() 

590 

591 # Check output shape 

592 if out is not None: 

593 assert out.stride(-1) == 1 

594 assert out.dtype == q.dtype 

595 assert out.size() == (batch_size, seqlen_q, num_heads, head_size) 

596 CHECK_DEVICE(out) 

597 

598 assert ( 

599 head_size % 8 == 0 

600 ), "head_size must be a multiple of 8, this is ensured by padding!" 

601 assert ( 

602 num_heads % num_heads_k == 0 

603 ), "Number of heads in key/value must divide number of heads in query" 

604 if window_size_left >= seqlen_k: 

605 window_size_left = -1 

606 if window_size_right >= seqlen_k: 

607 window_size_right = -1 

608 if seqlen_q == 1 and alibi_slopes is None: 

609 is_causal = False 

610 if is_causal: 

611 window_size_right = 0 

612 

613 is_causal = window_size_left < 0 and window_size_right == 0 

614 is_local = window_size_left >= 0 and window_size_right >= 0 

615 

616 seqlenq_ngroups_swapped = ( 

617 seqlen_q == 1 

618 and alibi_slopes is None 

619 and num_heads > num_heads_k 

620 and window_size_left < 0 

621 and window_size_right < 0 

622 and p_dropout == 0 

623 ) 

624 q_groups = num_heads // num_heads_k 

625 

626 if seqlenq_ngroups_swapped: 

627 logger.debug("q_kg swapped.") 

628 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2) 

629 seqlen_q = q_groups 

630 num_heads = num_heads_k 

631 

632 round_multiple = lambda x, m: (x + m - 1) // m * m 

633 head_size_rounded = round_multiple(head_size, 32) 

634 seqlen_q_rounded = round_multiple(seqlen_q, 128) 

635 seqlen_k_rounded = round_multiple(seqlen_k, 32) 

636 

637 assert ( 

638 head_size <= 256 

639 ), "FlashAttention forward only supports head dimension at most 256" 

640 assert head_size == head_size_rounded, "head_size must be rounded to 32" 

641 

642 def splits_heuristic(num_tasks, num_sms, n_blocks): 

643 # splits when wave efficiency is low 

644 n_waves = triton.cdiv(num_tasks, num_sms) 

645 eff = (num_tasks / num_sms) / n_waves 

646 if eff > 0.8 or n_waves > 1: 

647 return 1 

648 

649 min_blocks_per_split = 2 

650 best_splits = min( 

651 triton.cdiv(n_blocks, min_blocks_per_split), 

652 int(math.floor(1.0 / eff)), 

653 num_sms, 

654 ) 

655 

656 return best_splits 

657 

658 with torch_device_fn.device(q_device): 

659 # Set softmax params 

660 lse = torch.empty( 

661 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device 

662 ) 

663 

664 if out is not None: 

665 if seqlenq_ngroups_swapped: 

666 out = out.reshape( 

667 batch_size, num_heads_k, q_groups, head_size 

668 ).transpose(1, 2) 

669 else: 

670 out = torch.empty_like(q, dtype=v.dtype) 

671 

672 # Set dropout params 

673 if p_dropout > 0: 

674 is_dropout = True 

675 increment = batch_size * num_heads * 32 

676 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

677 philox_args = torch.tensor( 

678 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

679 ) 

680 else: 

681 is_dropout = False 

682 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

683 

684 p_dropout = 1 - p_dropout 

685 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

686 rp_dropout = 1.0 / p_dropout 

687 

688 if return_softmax: 

689 assert is_dropout, "Only supported with non-zero dropout." 

690 p = torch.empty( 

691 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

692 device=q_device, 

693 ) 

694 else: 

695 p = torch.empty((), device=q_device) 

696 

697 M_LOG2E = 1.4426950408889634074 

698 if softcap > 0.0: 

699 is_softcap = True 

700 adjusted_scale_softmax = softcap 

701 adjusted_softcap = softmax_scale / softcap 

702 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

703 else: 

704 is_softcap = False 

705 adjusted_softcap = 0.0 

706 adjusted_scale_softmax = softmax_scale 

707 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

708 

709 # Set alibi params 

710 if alibi_slopes is not None: 

711 assert alibi_slopes.device == q_device 

712 assert alibi_slopes.dtype in (torch.float,) 

713 assert alibi_slopes.stride(-1) == 1 

714 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

715 batch_size, 

716 num_heads, 

717 ) 

718 alibi_slopes_batch_stride = ( 

719 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

720 ) 

721 is_alibi = True 

722 else: 

723 alibi_slopes_batch_stride = 0 

724 is_alibi = False 

725 

726 # ONLY EVEN_K IS SUPPORTED 

727 assert head_size == head_size_rounded 

728 

729 # Do kernel dispatching 

730 def dispatch(B, H, Q, K, D, params): 

731 num_sms = torch_device_fn.get_device_properties( 

732 "cuda" 

733 ).multi_processor_count 

734 

735 # Try bh parallel 

736 # if B * H > 0.8 * num_sms: 

737 # kernel = flash_fwd_bh_parallel_kernel[(H, B)] 

738 # # Yield kernel and prefilled args 

739 # return kernel, default_args, None, None 

740 

741 # Try splitkv 

742 if not is_dropout and not is_local and not disable_splitkv: 

743 BM = block_m_splitkv_heuristic(D) 

744 n_tasks = B * H * triton.cdiv(seqlen_q, BM) 

745 BN = block_n_splitkv_heuristic(D) 

746 n_blocks = triton.cdiv(seqlen_k, BN) 

747 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks) 

748 

749 if n_splits > 1: 

750 logger.debug("kernel: flash_fwd_splitkv") 

751 lse_splits = torch.empty( 

752 (n_splits, B, H, Q), dtype=torch.float, device=q_device 

753 ) 

754 out_splits = torch.empty( 

755 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device 

756 ) 

757 grid = lambda args: ( 

758 triton.cdiv(Q, args["BLOCK_M"]), 

759 n_splits, 

760 B * H, 

761 ) 

762 splitkv_kernel = flash_fwd_splitkv_kernel[grid] 

763 params.o_ptr = out_splits 

764 params.softmax_lse_ptr = lse_splits 

765 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)} 

766 kernel = splitkv_kernel(*params.args(), **extra_args) 

767 

768 if D >= 128: 

769 BLOCK_M = 4 

770 elif D >= 64: 

771 BLOCK_M = 8 

772 else: 

773 BLOCK_M = 16 

774 BLOCK_K = triton.next_power_of_2(D) 

775 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),) 

776 combine_kernel = flash_fwd_splitkv_combine_kernel[grid] 

777 combine_args = { 

778 "out_ptr": out, 

779 "lse_ptr": lse, 

780 "head_size": head_size, 

781 "out_split_stride": out_splits.stride(0), 

782 "lse_split_stride": lse_splits.stride(0), 

783 "out_b_stride": out.stride(0), 

784 "out_s_stride": out.stride(-3), 

785 "out_h_stride": out.stride(-1), 

786 "out_splits_ptr": out_splits, 

787 "lse_splits_ptr": lse_splits, 

788 "n_splits": n_splits, 

789 "BLOCK_M": BLOCK_M, 

790 "BLOCK_K": BLOCK_K, 

791 "q_total": B * H * Q, 

792 "MAX_N_SPLITS": triton.next_power_of_2(n_splits), 

793 } 

794 combine_kernel(**combine_args) 

795 return kernel 

796 

797 # Last option: flash_fwd 

798 logger.debug("kernel: flash_fwd") 

799 grid = lambda args: ( 

800 triton.cdiv(Q, args["BLOCK_M"]), 

801 H * B, 

802 ) 

803 kernel = flash_fwd_kernel[grid] 

804 kernel = kernel(*params.args()) 

805 return kernel 

806 

807 if _debug: 

808 p = torch.empty( 

809 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

810 dtype=torch.float32, 

811 device=q_device, 

812 ) 

813 return_softmax = True 

814 

815 params = fwd_params( 

816 q, # q_ptr, 

817 k, # k_ptr, 

818 v, # v_ptr, 

819 out, # o_ptr, 

820 p, # p_ptr, 

821 lse, # softmax_lse_ptr, 

822 q.stride(-3), # q_row_stride, 

823 k.stride(-3), # k_row_stride, 

824 v.stride(-3), # v_row_stride, 

825 q.stride(-2), # q_head_stride, 

826 k.stride(-2), # k_head_stride, 

827 v.stride(-2), # v_head_stride, 

828 out.stride(-3), # o_row_stride, 

829 out.stride(-2), # o_head_stride, 

830 q.stride(0), # q_batch_stride, 

831 k.stride(0), # k_batch_stride, 

832 v.stride(0), # v_batch_stride, 

833 out.stride(0), # o_batch_stride, 

834 False, # is_cu_seqlens_q, 

835 None, # cu_seqlens_q_ptr, 

836 False, # is_cu_seqlens_k, 

837 None, # cu_seqlens_k_ptr, 

838 False, # is_seqused_k, 

839 None, # seqused_k_ptr, 

840 # sizes 

841 batch_size, # b, 

842 0, # bk, 

843 num_heads, # h, 

844 num_heads_k, # hk, 

845 num_heads // num_heads_k, # h_hk_ratio, 

846 seqlen_q, # seqlen_q, 

847 seqlen_k, # seqlen_k, 

848 seqlen_q_rounded, # seqlen_q_rounded, 

849 seqlen_k_rounded, # seqlen_k_rounded, 

850 head_size, # d, 

851 head_size_rounded, # d_rounded, 

852 # scaling factors 

853 is_softcap, 

854 adjusted_softcap, # softcap, 

855 adjusted_scale_softmax, # scale_softmax, 

856 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

857 # dropout 

858 is_dropout, 

859 p_dropout, 

860 rp_dropout, 

861 p_dropout_in_uint8_t, 

862 philox_args, 

863 return_softmax, 

864 # causal and swa 

865 is_causal, # is_causal, 

866 is_local, # is_local, 

867 window_size_left, # window_size_left, 

868 window_size_right, # window_size_right, 

869 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

870 # alibi 

871 is_alibi, # 

872 alibi_slopes, # alibi_slopes_ptr, 

873 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

874 # block table params 

875 0, # total_q, 

876 None, # page_table_ptr, 

877 0, # page_table_batch_stride, 

878 0, # block_size, 

879 ) 

880 

881 # Move TxD to last dims for correct stride in Triton tt.load 

882 if flag_gems.vendor_name == "iluvatar": 

883 params.q_ptr = q.transpose(1, 2) 

884 params.k_ptr = k.transpose(1, 2) 

885 params.v_ptr = v.transpose(1, 2) 

886 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params) 

887 

888 if _debug: 

889 print(f"{kernel.name} shared memory:", kernel.metadata.shared) 

890 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps) 

891 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages) 

892 # print(kernel.asm['ttgir']) 

893 

894 if seqlenq_ngroups_swapped: 

895 out = out.transpose(1, 2).reshape( 

896 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

897 ) 

898 q = q.transpose(1, 2).reshape( 

899 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

900 ) 

901 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1)) 

902 

903 unused = torch.empty((), dtype=torch.int64, device=q_device) 

904 

905 return out, q, k, v, lse, philox_args, unused, p