Coverage for src/flag_gems/ops/flash_api.py: 91%

373 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6 

7import flag_gems 

8from flag_gems import runtime 

9from flag_gems.ops.flash_kernel import ( 

10 block_m_splitkv_heuristic, 

11 block_n_splitkv_heuristic, 

12 flash_fwd_kernel, 

13 flash_fwd_splitkv_combine_kernel, 

14 flash_fwd_splitkv_kernel, 

15 flash_varlen_fwd_kernel, 

16) 

17from flag_gems.runtime import torch_device_fn 

18from flag_gems.utils.random_utils import philox_backend_seed_offset 

19 

20logger = logging.getLogger(__name__) 

21_debug = False 

22 

23 

24def CHECK_DEVICE(x): 

25 assert x.device.type == flag_gems.device 

26 

27 

28class fwd_params: 

29 __slots__ = ( 

30 # pointers and strides 

31 "q_ptr", 

32 "k_ptr", 

33 "v_ptr", 

34 "o_ptr", 

35 "p_ptr", 

36 "softmax_lse_ptr", 

37 "q_row_stride", 

38 "k_row_stride", 

39 "v_row_stride", 

40 "q_head_stride", 

41 "k_head_stride", 

42 "v_head_stride", 

43 "o_row_stride", 

44 "o_head_stride", 

45 "q_batch_stride", 

46 "k_batch_stride", 

47 "v_batch_stride", 

48 "o_batch_stride", 

49 "is_cu_seqlens_q", 

50 "cu_seqlens_q_ptr", 

51 "is_cu_seqlens_k", 

52 "cu_seqlens_k_ptr", 

53 "is_seqused_k", 

54 "seqused_k_ptr", 

55 # sizes 

56 "b", 

57 "bk", 

58 "h", 

59 "hk", 

60 "h_hk_ratio", 

61 "seqlen_q", 

62 "seqlen_k", 

63 "seqlen_q_rounded", 

64 "seqlen_k_rounded", 

65 "d", 

66 "d_rounded", 

67 # scaling factors 

68 "is_softcap", 

69 "softcap", 

70 "scale_softmax", 

71 "scale_softmax_log2", 

72 # dropout 

73 "is_dropout", 

74 "p_dropout", 

75 "rp_dropout", 

76 "p_dropout_in_uint8_t", 

77 "philox_args", 

78 "return_softmax", 

79 # masking 

80 "is_causal", 

81 "is_local", 

82 "window_size_left", 

83 "window_size_right", 

84 "seqlenq_ngroups_swapped", 

85 # alibi 

86 "is_alibi", 

87 "alibi_slopes_ptr", 

88 "alibi_slopes_batch_stride", 

89 # block table 

90 "total_q", 

91 "page_table_ptr", 

92 "page_table_batch_stride", 

93 "block_size", 

94 ) 

95 

96 def __init__( 

97 self, 

98 q_ptr, 

99 k_ptr, 

100 v_ptr, 

101 o_ptr, 

102 p_ptr, 

103 softmax_lse_ptr, 

104 q_row_stride, 

105 k_row_stride, 

106 v_row_stride, 

107 q_head_stride, 

108 k_head_stride, 

109 v_head_stride, 

110 o_row_stride, 

111 o_head_stride, 

112 q_batch_stride, 

113 k_batch_stride, 

114 v_batch_stride, 

115 o_batch_stride, 

116 is_cu_seqlens_q, 

117 cu_seqlens_q_ptr, 

118 is_cu_seqlens_k, 

119 cu_seqlens_k_ptr, 

120 is_seqused_k, 

121 seqused_k_ptr, 

122 # sizes 

123 b, 

124 bk, 

125 h, 

126 hk, 

127 h_hk_ratio, 

128 seqlen_q, 

129 seqlen_k, 

130 seqlen_q_rounded, 

131 seqlen_k_rounded, 

132 d, 

133 d_rounded, 

134 # scaling factors 

135 is_softcap, 

136 softcap, 

137 scale_softmax, 

138 scale_softmax_log2, 

139 # dropout 

140 is_dropout, 

141 p_dropout, 

142 rp_dropout, 

143 p_dropout_in_uint8_t, 

144 philox_args, 

145 return_softmax, 

146 # masking 

147 is_causal, 

148 is_local, 

149 window_size_left, 

150 window_size_right, 

151 seqlenq_ngroups_swapped, 

152 # alibi 

153 is_alibi, 

154 alibi_slopes_ptr, 

155 alibi_slopes_batch_stride, 

156 # block table 

157 total_q, 

158 page_table_ptr, 

159 page_table_batch_stride, 

160 block_size, 

161 ): 

