Coverage for src/flag_gems/ops/flash_api.py: 91%

377 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6 

7import flag_gems 

8from flag_gems import runtime 

9from flag_gems.ops.flash_kernel import ( 

10 block_m_splitkv_heuristic, 

11 block_n_splitkv_heuristic, 

12 flash_fwd_kernel, 

13 flash_fwd_splitkv_combine_kernel, 

14 flash_fwd_splitkv_kernel, 

15 flash_varlen_fwd_kernel, 

16) 

17from flag_gems.runtime import torch_device_fn 

18from flag_gems.utils.random_utils import philox_backend_seed_offset 

19 

20logger = logging.getLogger(__name__) 

21_debug = False 

22 

23 

24def CHECK_DEVICE(x): 

25 assert x.device.type == flag_gems.device 

26 

27 

28class fwd_params: 

29 __slots__ = ( 

30 # pointers and strides 

31 "q_ptr", 

32 "k_ptr", 

33 "v_ptr", 

34 "o_ptr", 

35 "p_ptr", 

36 "softmax_lse_ptr", 

37 "q_row_stride", 

38 "k_row_stride", 

39 "v_row_stride", 

40 "q_head_stride", 

41 "k_head_stride", 

42 "v_head_stride", 

43 "o_row_stride", 

44 "o_head_stride", 

45 "q_batch_stride", 

46 "k_batch_stride", 

47 "v_batch_stride", 

48 "o_batch_stride", 

49 "is_cu_seqlens_q", 

50 "cu_seqlens_q_ptr", 

51 "is_cu_seqlens_k", 

52 "cu_seqlens_k_ptr", 

53 "is_seqused_k", 

54 "seqused_k_ptr", 

55 # sizes 

56 "b", 

57 "bk", 

58 "h", 

59 "hk", 

60 "h_hk_ratio", 

61 "seqlen_q", 

62 "seqlen_k", 

63 "seqlen_q_rounded", 

64 "seqlen_k_rounded", 

65 "d", 

66 "d_rounded", 

67 # scaling factors 

68 "is_softcap", 

69 "softcap", 

70 "scale_softmax", 

71 "scale_softmax_log2", 

72 # dropout 

73 "is_dropout", 

74 "p_dropout", 

75 "rp_dropout", 

76 "p_dropout_in_uint8_t", 

77 "philox_args", 

78 "return_softmax", 

79 # masking 

80 "is_causal", 

81 "is_local", 

82 "window_size_left", 

83 "window_size_right", 

84 "seqlenq_ngroups_swapped", 

85 "is_paged", 

86 # alibi 

87 "is_alibi", 

88 "alibi_slopes_ptr", 

89 "alibi_slopes_batch_stride", 

90 # block table 

91 "total_q", 

92 "page_table_ptr", 

93 "page_table_batch_stride", 

94 "block_size", 

95 ) 

96 

97 def __init__( 

98 self, 

99 q_ptr, 

100 k_ptr, 

101 v_ptr, 

102 o_ptr, 

103 p_ptr, 

104 softmax_lse_ptr, 

105 q_row_stride, 

106 k_row_stride, 

107 v_row_stride, 

108 q_head_stride, 

109 k_head_stride, 

110 v_head_stride, 

111 o_row_stride, 

112 o_head_stride, 

113 q_batch_stride, 

114 k_batch_stride, 

115 v_batch_stride, 

116 o_batch_stride, 

117 is_cu_seqlens_q, 

118 cu_seqlens_q_ptr, 

119 is_cu_seqlens_k, 

120 cu_seqlens_k_ptr, 

121 is_seqused_k, 

122 seqused_k_ptr, 

123 # sizes 

124 b, 

125 bk, 

126 h, 

127 hk, 

128 h_hk_ratio, 

129 seqlen_q, 

130 seqlen_k, 

131 seqlen_q_rounded, 

132 seqlen_k_rounded, 

133 d, 

134 d_rounded, 

135 # scaling factors 

136 is_softcap, 

137 softcap, 

138 scale_softmax, 

139 scale_softmax_log2, 

140 # dropout 

141 is_dropout, 

142 p_dropout, 

143 rp_dropout, 

144 p_dropout_in_uint8_t, 

145 philox_args, 

146 return_softmax, 

147 # masking 

148 is_causal, 

149 is_local, 

150 window_size_left, 

151 window_size_right, 

152 seqlenq_ngroups_swapped, 

153 is_paged, 

154 # alibi 

155 is_alibi, 

156 alibi_slopes_ptr, 

157 alibi_slopes_batch_stride, 

158 # block table 

159 total_q, 

160 page_table_ptr, 

161 page_table_batch_stride, 

162 block_size, 

163 ): 

164 self.q_ptr = q_ptr 

165 self.k_ptr = k_ptr 

166 self.v_ptr = v_ptr 

167 self.o_ptr = o_ptr 

168 self.p_ptr = p_ptr 

169 self.softmax_lse_ptr = softmax_lse_ptr 

170 self.q_row_stride = q_row_stride 

171 self.k_row_stride = k_row_stride 

172 self.v_row_stride = v_row_stride 

173 self.q_head_stride = q_head_stride 

174 self.k_head_stride = k_head_stride 

175 self.v_head_stride = v_head_stride 

176 self.o_row_stride = o_row_stride 

177 self.o_head_stride = o_head_stride 

178 self.q_batch_stride = q_batch_stride 

179 self.k_batch_stride = k_batch_stride 

180 self.v_batch_stride = v_batch_stride 

181 self.o_batch_stride = o_batch_stride 

182 self.is_cu_seqlens_q = is_cu_seqlens_q 

183 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr 

184 self.is_cu_seqlens_k = is_cu_seqlens_k 

185 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr 

186 self.is_seqused_k = is_seqused_k 

187 self.seqused_k_ptr = seqused_k_ptr 

188 # sizes 

189 self.b = b 

190 self.bk = bk 

191 self.h = h 

192 self.hk = hk 

193 self.h_hk_ratio = h_hk_ratio 

194 self.seqlen_q = seqlen_q 

195 self.seqlen_k = seqlen_k 

196 self.seqlen_q_rounded = seqlen_q_rounded 

197 self.seqlen_k_rounded = seqlen_k_rounded 

198 self.d = d 

199 self.d_rounded = d_rounded 

200 # scaling factors 

201 self.is_softcap = is_softcap 

202 self.softcap = softcap 

203 self.scale_softmax = scale_softmax 

204 self.scale_softmax_log2 = scale_softmax_log2 

205 # dropout 

206 self.is_dropout = is_dropout 

207 self.p_dropout = p_dropout 

208 self.rp_dropout = rp_dropout 

209 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t 

210 self.philox_args = philox_args 

211 self.return_softmax = return_softmax 

212 # masking 

213 self.is_causal = is_causal 

214 self.is_local = is_local 

215 self.window_size_left = window_size_left 

216 self.window_size_right = window_size_right 

217 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped 

218 self.is_paged = is_paged 

219 # alibi 

220 self.is_alibi = is_alibi 

221 self.alibi_slopes_ptr = alibi_slopes_ptr 

222 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride 

223 # block table 

224 self.total_q = total_q 

225 self.page_table_ptr = page_table_ptr 

226 self.page_table_batch_stride = page_table_batch_stride 

227 self.block_size = block_size 

228 

229 def args(self): 

230 return tuple(getattr(self, k) for k in self.__slots__) 

231 

232 

233def mha_varlan_fwd( 

234 q, 

235 k, 

236 v, 

237 out, 

238 cu_seqlens_q, 

239 cu_seqlens_k, 

240 seqused_k, 

241 leftpad_k, 

242 page_table, 

243 alibi_slopes, 

244 max_seqlen_q, 

245 max_seqlen_k, 

246 p_dropout, 

247 softmax_scale, 

248 zero_tensors, 

249 is_causal, 

250 window_size_left, 

251 window_size_right, 

252 softcap, 

253 return_softmax, 

254 gen, 

255): 

256 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

257 q_device = q.device 

258 q_dtype = q.dtype 