162 self.q_ptr = q_ptr 

163 self.k_ptr = k_ptr 

164 self.v_ptr = v_ptr 

165 self.o_ptr = o_ptr 

166 self.p_ptr = p_ptr 

167 self.softmax_lse_ptr = softmax_lse_ptr 

168 self.q_row_stride = q_row_stride 

169 self.k_row_stride = k_row_stride 

170 self.v_row_stride = v_row_stride 

171 self.q_head_stride = q_head_stride 

172 self.k_head_stride = k_head_stride 

173 self.v_head_stride = v_head_stride 

174 self.o_row_stride = o_row_stride 

175 self.o_head_stride = o_head_stride 

176 self.q_batch_stride = q_batch_stride 

177 self.k_batch_stride = k_batch_stride 

178 self.v_batch_stride = v_batch_stride 

179 self.o_batch_stride = o_batch_stride 

180 self.is_cu_seqlens_q = is_cu_seqlens_q 

181 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr 

182 self.is_cu_seqlens_k = is_cu_seqlens_k 

183 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr 

184 self.is_seqused_k = is_seqused_k 

185 self.seqused_k_ptr = seqused_k_ptr 

186 # sizes 

187 self.b = b 

188 self.bk = bk 

189 self.h = h 

190 self.hk = hk 

191 self.h_hk_ratio = h_hk_ratio 

192 self.seqlen_q = seqlen_q 

193 self.seqlen_k = seqlen_k 

194 self.seqlen_q_rounded = seqlen_q_rounded 

195 self.seqlen_k_rounded = seqlen_k_rounded 

196 self.d = d 

197 self.d_rounded = d_rounded 

198 # scaling factors 

199 self.is_softcap = is_softcap 

200 self.softcap = softcap 

201 self.scale_softmax = scale_softmax 

202 self.scale_softmax_log2 = scale_softmax_log2 

203 # dropout 

204 self.is_dropout = is_dropout 

205 self.p_dropout = p_dropout 

206 self.rp_dropout = rp_dropout 

207 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t 

208 self.philox_args = philox_args 

209 self.return_softmax = return_softmax 

210 # masking 

211 self.is_causal = is_causal 

212 self.is_local = is_local 

213 self.window_size_left = window_size_left 

214 self.window_size_right = window_size_right 

215 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped 

216 # alibi 

217 self.is_alibi = is_alibi 

218 self.alibi_slopes_ptr = alibi_slopes_ptr 

219 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride 

220 # block table 

221 self.total_q = total_q 

222 self.page_table_ptr = page_table_ptr 

223 self.page_table_batch_stride = page_table_batch_stride 

224 self.block_size = block_size 

225 

226 def args(self): 

227 return tuple(getattr(self, k) for k in self.__slots__) 

228 

229 

230def mha_varlan_fwd( 

231 q, 

232 k, 

233 v, 

234 out, 

235 cu_seqlens_q, 

236 cu_seqlens_k, 

237 seqused_k, 

238 leftpad_k, 

239 page_table, 

240 alibi_slopes, 

241 max_seqlen_q, 

242 max_seqlen_k, 

243 p_dropout, 

244 softmax_scale, 

245 zero_tensors, 

246 is_causal, 

247 window_size_left, 

248 window_size_right, 

249 softcap, 

250 return_softmax, 

251 gen, 

252): 

253 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

254 q_device = q.device 

255 q_dtype = q.dtype 