259 assert q_dtype in ( 

260 torch.float16, 

261 torch.bfloat16, 

262 ), "FlashAttention only support fp16 and bf16 data type" 

263 assert q_dtype == k.dtype 

264 assert q_dtype == v.dtype 

265 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

266 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

267 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

268 

269 assert cu_seqlens_q.dtype == torch.int32 

270 assert cu_seqlens_q.is_contiguous() 

271 

272 assert cu_seqlens_k.dtype == torch.int32 

273 assert cu_seqlens_k.is_contiguous() 

274 

275 is_paged = page_table is not None 

276 if not is_paged: 

277 page_table = torch.empty((0, 0), device=q_device, dtype=torch.int32) 

278 

279 # q shape: [total_q_tokens, num_heads, head_size] 

280 # k shape: 

281 # paged_kv: [num_pages, block_size, num_heads_k, head_size] 

282 # batch_size, number of sentences 

283 total_q, num_heads, head_size = q.size() 

284 num_heads_k = k.size(2) if is_paged else k.size(1) 

285 batch_size = cu_seqlens_q.numel() - 1 

286 block_size = k.size(1) if is_paged else 1 

287 num_pages = k.size(0) if is_paged else 0 

288 k_batch_size = num_pages 

289 # max_num_pages_per_seq = page_table.size(1) 

290 page_table_batch_stride = page_table.stride(0) 

291 k_batch_stride = k.stride(0) 

292 v_batch_stride = v.stride(0) 

293 

294 assert k.size() == v.size() 

295 assert cu_seqlens_q.size() == (batch_size + 1,) 

296 assert cu_seqlens_k.size() == (batch_size + 1,) 

297 

298 # Check output shape 

299 if out is not None: 

300 assert out.stride(-1) == 1 

301 assert out.dtype == q.dtype 

302 assert out.size() == (total_q, num_heads, head_size) 

303 

304 if seqused_k is not None: 

305 assert seqused_k.is_contiguous() 

306 assert seqused_k.size() == (batch_size,) 

307 

308 if max_seqlen_q == 1 and alibi_slopes is None: 

309 is_causal = False 

310 

311 if is_causal: 

312 window_size_right = 0 

313 

314 # check disable swa 

315 if window_size_left >= max_seqlen_k: 

316 window_size_left = -1 

317 if window_size_right >= max_seqlen_k: 

318 window_size_right = -1 

319 

320 is_local = window_size_left >= 0 

321 

322 # Optimize all single-query sequences by swapping the query-group and sequence dimensions 

323 seqlenq_ngroups_swapped = ( 

324 max_seqlen_q == 1 

325 and alibi_slopes is None 

326 and num_heads > num_heads_k 

327 and window_size_left < 0 

328 and window_size_right < 0 

329 and p_dropout == 0 

330 ) 

331 q_groups = num_heads // num_heads_k 

332 if seqlenq_ngroups_swapped: 

333 logger.debug("Swapping query groups and sequence dimensions") 

334 q = ( 

335 q.reshape((batch_size, num_heads_k, q_groups, head_size)) 

336 .transpose(1, 2) 

337 .reshape(batch_size * q_groups, num_heads_k, head_size) 

338 ) 

339 max_seqlen_q = q_groups 

340 num_heads = num_heads_k 

341 cu_seqlens_q = None 

342 q_batch_stride = q.stride(0) * max_seqlen_q 

343 k_batch_stride = k.stride(0) 

344 v_batch_stride = v.stride(0) 

345 # o_batch_stride = out.stride(0) * max_seqlen_q 

346 else: 

347 q_batch_stride = 0 

348 k_batch_stride = 0 

349 v_batch_stride = 0 

350 o_batch_stride = 0 

351 

352 total_q = q.size(0) 

353 

354 assert leftpad_k is None, "leftpad_k is not supported." 

355 assert ( 

356 head_size <= 256 

357 ), "FlashAttention forward only supports head dimension at most 256" 

358 assert ( 

359 head_size % 8 == 0 

360 ), "head_size must be a multiple of 8, this is ensured by padding!" 

361 assert ( 

362 num_heads % num_heads_k == 0 

363 ), "Number of heads in key/value must divide number of heads in query" 

364 

365 assert q.shape == (total_q, num_heads, head_size) 

366 if is_paged: 

367 assert k.shape == (num_pages, block_size, num_heads_k, head_size) 

368 assert v.shape == (num_pages, block_size, num_heads_k, head_size) 

369 assert k.stride() == v.stride() 

370 

371 if softcap > 0.0: 

372 assert p_dropout == 0, "dropout is not supported if softcap is used." 

373 

374 round_multiple = lambda x, m: (x + m - 1) // m * m 

375 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256 

376 seqlen_q_rounded = round_multiple(max_seqlen_q, 128) 

377 seqlen_k_rounded = round_multiple(max_seqlen_k, 32) 

378 

379 M_LOG2E = 1.4426950408889634074 

380 if softcap > 0.0: 

381 is_softcap = True 

382 adjusted_scale_softmax = softcap 

383 adjusted_softcap = softmax_scale / softcap 

384 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

385 else: 

386 is_softcap = False 

387 adjusted_softcap = 0.0 

388 adjusted_scale_softmax = softmax_scale 

389 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

390 

391 # Set alibi params 

392 if alibi_slopes is not None: 

393 assert alibi_slopes.device == q_device 

394 assert alibi_slopes.dtype in (torch.float,) 

395 assert alibi_slopes.stride(-1) == 1 

396 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

397 batch_size, 

398 num_heads, 

399 ) 

400 alibi_slopes_batch_stride = ( 

401 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

402 ) 

403 is_alibi = True 

404 else: 

405 alibi_slopes_batch_stride = 0 

406 is_alibi = False 

407 

408 # Prepare params to kernel 

409 with torch_device_fn.device(q_device): 

410 if out is not None: 

411 out_ = out 

412 if seqlenq_ngroups_swapped: 

413 out = torch.empty_like(q, dtype=v.dtype) 

414 else: 

415 out_ = None 

416 out = torch.empty_like(q, dtype=v.dtype) 

417 

418 if seqlenq_ngroups_swapped: 

419 o_batch_stride = out.stride(0) * max_seqlen_q 

420 

421 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device) 

422 

423 if p_dropout > 0: 

424 is_dropout = True 

425 increment = batch_size * num_heads * 32 

426 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

427 philox_args = torch.tensor( 

428 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

429 ) 

430 else: 

431 is_dropout = False 

432 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

433 

434 p_dropout = 1 - p_dropout 

435 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

436 rp_dropout = 1.0 / p_dropout 

437 

438 if return_softmax: 

439 assert is_dropout, "Only supported with non-zero dropout." 

440 p = torch.empty( 

441 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

442 device=q_device, 

443 ) 

444 else: 

445 p = torch.empty((), device=q_device) 

446 

447 if zero_tensors: 

448 out.zero_() 

449 lse.fill_(float("-inf")) 

450 