256 assert q_dtype in ( 

257 torch.float16, 

258 torch.bfloat16, 

259 ), "FlashAttention only support fp16 and bf16 data type" 

260 assert q_dtype == k.dtype 

261 assert q_dtype == v.dtype 

262 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

263 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

264 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

265 

266 assert cu_seqlens_q.dtype == torch.int32 

267 assert cu_seqlens_q.is_contiguous() 

268 

269 assert cu_seqlens_k.dtype == torch.int32 

270 assert cu_seqlens_k.is_contiguous() 

271 

272 assert page_table is not None 

273 

274 # q shape: [total_q_tokens, num_heads, head_size] 

275 # k shape: 

276 # paged_kv: [num_pages, block_size, num_heads_k, head_size] 

277 # batch_size, number of sentences 

278 total_q, num_heads, head_size = q.size() 

279 num_heads_k = k.size(2) 

280 batch_size = cu_seqlens_q.numel() - 1 

281 block_size = k.size(1) 

282 num_pages = k.size(0) 

283 k_batch_size = num_pages 

284 # max_num_pages_per_seq = page_table.size(1) 

285 page_table_batch_stride = page_table.stride(0) 

286 k_batch_stride = k.stride(0) 

287 v_batch_stride = v.stride(0) 

288 

289 assert k.size() == v.size() 

290 assert cu_seqlens_q.size() == (batch_size + 1,) 

291 assert cu_seqlens_k.size() == (batch_size + 1,) 

292 

293 # Check output shape 

294 if out is not None: 

295 assert out.stride(-1) == 1 

296 assert out.dtype == q.dtype 

297 assert out.size() == (total_q, num_heads, head_size) 

298 

299 if seqused_k is not None: 

300 assert seqused_k.is_contiguous() 

301 assert seqused_k.size() == (batch_size,) 

302 

303 if max_seqlen_q == 1 and alibi_slopes is None: 

304 is_causal = False 

305 

306 if is_causal: 

307 window_size_right = 0 

308 

309 # check disable swa 

310 if window_size_left >= max_seqlen_k: 

311 window_size_left = -1 

312 if window_size_right >= max_seqlen_k: 

313 window_size_right = -1 

314 

315 is_local = window_size_left >= 0 

316 

317 # Optimize all single-query sequences by swapping the query-group and sequence dimensions 

318 seqlenq_ngroups_swapped = ( 

319 max_seqlen_q == 1 

320 and alibi_slopes is None 

321 and num_heads > num_heads_k 

322 and window_size_left < 0 

323 and window_size_right < 0 

324 and p_dropout == 0 

325 ) 

326 q_groups = num_heads // num_heads_k 

327 if seqlenq_ngroups_swapped: 

328 logger.debug("Swapping query groups and sequence dimensions") 

329 q = ( 

330 q.reshape((batch_size, num_heads_k, q_groups, head_size)) 

331 .transpose(1, 2) 

332 .reshape(batch_size * q_groups, num_heads_k, head_size) 

333 ) 

334 max_seqlen_q = q_groups 

335 num_heads = num_heads_k 

336 cu_seqlens_q = None 

337 q_batch_stride = q.stride(0) * max_seqlen_q 

338 k_batch_stride = k.stride(0) 

339 v_batch_stride = v.stride(0) 

340 # o_batch_stride = out.stride(0) * max_seqlen_q 

341 else: 

342 q_batch_stride = 0 

343 k_batch_stride = 0 

344 v_batch_stride = 0 

345 o_batch_stride = 0 

346 

347 total_q = q.size(0) 

348 

349 assert leftpad_k is None, "leftpad_k is not supported." 

350 assert ( 

351 head_size <= 256 

352 ), "FlashAttention forward only supports head dimension at most 256" 

353 assert ( 

354 head_size % 8 == 0 

355 ), "head_size must be a multiple of 8, this is ensured by padding!" 

356 assert ( 

357 num_heads % num_heads_k == 0 

358 ), "Number of heads in key/value must divide number of heads in query" 

359 

360 assert q.shape == (total_q, num_heads, head_size) 

361 assert k.shape == (num_pages, block_size, num_heads_k, head_size) 

362 assert v.shape == (num_pages, block_size, num_heads_k, head_size) 

363 assert k.stride() == v.stride() 

364 

365 if softcap > 0.0: 

366 assert p_dropout == 0, "dropout is not supported if softcap is used." 

367 

368 round_multiple = lambda x, m: (x + m - 1) // m * m 

369 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256 

370 seqlen_q_rounded = round_multiple(max_seqlen_q, 128) 

371 seqlen_k_rounded = round_multiple(max_seqlen_k, 32) 

372 

373 M_LOG2E = 1.4426950408889634074 

374 if softcap > 0.0: 

375 is_softcap = True 

376 adjusted_scale_softmax = softcap 

377 adjusted_softcap = softmax_scale / softcap 

378 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

379 else: 

380 is_softcap = False 

381 adjusted_softcap = 0.0 

382 adjusted_scale_softmax = softmax_scale 

383 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

384 

385 # Set alibi params 

386 if alibi_slopes is not None: 

387 assert alibi_slopes.device == q_device 

388 assert alibi_slopes.dtype in (torch.float,) 

389 assert alibi_slopes.stride(-1) == 1 

390 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

391 batch_size, 

392 num_heads, 

393 ) 

394 alibi_slopes_batch_stride = ( 

395 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

396 ) 

397 is_alibi = True 

398 else: 

399 alibi_slopes_batch_stride = 0 

400 is_alibi = False 

401 

402 # Prepare params to kernel 

403 with torch_device_fn.device(q_device): 

404 if out is not None: 

405 out_ = out 

406 if seqlenq_ngroups_swapped: 

407 out = torch.empty_like(q, dtype=v.dtype) 

408 else: 

409 out_ = None 

410 out = torch.empty_like(q, dtype=v.dtype) 

411 

412 if seqlenq_ngroups_swapped: 

413 o_batch_stride = out.stride(0) * max_seqlen_q 

414 

415 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device) 

416 

417 if p_dropout > 0: 

418 is_dropout = True 

419 increment = batch_size * num_heads * 32 

420 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

421 philox_args = torch.tensor( 

422 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

423 ) 

424 else: 

425 is_dropout = False 

426 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

427 

428 p_dropout = 1 - p_dropout 

429 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

430 rp_dropout = 1.0 / p_dropout 

431 

432 if return_softmax: 

433 assert is_dropout, "Only supported with non-zero dropout." 

434 p = torch.empty( 

435 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

436 device=q_device, 

437 ) 

438 else: 

439 p = torch.empty((), device=q_device) 

440 

441 if zero_tensors: 

442 out.zero_() 

443 lse.fill_(float("-inf")) 

444 

445 params = fwd_params( 

446 q, # q_ptr, 

447 k, # k_ptr, 

448 v, # v_ptr, 

449 out, # o_ptr, 

450 p, # p_ptr, 

451 lse, # softmax_lse_ptr, 

452 q.stride(-3), # q_row_stride, 

453 k.stride(-3), # k_row_stride, 

454 v.stride(-3), # v_row_stride, 

455 q.stride(-2), # q_head_stride, 

456 k.stride(-2), # k_head_stride, 

457 v.stride(-2), # v_head_stride, 

458 out.stride(-3), # o_row_stride, 

459 out.stride(-2), # o_head_stride, 

460 q_batch_stride, # q_batch_stride, 

461 k_batch_stride, # k_batch_stride, 

462 v_batch_stride, # v_batch_stride, 

463 o_batch_stride, # o_batch_stride, 

464 cu_seqlens_q is not None, # is_cu_seqlens_q, 

465 cu_seqlens_q, # cu_seqlens_q_ptr, 

466 seqused_k is None, # is_cu_seqlens_k, 

467 cu_seqlens_k, # cu_seqlens_k_ptr, 

468 seqused_k is not None, # is_seqused_k, 

469 seqused_k, # seqused_k_ptr, 

470 # sizes 

471 batch_size, # b, 

472 k_batch_size, # bk, 

473 num_heads, # h, 

474 num_heads_k, # hk, 

475 num_heads // num_heads_k, # h_hk_ratio, 

476 max_seqlen_q, # seqlen_q, 

477 max_seqlen_k, # seqlen_k, 

478 seqlen_q_rounded, # seqlen_q_rounded, 

479 seqlen_k_rounded, # seqlen_k_rounded, 

480 head_size, # d, 

481 head_size_rounded, # d_rounded, 

482 # scaling factors 

483 is_softcap, 

484 adjusted_softcap, # softcap, 

485 adjusted_scale_softmax, # scale_softmax, 

486 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

487 # dropout 

488 is_dropout, 

489 p_dropout, 

490 rp_dropout, 

491 p_dropout_in_uint8_t, 

492 philox_args, 

493 return_softmax, 

494 # causal and swa 

495 is_causal, # is_causal, 

496 is_local, # is_local, 

497 window_size_left, # window_size_left, 

498 window_size_right, # window_size_right, 

499 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

500 # alibi 

501 is_alibi, # 

502 alibi_slopes, # alibi_slopes_ptr, 

503 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

504 # block table params 

505 total_q, # total_q, 

506 page_table, # page_table_ptr, 

507 page_table_batch_stride, # page_table_batch_stride, 

508 block_size, # block_size, 

509 ) 