451 params = fwd_params( 

452 q, # q_ptr, 

453 k, # k_ptr, 

454 v, # v_ptr, 

455 out, # o_ptr, 

456 p, # p_ptr, 

457 lse, # softmax_lse_ptr, 

458 q.stride(-3), # q_row_stride, 

459 k.stride(-3), # k_row_stride, 

460 v.stride(-3), # v_row_stride, 

461 q.stride(-2), # q_head_stride, 

462 k.stride(-2), # k_head_stride, 

463 v.stride(-2), # v_head_stride, 

464 out.stride(-3), # o_row_stride, 

465 out.stride(-2), # o_head_stride, 

466 q_batch_stride, # q_batch_stride, 

467 k_batch_stride, # k_batch_stride, 

468 v_batch_stride, # v_batch_stride, 

469 o_batch_stride, # o_batch_stride, 

470 cu_seqlens_q is not None, # is_cu_seqlens_q, 

471 cu_seqlens_q, # cu_seqlens_q_ptr, 

472 seqused_k is None, # is_cu_seqlens_k, 

473 cu_seqlens_k, # cu_seqlens_k_ptr, 

474 seqused_k is not None, # is_seqused_k, 

475 seqused_k, # seqused_k_ptr, 

476 # sizes 

477 batch_size, # b, 

478 k_batch_size, # bk, 

479 num_heads, # h, 

480 num_heads_k, # hk, 

481 num_heads // num_heads_k, # h_hk_ratio, 

482 max_seqlen_q, # seqlen_q, 

483 max_seqlen_k, # seqlen_k, 

484 seqlen_q_rounded, # seqlen_q_rounded, 

485 seqlen_k_rounded, # seqlen_k_rounded, 

486 head_size, # d, 

487 head_size_rounded, # d_rounded, 

488 # scaling factors 

489 is_softcap, 

490 adjusted_softcap, # softcap, 

491 adjusted_scale_softmax, # scale_softmax, 

492 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

493 # dropout 

494 is_dropout, 

495 p_dropout, 

496 rp_dropout, 

497 p_dropout_in_uint8_t, 

498 philox_args, 

499 return_softmax, 

500 # causal and swa 

501 is_causal, # is_causal, 

502 is_local, # is_local, 

503 window_size_left, # window_size_left, 

504 window_size_right, # window_size_right, 

505 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

506 is_paged, 

507 # alibi 

508 is_alibi, # 

509 alibi_slopes, # alibi_slopes_ptr, 

510 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

511 # block table params 

512 total_q, # total_q, 

513 page_table, # page_table_ptr, 

514 page_table_batch_stride, # page_table_batch_stride, 

515 block_size, # block_size, 

516 ) 

517 

518 if flag_gems.vendor_name == "iluvatar": 

519 params.k_ptr = k.view(k.shape[0], k.shape[1], -1) 

520 params.v_ptr = v.view(v.shape[0], v.shape[1], -1) 

521 logger.debug("kernel: flash_varlen_fwd") 

522 grid = lambda args: ( 

523 triton.cdiv(max_seqlen_q, args["BLOCK_M"]), 

524 batch_size, 

525 num_heads, 

526 ) 

527 kernel = flash_varlen_fwd_kernel[grid] 

528 args = tuple(getattr(params, k) for k in params.__slots__) 

529 

530 # We assess which phase the requests are likely to be in and set the config accordingly. 

531 total_rows = total_q * num_heads 

532 num_sms = torch_device_fn.get_device_properties( 

533 flag_gems.device 

534 ).multi_processor_count 

535 avg_rows_per_sm = total_rows / num_sms 

536 avg_rows_per_batch = total_q / batch_size 

537 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm) 

538 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase. 

539 # This is a rough heuristic and may not be accurate for all scenarios. 

540 if avg_rows_per_cta > 64: 

541 varlen_fwd_config_str = "mha_block_128" 

542 elif avg_rows_per_cta > 32: 

543 varlen_fwd_config_str = "mha_block_64" 

544 elif avg_rows_per_cta > 16: 

545 varlen_fwd_config_str = "mha_block_32" 

546 else: 

547 varlen_fwd_config_str = "mha_block_16" 

548 if flag_gems.vendor_name == "mthreads": 

549 varlen_fwd_config_str = "mha_block_32" 

550 

551 cfg = runtime.get_heuristic_config(varlen_fwd_config_str) 

552 cfg_params = { 

553 "BLOCK_M": cfg["BLOCK_M"](args), 

554 "BLOCK_N": cfg["BLOCK_N"](args), 

555 "BLOCK_K": triton.next_power_of_2(head_size), 

556 "num_warps": cfg["num_warps"](args), 

557 "num_stages": 1 if not is_paged else cfg["num_stages"](args), 

558 } 

559 

560 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params) 

561 kernel(*args, **cfg_params) 

562 

563 if seqlenq_ngroups_swapped: 

564 out = out.reshape( 

565 batch_size, max_seqlen_q, num_heads_k, head_size 

566 ).transpose(1, 2) 

567 if out_ is not None: 

568 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out) 

569 out = out_ 

570 else: 

571 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size) 

572 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q) 

573 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size) 

574 

575 unused = torch.empty((), dtype=torch.int64, device=q_device) 

576 return out, q, k, v, lse, philox_args, unused, p 

577 