510 

511 if flag_gems.vendor_name == "iluvatar": 

512 params.k_ptr = k.view(k.shape[0], k.shape[1], -1) 

513 params.v_ptr = v.view(v.shape[0], v.shape[1], -1) 

514 logger.debug("kernel: flash_varlen_fwd") 

515 grid = lambda args: ( 

516 triton.cdiv(max_seqlen_q, args["BLOCK_M"]), 

517 batch_size, 

518 num_heads, 

519 ) 

520 kernel = flash_varlen_fwd_kernel[grid] 

521 args = tuple(getattr(params, k) for k in params.__slots__) 

522 

523 # We assess which phase the requests are likely to be in and set the config accordingly. 

524 total_rows = total_q * num_heads 

525 num_sms = torch_device_fn.get_device_properties( 

526 flag_gems.device 

527 ).multi_processor_count 

528 avg_rows_per_sm = total_rows / num_sms 

529 avg_rows_per_batch = total_q / batch_size 

530 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm) 

531 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase. 

532 # This is a rough heuristic and may not be accurate for all scenarios. 

533 if avg_rows_per_cta > 64: 

534 varlen_fwd_config_str = "mha_block_128" 

535 elif avg_rows_per_cta > 32: 

536 varlen_fwd_config_str = "mha_block_64" 

537 elif avg_rows_per_cta > 16: 

538 varlen_fwd_config_str = "mha_block_32" 

539 else: 

540 varlen_fwd_config_str = "mha_block_16" 

541 if flag_gems.vendor_name == "mthreads": 

542 varlen_fwd_config_str = "mha_block_32" 

543 

544 cfg = runtime.get_heuristic_config(varlen_fwd_config_str) 

545 cfg_params = { 

546 "BLOCK_M": cfg["BLOCK_M"](args), 

547 "BLOCK_N": cfg["BLOCK_N"](args), 

548 "BLOCK_K": triton.next_power_of_2(head_size), 

549 "num_warps": cfg["num_warps"](args), 

550 "num_stages": cfg["num_stages"](args), 

551 } 

552 

553 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params) 

554 kernel(*args, **cfg_params) 

555 

556 if seqlenq_ngroups_swapped: 

557 out = out.reshape( 

558 batch_size, max_seqlen_q, num_heads_k, head_size 

559 ).transpose(1, 2) 

560 if out_ is not None: 

561 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out) 

562 out = out_ 

563 else: 

564 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size) 

565 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q) 

566 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size) 

567 

568 unused = torch.empty((), dtype=torch.int64, device=q_device) 

569 return out, q, k, v, lse, philox_args, unused, p 

570 

571 

572def mha_fwd( 

573 q, 

574 k, 

575 v, 

576 out, 

577 alibi_slopes, 

578 p_dropout, 

579 softmax_scale, 

580 is_causal, 

581 window_size_left, 

582 window_size_right, 

583 softcap, 

584 return_softmax, 

585 disable_splitkv=False, 

586): 

587 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

588 q_dtype = q.dtype 

589 q_device = q.device 

590 assert q_dtype in ( 

591 torch.float16, 

592 torch.bfloat16, 

593 ), "FlashAttention only support fp16 and bf16 data type" 

594 assert q_dtype == k.dtype 

595 assert q_dtype == v.dtype 

596 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

597 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

598 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

599 batch_size, seqlen_q, num_heads, head_size = q.size() 

600 _, seqlen_k, num_heads_k, _ = k.size() 

601 

602 # Check output shape 

603 if out is not None: 

604 assert out.stride(-1) == 1 

605 assert out.dtype == q.dtype 

606 assert out.size() == (batch_size, seqlen_q, num_heads, head_size) 

607 CHECK_DEVICE(out) 

608 

609 assert ( 

610 head_size % 8 == 0 

611 ), "head_size must be a multiple of 8, this is ensured by padding!" 

612 assert ( 

613 num_heads % num_heads_k == 0 

614 ), "Number of heads in key/value must divide number of heads in query" 

615 if window_size_left >= seqlen_k: 

616 window_size_left = -1 

617 if window_size_right >= seqlen_k: 

618 window_size_right = -1 

619 if seqlen_q == 1 and alibi_slopes is None: 

620 is_causal = False 

621 if is_causal: 

622 window_size_right = 0 

623 

624 is_causal = window_size_left < 0 and window_size_right == 0 

625 is_local = window_size_left >= 0 and window_size_right >= 0 

626 

627 seqlenq_ngroups_swapped = ( 

628 seqlen_q == 1 

629 and alibi_slopes is None 

630 and num_heads > num_heads_k 

631 and window_size_left < 0 

632 and window_size_right < 0 

633 and p_dropout == 0 

634 ) 

635 q_groups = num_heads // num_heads_k 

636 

637 if seqlenq_ngroups_swapped: 

638 logger.debug("q_kg swapped.") 

639 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2) 

640 seqlen_q = q_groups 

641 num_heads = num_heads_k 

642 

643 round_multiple = lambda x, m: (x + m - 1) // m * m 

644 head_size_rounded = round_multiple(head_size, 32) 

645 seqlen_q_rounded = round_multiple(seqlen_q, 128) 

646 seqlen_k_rounded = round_multiple(seqlen_k, 32) 

647 

648 assert ( 

649 head_size <= 256 

650 ), "FlashAttention forward only supports head dimension at most 256" 

651 assert head_size == head_size_rounded, "head_size must be rounded to 32" 

652 

653 def splits_heuristic(num_tasks, num_sms, n_blocks): 

654 # splits when wave efficiency is low 

655 n_waves = triton.cdiv(num_tasks, num_sms) 

656 eff = (num_tasks / num_sms) / n_waves 

657 if eff > 0.8 or n_waves > 1: 

658 return 1 

659 

660 min_blocks_per_split = 2 

661 best_splits = min( 

662 triton.cdiv(n_blocks, min_blocks_per_split), 

663 int(math.floor(1.0 / eff)), 

664 num_sms, 

665 ) 

666 

667 return best_splits 

668 

669 with torch_device_fn.device(q_device): 

670 # Set softmax params 

671 lse = torch.empty( 

672 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device 

673 ) 

674 

675 if out is not None: 

676 if seqlenq_ngroups_swapped: 

677 out = out.reshape( 

678 batch_size, num_heads_k, q_groups, head_size 

679 ).transpose(1, 2) 

680 else: 

681 out = torch.empty_like(q, dtype=v.dtype) 

682 

683 # Set dropout params 

684 if p_dropout > 0: 

685 is_dropout = True 

686 increment = batch_size * num_heads * 32 

687 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

688 philox_args = torch.tensor( 

689 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

690 ) 

691 else: 

692 is_dropout = False 

693 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