578 

579def mha_fwd( 

580 q, 

581 k, 

582 v, 

583 out, 

584 alibi_slopes, 

585 p_dropout, 

586 softmax_scale, 

587 is_causal, 

588 window_size_left, 

589 window_size_right, 

590 softcap, 

591 return_softmax, 

592 disable_splitkv=False, 

593): 

594 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

595 q_dtype = q.dtype 

596 q_device = q.device 

597 assert q_dtype in ( 

598 torch.float16, 

599 torch.bfloat16, 

600 ), "FlashAttention only support fp16 and bf16 data type" 

601 assert q_dtype == k.dtype 

602 assert q_dtype == v.dtype 

603 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

604 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

605 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

606 batch_size, seqlen_q, num_heads, head_size = q.size() 

607 _, seqlen_k, num_heads_k, _ = k.size() 

608 

609 # Check output shape 

610 if out is not None: 

611 assert out.stride(-1) == 1 

612 assert out.dtype == q.dtype 

613 assert out.size() == (batch_size, seqlen_q, num_heads, head_size) 

614 CHECK_DEVICE(out) 

615 

616 assert ( 

617 head_size % 8 == 0 

618 ), "head_size must be a multiple of 8, this is ensured by padding!" 

619 assert ( 

620 num_heads % num_heads_k == 0 

621 ), "Number of heads in key/value must divide number of heads in query" 

622 if window_size_left >= seqlen_k: 

623 window_size_left = -1 

624 if window_size_right >= seqlen_k: 

625 window_size_right = -1 

626 if seqlen_q == 1 and alibi_slopes is None: 

627 is_causal = False 

628 if is_causal: 

629 window_size_right = 0 

630 

631 is_causal = window_size_left < 0 and window_size_right == 0 

632 is_local = window_size_left >= 0 and window_size_right >= 0 

633 

634 seqlenq_ngroups_swapped = ( 

635 seqlen_q == 1 

636 and alibi_slopes is None 

637 and num_heads > num_heads_k 

638 and window_size_left < 0 

639 and window_size_right < 0 

640 and p_dropout == 0 

641 ) 

642 q_groups = num_heads // num_heads_k 

643 

644 if seqlenq_ngroups_swapped: 

645 logger.debug("q_kg swapped.") 

646 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2) 

647 seqlen_q = q_groups 

648 num_heads = num_heads_k 

649 

650 round_multiple = lambda x, m: (x + m - 1) // m * m 

651 head_size_rounded = round_multiple(head_size, 32) 

652 seqlen_q_rounded = round_multiple(seqlen_q, 128) 

653 seqlen_k_rounded = round_multiple(seqlen_k, 32) 

654 

655 assert ( 

656 head_size <= 256 

657 ), "FlashAttention forward only supports head dimension at most 256" 

658 assert head_size == head_size_rounded, "head_size must be rounded to 32" 

659 

660 def splits_heuristic(num_tasks, num_sms, n_blocks): 

661 # splits when wave efficiency is low 

662 n_waves = triton.cdiv(num_tasks, num_sms) 

663 eff = (num_tasks / num_sms) / n_waves 

664 if eff > 0.8 or n_waves > 1: 

665 return 1 

666 

667 min_blocks_per_split = 2 

668 best_splits = min( 

669 triton.cdiv(n_blocks, min_blocks_per_split), 

670 int(math.floor(1.0 / eff)), 

671 num_sms, 

672 ) 

673 

674 return best_splits 

675 

676 with torch_device_fn.device(q_device): 

677 # Set softmax params 

678 lse = torch.empty( 

679 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device 

680 ) 

681 

682 if out is not None: 

683 if seqlenq_ngroups_swapped: 

684 out = out.reshape( 

685 batch_size, num_heads_k, q_groups, head_size 

686 ).transpose(1, 2) 

687 else: 

688 out = torch.empty_like(q, dtype=v.dtype) 

689 

690 # Set dropout params 

691 if p_dropout > 0: 

692 is_dropout = True 

693 increment = batch_size * num_heads * 32 

694 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

695 philox_args = torch.tensor( 

696 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

697 ) 