694 

695 p_dropout = 1 - p_dropout 

696 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

697 rp_dropout = 1.0 / p_dropout 

698 

699 if return_softmax: 

700 assert is_dropout, "Only supported with non-zero dropout." 

701 p = torch.empty( 

702 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

703 device=q_device, 

704 ) 

705 else: 

706 p = torch.empty((), device=q_device) 

707 

708 M_LOG2E = 1.4426950408889634074 

709 if softcap > 0.0: 

710 is_softcap = True 

711 adjusted_scale_softmax = softcap 

712 adjusted_softcap = softmax_scale / softcap 

713 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

714 else: 

715 is_softcap = False 

716 adjusted_softcap = 0.0 

717 adjusted_scale_softmax = softmax_scale 

718 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

719 

720 # Set alibi params 

721 if alibi_slopes is not None: 

722 assert alibi_slopes.device == q_device 

723 assert alibi_slopes.dtype in (torch.float,) 

724 assert alibi_slopes.stride(-1) == 1 

725 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

726 batch_size, 

727 num_heads, 

728 ) 

729 alibi_slopes_batch_stride = ( 

730 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

731 ) 

732 is_alibi = True 

733 else: 

734 alibi_slopes_batch_stride = 0 

735 is_alibi = False 

736 

737 # ONLY EVEN_K IS SUPPORTED 

738 assert head_size == head_size_rounded 

739 

740 # Do kernel dispatching 

741 def dispatch(B, H, Q, K, D, params): 

742 num_sms = torch_device_fn.get_device_properties( 

743 "cuda" 

744 ).multi_processor_count 

745 

746 # Try bh parallel 

747 # if B * H > 0.8 * num_sms: 

748 # kernel = flash_fwd_bh_parallel_kernel[(H, B)] 

749 # # Yield kernel and prefilled args 

750 # return kernel, default_args, None, None 

751 

752 # Try splitkv 

753 if not is_dropout and not is_local and not disable_splitkv: 

754 BM = block_m_splitkv_heuristic(D) 

755 n_tasks = B * H * triton.cdiv(seqlen_q, BM) 

756 BN = block_n_splitkv_heuristic(D) 

757 n_blocks = triton.cdiv(seqlen_k, BN) 

758 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks) 

759 

760 if n_splits > 1: 

761 logger.debug("kernel: flash_fwd_splitkv") 

762 lse_splits = torch.empty( 

763 (n_splits, B, H, Q), dtype=torch.float, device=q_device 

764 ) 

765 out_splits = torch.empty( 

766 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device 

767 ) 

768 grid = lambda args: ( 

769 triton.cdiv(Q, args["BLOCK_M"]), 

770 n_splits, 

771 B * H, 

772 ) 

773 splitkv_kernel = flash_fwd_splitkv_kernel[grid] 

774 params.o_ptr = out_splits 

775 params.softmax_lse_ptr = lse_splits 

776 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)} 

777 kernel = splitkv_kernel(*params.args(), **extra_args) 

778 

779 if D >= 128: 

780 BLOCK_M = 4 

781 elif D >= 64: 

782 BLOCK_M = 8 

783 else: 

784 BLOCK_M = 16 

785 BLOCK_K = triton.next_power_of_2(D) 

786 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),) 

787 combine_kernel = flash_fwd_splitkv_combine_kernel[grid] 

788 combine_args = { 

789 "out_ptr": out, 

790 "lse_ptr": lse, 

791 "head_size": head_size, 

792 "out_split_stride": out_splits.stride(0), 

793 "lse_split_stride": lse_splits.stride(0), 

794 "out_b_stride": out.stride(0), 

795 "out_s_stride": out.stride(-3), 

796 "out_h_stride": out.stride(-1), 

797 "out_splits_ptr": out_splits, 

798 "lse_splits_ptr": lse_splits, 

799 "n_splits": n_splits, 

800 "BLOCK_M": BLOCK_M, 

801 "BLOCK_K": BLOCK_K, 

802 "q_total": B * H * Q, 

803 "MAX_N_SPLITS": triton.next_power_of_2(n_splits), 

804 } 

805 combine_kernel(**combine_args) 

806 return kernel 

807 

808 # Last option: flash_fwd 

809 logger.debug("kernel: flash_fwd") 

810 grid = lambda args: ( 

811 triton.cdiv(Q, args["BLOCK_M"]), 

812 H * B, 

813 ) 

814 kernel = flash_fwd_kernel[grid] 

815 kernel = kernel(*params.args()) 

816 return kernel 

817 

818 if _debug: 

819 p = torch.empty( 

820 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

821 dtype=torch.float32, 

822 device=q_device, 

823 ) 

824 return_softmax = True 

825 

826 params = fwd_params( 

827 q, # q_ptr, 

828 k, # k_ptr, 

829 v, # v_ptr, 

830 out, # o_ptr, 

831 p, # p_ptr, 

832 lse, # softmax_lse_ptr, 

833 q.stride(-3), # q_row_stride, 

834 k.stride(-3), # k_row_stride, 

835 v.stride(-3), # v_row_stride, 

836 q.stride(-2), # q_head_stride, 

837 k.stride(-2), # k_head_stride, 

838 v.stride(-2), # v_head_stride, 

839 out.stride(-3), # o_row_stride, 

840 out.stride(-2), # o_head_stride, 

841 q.stride(0), # q_batch_stride, 

842 k.stride(0), # k_batch_stride, 

843 v.stride(0), # v_batch_stride, 

844 out.stride(0), # o_batch_stride, 

845 False, # is_cu_seqlens_q, 

846 None, # cu_seqlens_q_ptr, 

847 False, # is_cu_seqlens_k, 

848 None, # cu_seqlens_k_ptr, 

849 False, # is_seqused_k, 

850 None, # seqused_k_ptr, 

851 # sizes 

852 batch_size, # b, 

853 0, # bk, 

854 num_heads, # h, 

855 num_heads_k, # hk, 

856 num_heads // num_heads_k, # h_hk_ratio, 

857 seqlen_q, # seqlen_q, 

858 seqlen_k, # seqlen_k, 

859 seqlen_q_rounded, # seqlen_q_rounded, 

860 seqlen_k_rounded, # seqlen_k_rounded, 

861 head_size, # d, 

862 head_size_rounded, # d_rounded, 

863 # scaling factors 

864 is_softcap, 

865 adjusted_softcap, # softcap, 

866 adjusted_scale_softmax, # scale_softmax, 

867 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

868 # dropout 

869 is_dropout, 

870 p_dropout, 

871 rp_dropout, 

872 p_dropout_in_uint8_t, 

873 philox_args, 

874 return_softmax, 

875 # causal and swa 

876 is_causal, # is_causal, 

877 is_local, # is_local, 

878 window_size_left, # window_size_left, 

879 window_size_right, # window_size_right, 

880 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

881 # alibi 

882 is_alibi, # 

883 alibi_slopes, # alibi_slopes_ptr, 

884 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

885 # block table params 

886 0, # total_q, 

887 None, # page_table_ptr, 

888 0, # page_table_batch_stride, 

889 0, # block_size, 

890 ) 

891 

892 # Move TxD to last dims for correct stride in Triton tt.load 

893 if flag_gems.vendor_name == "iluvatar": 

894 params.q_ptr = q.transpose(1, 2) 

895 params.k_ptr = k.transpose(1, 2) 

896 params.v_ptr = v.transpose(1, 2) 

897 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params) 

898 

899 if _debug: 

900 print(f"{kernel.name} shared memory:", kernel.metadata.shared) 

901 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps) 

902 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages) 

903 # print(kernel.asm['ttgir']) 

904 

905 if seqlenq_ngroups_swapped: 

906 out = out.transpose(1, 2).reshape( 

907 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

908 ) 

909 q = q.transpose(1, 2).reshape( 

910 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

911 ) 

912 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1)) 

913 

914 unused = torch.empty((), dtype=torch.int64, device=q_device) 

915 

916 return out, q, k, v, lse, philox_args, unused, p