698 else: 

699 is_dropout = False 

700 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

701 

702 p_dropout = 1 - p_dropout 

703 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

704 rp_dropout = 1.0 / p_dropout 

705 

706 if return_softmax: 

707 assert is_dropout, "Only supported with non-zero dropout." 

708 p = torch.empty( 

709 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

710 device=q_device, 

711 ) 

712 else: 

713 p = torch.empty((), device=q_device) 

714 

715 M_LOG2E = 1.4426950408889634074 

716 if softcap > 0.0: 

717 is_softcap = True 

718 adjusted_scale_softmax = softcap 

719 adjusted_softcap = softmax_scale / softcap 

720 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

721 else: 

722 is_softcap = False 

723 adjusted_softcap = 0.0 

724 adjusted_scale_softmax = softmax_scale 

725 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

726 

727 # Set alibi params 

728 if alibi_slopes is not None: 

729 assert alibi_slopes.device == q_device 

730 assert alibi_slopes.dtype in (torch.float,) 

731 assert alibi_slopes.stride(-1) == 1 

732 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

733 batch_size, 

734 num_heads, 

735 ) 

736 alibi_slopes_batch_stride = ( 

737 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

738 ) 

739 is_alibi = True 

740 else: 

741 alibi_slopes_batch_stride = 0 

742 is_alibi = False 

743 

744 # ONLY EVEN_K IS SUPPORTED 

745 assert head_size == head_size_rounded 

746 

747 # Do kernel dispatching 

748 def dispatch(B, H, Q, K, D, params): 

749 num_sms = torch_device_fn.get_device_properties( 

750 "cuda" 

751 ).multi_processor_count 

752 

753 # Try bh parallel 

754 # if B * H > 0.8 * num_sms: 

755 # kernel = flash_fwd_bh_parallel_kernel[(H, B)] 

756 # # Yield kernel and prefilled args 

757 # return kernel, default_args, None, None 

758 

759 # Try splitkv 

760 if not is_dropout and not is_local and not disable_splitkv: 

761 BM = block_m_splitkv_heuristic(D) 

762 n_tasks = B * H * triton.cdiv(seqlen_q, BM) 

763 BN = block_n_splitkv_heuristic(D) 

764 n_blocks = triton.cdiv(seqlen_k, BN) 

765 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks) 

766 

767 if n_splits > 1: 

768 logger.debug("kernel: flash_fwd_splitkv") 

769 lse_splits = torch.empty( 

770 (n_splits, B, H, Q), dtype=torch.float, device=q_device 

771 ) 

772 out_splits = torch.empty( 

773 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device 

774 ) 

775 grid = lambda args: ( 

776 triton.cdiv(Q, args["BLOCK_M"]), 

777 n_splits, 

778 B * H, 

779 ) 

780 splitkv_kernel = flash_fwd_splitkv_kernel[grid] 

781 params.o_ptr = out_splits 

782 params.softmax_lse_ptr = lse_splits 

783 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)} 

784 kernel = splitkv_kernel(*params.args(), **extra_args) 

785 

786 if D >= 128: 

787 BLOCK_M = 4 

788 elif D >= 64: 

789 BLOCK_M = 8 

790 else: 

791 BLOCK_M = 16 

792 BLOCK_K = triton.next_power_of_2(D) 

793 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),) 

794 combine_kernel = flash_fwd_splitkv_combine_kernel[grid] 

795 combine_args = { 

796 "out_ptr": out, 

797 "lse_ptr": lse, 

798 "head_size": head_size, 

799 "out_split_stride": out_splits.stride(0), 

800 "lse_split_stride": lse_splits.stride(0), 

801 "out_b_stride": out.stride(0), 

802 "out_s_stride": out.stride(-3), 

803 "out_h_stride": out.stride(-1), 

804 "out_splits_ptr": out_splits, 

805 "lse_splits_ptr": lse_splits, 

806 "n_splits": n_splits, 

807 "BLOCK_M": BLOCK_M, 

808 "BLOCK_K": BLOCK_K, 

809 "q_total": B * H * Q, 

810 "MAX_N_SPLITS": triton.next_power_of_2(n_splits), 

811 } 

812 combine_kernel(**combine_args) 

813 return kernel 

814 

815 # Last option: flash_fwd 

816 logger.debug("kernel: flash_fwd") 

817 grid = lambda args: ( 

818 triton.cdiv(Q, args["BLOCK_M"]), 

819 H * B, 

820 ) 

821 kernel = flash_fwd_kernel[grid] 

822 kernel = kernel(*params.args()) 

823 return kernel 

824 

825 if _debug: 

826 p = torch.empty( 

827 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

828 dtype=torch.float32, 

829 device=q_device, 

830 ) 

831 return_softmax = True 

832 

833 params = fwd_params( 

834 q, # q_ptr, 

835 k, # k_ptr, 

836 v, # v_ptr, 

837 out, # o_ptr, 

838 p, # p_ptr, 

839 lse, # softmax_lse_ptr, 

840 q.stride(-3), # q_row_stride, 

841 k.stride(-3), # k_row_stride, 

842 v.stride(-3), # v_row_stride, 

843 q.stride(-2), # q_head_stride, 

844 k.stride(-2), # k_head_stride, 

845 v.stride(-2), # v_head_stride, 

846 out.stride(-3), # o_row_stride, 

847 out.stride(-2), # o_head_stride, 

848 q.stride(0), # q_batch_stride, 

849 k.stride(0), # k_batch_stride, 

850 v.stride(0), # v_batch_stride, 

851 out.stride(0), # o_batch_stride, 

852 False, # is_cu_seqlens_q, 

853 None, # cu_seqlens_q_ptr, 

854 False, # is_cu_seqlens_k, 

855 None, # cu_seqlens_k_ptr, 

856 False, # is_seqused_k, 

857 None, # seqused_k_ptr, 

858 # sizes 

859 batch_size, # b, 

860 0, # bk, 

861 num_heads, # h, 

862 num_heads_k, # hk, 

863 num_heads // num_heads_k, # h_hk_ratio, 

864 seqlen_q, # seqlen_q, 

865 seqlen_k, # seqlen_k, 

866 seqlen_q_rounded, # seqlen_q_rounded, 

867 seqlen_k_rounded, # seqlen_k_rounded, 

868 head_size, # d, 

869 head_size_rounded, # d_rounded, 

870 # scaling factors 

871 is_softcap, 

872 adjusted_softcap, # softcap, 

873 adjusted_scale_softmax, # scale_softmax, 

874 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

875 # dropout 

876 is_dropout, 

877 p_dropout, 

878 rp_dropout, 

879 p_dropout_in_uint8_t, 

880 philox_args, 

881 return_softmax, 

882 # causal and swa 

883 is_causal, # is_causal, 

884 is_local, # is_local, 

885 window_size_left, # window_size_left, 

886 window_size_right, # window_size_right, 

887 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

888 False, # is_paged, 

889 # alibi 

890 is_alibi, # 

891 alibi_slopes, # alibi_slopes_ptr, 

892 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

893 # block table params 

894 0, # total_q, 

895 None, # page_table_ptr, 

896 0, # page_table_batch_stride, 

897 0, # block_size, 

898 ) 

899 

900 # Move TxD to last dims for correct stride in Triton tt.load 

901 if flag_gems.vendor_name == "iluvatar": 

902 params.q_ptr = q.transpose(1, 2) 

903 params.k_ptr = k.transpose(1, 2) 

904 params.v_ptr = v.transpose(1, 2) 

905 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params) 

906 

907 if _debug: 

908 print(f"{kernel.name} shared memory:", kernel.metadata.shared) 

909 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps) 

910 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages) 

911 # print(kernel.asm['ttgir']) 

912 

913 if seqlenq_ngroups_swapped: 

914 out = out.transpose(1, 2).reshape( 

915 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

916 ) 

917 q = q.transpose(1, 2).reshape( 

918 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

919 ) 

920 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1)) 

921 

922 unused = torch.empty((), dtype=torch.int64, device=q_device) 

923 

924 return out, q, k, v, lse, philox_args, unused, p