Coverage for src/flag_gems/runtime/backend/_hygon/ops/flash_api.py: 0%

375 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6 

7import flag_gems 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils.random_utils import philox_backend_seed_offset 

11 

12from .flash_kernel import ( 

13 block_m_splitkv_heuristic, 

14 block_n_splitkv_heuristic, 

15 flash_fwd_kernel, 

16 flash_fwd_splitkv_combine_kernel, 

17 flash_fwd_splitkv_kernel, 

18 flash_varlen_fwd_kernel, 

19) 

20 

21logger = logging.getLogger(__name__) 

22_debug = False 

23 

24 

25def CHECK_DEVICE(x): 

26 assert x.device.type == flag_gems.device 

27 

28 

29class fwd_params: 

30 __slots__ = ( 

31 # pointers and strides 

32 "q_ptr", 

33 "k_ptr", 

34 "v_ptr", 

35 "o_ptr", 

36 "p_ptr", 

37 "softmax_lse_ptr", 

38 "q_row_stride", 

39 "k_row_stride", 

40 "v_row_stride", 

41 "q_head_stride", 

42 "k_head_stride", 

43 "v_head_stride", 

44 "o_row_stride", 

45 "o_head_stride", 

46 "q_batch_stride", 

47 "k_batch_stride", 

48 "v_batch_stride", 

49 "o_batch_stride", 

50 "is_cu_seqlens_q", 

51 "cu_seqlens_q_ptr", 

52 "is_cu_seqlens_k", 

53 "cu_seqlens_k_ptr", 

54 "is_seqused_k", 

55 "seqused_k_ptr", 

56 # sizes 

57 "b", 

58 "bk", 

59 "h", 

60 "hk", 

61 "h_hk_ratio", 

62 "seqlen_q", 

63 "seqlen_k", 

64 "seqlen_q_rounded", 

65 "seqlen_k_rounded", 

66 "d", 

67 "d_rounded", 

68 # scaling factors 

69 "is_softcap", 

70 "softcap", 

71 "scale_softmax", 

72 "scale_softmax_log2", 

73 # dropout 

74 "is_dropout", 

75 "p_dropout", 

76 "rp_dropout", 

77 "p_dropout_in_uint8_t", 

78 "philox_args", 

79 "return_softmax", 

80 # masking 

81 "is_causal", 

82 "is_local", 

83 "window_size_left", 

84 "window_size_right", 

85 "seqlenq_ngroups_swapped", 

86 # alibi 

87 "is_alibi", 

88 "alibi_slopes_ptr", 

89 "alibi_slopes_batch_stride", 

90 # block table 

91 "total_q", 

92 "page_table_ptr", 

93 "page_table_batch_stride", 

94 "block_size", 

95 ) 

96 

97 def __init__( 

98 self, 

99 q_ptr, 

100 k_ptr, 

101 v_ptr, 

102 o_ptr, 

103 p_ptr, 

104 softmax_lse_ptr, 

105 q_row_stride, 

106 k_row_stride, 

107 v_row_stride, 

108 q_head_stride, 

109 k_head_stride, 

110 v_head_stride, 

111 o_row_stride, 

112 o_head_stride, 

113 q_batch_stride, 

114 k_batch_stride, 

115 v_batch_stride, 

116 o_batch_stride, 

117 is_cu_seqlens_q, 

118 cu_seqlens_q_ptr, 

119 is_cu_seqlens_k, 

120 cu_seqlens_k_ptr, 

121 is_seqused_k, 

122 seqused_k_ptr, 

123 # sizes 

124 b, 

125 bk, 

126 h, 

127 hk, 

128 h_hk_ratio, 

129 seqlen_q, 

130 seqlen_k, 

131 seqlen_q_rounded, 

132 seqlen_k_rounded, 

133 d, 

134 d_rounded, 

135 # scaling factors 

136 is_softcap, 

137 softcap, 

138 scale_softmax, 

139 scale_softmax_log2, 

140 # dropout 

141 is_dropout, 

142 p_dropout, 

143 rp_dropout, 

144 p_dropout_in_uint8_t, 

145 philox_args, 

146 return_softmax, 

147 # masking 

148 is_causal, 

149 is_local, 

150 window_size_left, 

151 window_size_right, 

152 seqlenq_ngroups_swapped, 

153 # alibi 

154 is_alibi, 

155 alibi_slopes_ptr, 

156 alibi_slopes_batch_stride, 

157 # block table 

158 total_q, 

159 page_table_ptr, 

160 page_table_batch_stride, 

161 block_size, 

162 ): 

163 self.q_ptr = q_ptr 

164 self.k_ptr = k_ptr 

165 self.v_ptr = v_ptr 

166 self.o_ptr = o_ptr 

167 self.p_ptr = p_ptr 

168 self.softmax_lse_ptr = softmax_lse_ptr 

169 self.q_row_stride = q_row_stride 

170 self.k_row_stride = k_row_stride 

171 self.v_row_stride = v_row_stride 

172 self.q_head_stride = q_head_stride 

173 self.k_head_stride = k_head_stride 

174 self.v_head_stride = v_head_stride 

175 self.o_row_stride = o_row_stride 

176 self.o_head_stride = o_head_stride 

177 self.q_batch_stride = q_batch_stride 

178 self.k_batch_stride = k_batch_stride 

179 self.v_batch_stride = v_batch_stride 

180 self.o_batch_stride = o_batch_stride 

181 self.is_cu_seqlens_q = is_cu_seqlens_q 

182 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr 

183 self.is_cu_seqlens_k = is_cu_seqlens_k 

184 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr 

185 self.is_seqused_k = is_seqused_k 

186 self.seqused_k_ptr = seqused_k_ptr 

187 # sizes 

188 self.b = b 

189 self.bk = bk 

190 self.h = h 

191 self.hk = hk 

192 self.h_hk_ratio = h_hk_ratio 

193 self.seqlen_q = seqlen_q 

194 self.seqlen_k = seqlen_k 

195 self.seqlen_q_rounded = seqlen_q_rounded 

196 self.seqlen_k_rounded = seqlen_k_rounded 

197 self.d = d 

198 self.d_rounded = d_rounded 

199 # scaling factors 

200 self.is_softcap = is_softcap 

201 self.softcap = softcap 

202 self.scale_softmax = scale_softmax 

203 self.scale_softmax_log2 = scale_softmax_log2 

204 # dropout 

205 self.is_dropout = is_dropout 

206 self.p_dropout = p_dropout 

207 self.rp_dropout = rp_dropout 

208 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t 

209 self.philox_args = philox_args 

210 self.return_softmax = return_softmax 

211 # masking 

212 self.is_causal = is_causal 

213 self.is_local = is_local 

214 self.window_size_left = window_size_left 

215 self.window_size_right = window_size_right 

216 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped 

217 # alibi 

218 self.is_alibi = is_alibi 

219 self.alibi_slopes_ptr = alibi_slopes_ptr 

220 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride 

221 # block table 

222 self.total_q = total_q 

223 self.page_table_ptr = page_table_ptr 

224 self.page_table_batch_stride = page_table_batch_stride 

225 self.block_size = block_size 

226 

227 def args(self): 

228 return tuple(getattr(self, k) for k in self.__slots__) 

229 

230 

231def mha_varlan_fwd( 

232 q, 

233 k, 

234 v, 

235 out, 

236 cu_seqlens_q, 

237 cu_seqlens_k, 

238 seqused_k, 

239 leftpad_k, 

240 page_table, 

241 alibi_slopes, 

242 max_seqlen_q, 

243 max_seqlen_k, 

244 p_dropout, 

245 softmax_scale, 

246 zero_tensors, 

247 is_causal, 

248 window_size_left, 

249 window_size_right, 

250 softcap, 

251 return_softmax, 

252 gen, 

253): 

254 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

255 q_device = q.device 

256 q_dtype = q.dtype 

257 assert q_dtype in ( 

258 torch.float16, 

259 torch.bfloat16, 

260 ), "FlashAttention only support fp16 and bf16 data type" 

261 assert q_dtype == k.dtype 

262 assert q_dtype == v.dtype 

263 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

264 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

265 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

266 

267 assert cu_seqlens_q.dtype == torch.int32 

268 assert cu_seqlens_q.is_contiguous() 

269 

270 assert cu_seqlens_k.dtype == torch.int32 

271 assert cu_seqlens_k.is_contiguous() 

272 

273 assert page_table is not None 

274 

275 # q shape: [total_q_tokens, num_heads, head_size] 

276 # k shape: 

277 # paged_kv: [num_pages, block_size, num_heads_k, head_size] 

278 # batch_size, number of sentences 

279 total_q, num_heads, head_size = q.size() 

280 num_heads_k = k.size(2) 

281 batch_size = cu_seqlens_q.numel() - 1 

282 block_size = k.size(1) 

283 num_pages = k.size(0) 

284 k_batch_size = num_pages 

285 # max_num_pages_per_seq = page_table.size(1) 

286 page_table_batch_stride = page_table.stride(0) 

287 k_batch_stride = k.stride(0) 

288 v_batch_stride = v.stride(0) 

289 

290 assert k.size() == v.size() 

291 assert cu_seqlens_q.size() == (batch_size + 1,) 

292 assert cu_seqlens_k.size() == (batch_size + 1,) 

293 

294 # Check output shape 

295 if out is not None: 

296 assert out.stride(-1) == 1 

297 assert out.dtype == q.dtype 

298 assert out.size() == (total_q, num_heads, head_size) 

299 

300 if seqused_k is not None: 

301 assert seqused_k.is_contiguous() 

302 assert seqused_k.size() == (batch_size,) 

303 

304 if max_seqlen_q == 1 and alibi_slopes is None: 

305 is_causal = False 

306 

307 if is_causal: 

308 window_size_right = 0 

309 

310 # check disable swa 

311 if window_size_left >= max_seqlen_k: 

312 window_size_left = -1 

313 if window_size_right >= max_seqlen_k: 

314 window_size_right = -1 

315 

316 is_local = window_size_left >= 0 

317 

318 # Optimize all single-query sequences by swapping the query-group and sequence dimensions 

319 seqlenq_ngroups_swapped = ( 

320 max_seqlen_q == 1 

321 and alibi_slopes is None 

322 and num_heads > num_heads_k 

323 and window_size_left < 0 

324 and window_size_right < 0 

325 and p_dropout == 0 

326 ) 

327 q_groups = num_heads // num_heads_k 

328 if seqlenq_ngroups_swapped: 

329 logger.debug("Swapping query groups and sequence dimensions") 

330 q = ( 

331 q.reshape((batch_size, num_heads_k, q_groups, head_size)) 

332 .transpose(1, 2) 

333 .reshape(batch_size * q_groups, num_heads_k, head_size) 

334 ) 

335 max_seqlen_q = q_groups 

336 num_heads = num_heads_k 

337 cu_seqlens_q = None 

338 q_batch_stride = q.stride(0) * max_seqlen_q 

339 k_batch_stride = k.stride(0) 

340 v_batch_stride = v.stride(0) 

341 # o_batch_stride = out.stride(0) * max_seqlen_q 

342 else: 

343 q_batch_stride = 0 

344 k_batch_stride = 0 

345 v_batch_stride = 0 

346 o_batch_stride = 0 

347 

348 total_q = q.size(0) 

349 

350 assert leftpad_k is None, "leftpad_k is not supported." 

351 assert ( 

352 head_size <= 256 

353 ), "FlashAttention forward only supports head dimension at most 256" 

354 assert ( 

355 head_size % 8 == 0 

356 ), "head_size must be a multiple of 8, this is ensured by padding!" 

357 assert ( 

358 num_heads % num_heads_k == 0 

359 ), "Number of heads in key/value must divide number of heads in query" 

360 

361 assert q.shape == (total_q, num_heads, head_size) 

362 assert k.shape == (num_pages, block_size, num_heads_k, head_size) 

363 assert v.shape == (num_pages, block_size, num_heads_k, head_size) 

364 assert k.stride() == v.stride() 

365 

366 if softcap > 0.0: 

367 assert p_dropout == 0, "dropout is not supported if softcap is used." 

368 

369 round_multiple = lambda x, m: (x + m - 1) // m * m 

370 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256 

371 seqlen_q_rounded = round_multiple(max_seqlen_q, 128) 

372 seqlen_k_rounded = round_multiple(max_seqlen_k, 32) 

373 

374 M_LOG2E = 1.4426950408889634074 

375 if softcap > 0.0: 

376 is_softcap = True 

377 adjusted_scale_softmax = softcap 

378 adjusted_softcap = softmax_scale / softcap 

379 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

380 else: 

381 is_softcap = False 

382 adjusted_softcap = 0.0 

383 adjusted_scale_softmax = softmax_scale 

384 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

385 

386 # Set alibi params 

387 if alibi_slopes is not None: 

388 assert alibi_slopes.device == q_device 

389 assert alibi_slopes.dtype in (torch.float,) 

390 assert alibi_slopes.stride(-1) == 1 

391 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

392 batch_size, 

393 num_heads, 

394 ) 

395 alibi_slopes_batch_stride = ( 

396 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

397 ) 

398 is_alibi = True 

399 else: 

400 alibi_slopes_batch_stride = 0 

401 is_alibi = False 

402 

403 # Prepare params to kernel 

404 with torch_device_fn.device(q_device): 

405 if out is not None: 

406 out_ = out 

407 if seqlenq_ngroups_swapped: 

408 out = torch.empty_like(q, dtype=v.dtype) 

409 else: 

410 out_ = None 

411 out = torch.empty_like(q, dtype=v.dtype) 

412 

413 if seqlenq_ngroups_swapped: 

414 o_batch_stride = out.stride(0) * max_seqlen_q 

415 

416 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device) 

417 

418 if p_dropout > 0: 

419 is_dropout = True 

420 increment = batch_size * num_heads * 32 

421 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

422 philox_args = torch.tensor( 

423 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

424 ) 

425 else: 

426 is_dropout = False 

427 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

428 

429 p_dropout = 1 - p_dropout 

430 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

431 rp_dropout = 1.0 / p_dropout 

432 

433 if return_softmax: 

434 assert is_dropout, "Only supported with non-zero dropout." 

435 p = torch.empty( 

436 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

437 device=q_device, 

438 ) 

439 else: 

440 p = torch.empty((), device=q_device) 

441 

442 if zero_tensors: 

443 out.zero_() 

444 lse.fill_(float("-inf")) 

445 

446 params = fwd_params( 

447 q, # q_ptr, 

448 k, # k_ptr, 

449 v, # v_ptr, 

450 out, # o_ptr, 

451 p, # p_ptr, 

452 lse, # softmax_lse_ptr, 

453 q.stride(-3), # q_row_stride, 

454 k.stride(-3), # k_row_stride, 

455 v.stride(-3), # v_row_stride, 

456 q.stride(-2), # q_head_stride, 

457 k.stride(-2), # k_head_stride, 

458 v.stride(-2), # v_head_stride, 

459 out.stride(-3), # o_row_stride, 

460 out.stride(-2), # o_head_stride, 

461 q_batch_stride, # q_batch_stride, 

462 k_batch_stride, # k_batch_stride, 

463 v_batch_stride, # v_batch_stride, 

464 o_batch_stride, # o_batch_stride, 

465 cu_seqlens_q is not None, # is_cu_seqlens_q, 

466 cu_seqlens_q, # cu_seqlens_q_ptr, 

467 seqused_k is None, # is_cu_seqlens_k, 

468 cu_seqlens_k, # cu_seqlens_k_ptr, 

469 seqused_k is not None, # is_seqused_k, 

470 seqused_k, # seqused_k_ptr, 

471 # sizes 

472 batch_size, # b, 

473 k_batch_size, # bk, 

474 num_heads, # h, 

475 num_heads_k, # hk, 

476 num_heads // num_heads_k, # h_hk_ratio, 

477 max_seqlen_q, # seqlen_q, 

478 max_seqlen_k, # seqlen_k, 

479 seqlen_q_rounded, # seqlen_q_rounded, 

480 seqlen_k_rounded, # seqlen_k_rounded, 

481 head_size, # d, 

482 head_size_rounded, # d_rounded, 

483 # scaling factors 

484 is_softcap, 

485 adjusted_softcap, # softcap, 

486 adjusted_scale_softmax, # scale_softmax, 

487 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

488 # dropout 

489 is_dropout, 

490 p_dropout, 

491 rp_dropout, 

492 p_dropout_in_uint8_t, 

493 philox_args, 

494 return_softmax, 

495 # causal and swa 

496 is_causal, # is_causal, 

497 is_local, # is_local, 

498 window_size_left, # window_size_left, 

499 window_size_right, # window_size_right, 

500 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

501 # alibi 

502 is_alibi, # 

503 alibi_slopes, # alibi_slopes_ptr, 

504 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

505 # block table params 

506 total_q, # total_q, 

507 page_table, # page_table_ptr, 

508 page_table_batch_stride, # page_table_batch_stride, 

509 block_size, # block_size, 

510 ) 

511 

512 if flag_gems.vendor_name == "iluvatar": 

513 params.k_ptr = k.view(k.shape[0], k.shape[1], -1) 

514 params.v_ptr = v.view(v.shape[0], v.shape[1], -1) 

515 logger.debug("kernel: flash_varlen_fwd") 

516 grid = lambda args: ( 

517 triton.cdiv(max_seqlen_q, args["BLOCK_M"]), 

518 batch_size, 

519 num_heads, 

520 ) 

521 kernel = flash_varlen_fwd_kernel[grid] 

522 args = tuple(getattr(params, k) for k in params.__slots__) 

523 

524 # We assess which phase the requests are likely to be in and set the config accordingly. 

525 total_rows = total_q * num_heads 

526 num_sms = torch_device_fn.get_device_properties( 

527 flag_gems.device 

528 ).multi_processor_count 

529 avg_rows_per_sm = total_rows / num_sms 

530 avg_rows_per_batch = total_q / batch_size 

531 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm) 

532 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase. 

533 # This is a rough heuristic and may not be accurate for all scenarios. 

534 if avg_rows_per_cta > 64: 

535 varlen_fwd_config_str = "mha_block_128" 

536 elif avg_rows_per_cta > 32: 

537 varlen_fwd_config_str = "mha_block_64" 

538 elif avg_rows_per_cta > 16: 

539 varlen_fwd_config_str = "mha_block_32" 

540 else: 

541 varlen_fwd_config_str = "mha_block_16" 

542 if flag_gems.vendor_name == "mthreads": 

543 varlen_fwd_config_str = "mha_block_32" 

544 

545 cfg = runtime.get_heuristic_config(varlen_fwd_config_str) 

546 cfg_params = { 

547 "BLOCK_M": cfg["BLOCK_M"](args), 

548 "BLOCK_N": cfg["BLOCK_N"](args), 

549 "BLOCK_K": triton.next_power_of_2(head_size), 

550 "num_warps": cfg["num_warps"](args), 

551 "num_stages": cfg["num_stages"](args), 

552 } 

553 

554 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params) 

555 kernel(*args, **cfg_params) 

556 

557 if seqlenq_ngroups_swapped: 

558 out = out.reshape( 

559 batch_size, max_seqlen_q, num_heads_k, head_size 

560 ).transpose(1, 2) 

561 if out_ is not None: 

562 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out) 

563 out = out_ 

564 else: 

565 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size) 

566 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q) 

567 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size) 

568 

569 unused = torch.empty((), dtype=torch.int64, device=q_device) 

570 return out, q, k, v, lse, philox_args, unused, p 

571 

572 

573def mha_fwd( 

574 q, 

575 k, 

576 v, 

577 out, 

578 alibi_slopes, 

579 p_dropout, 

580 softmax_scale, 

581 is_causal, 

582 window_size_left, 

583 window_size_right, 

584 softcap, 

585 return_softmax, 

586 disable_splitkv=False, 

587): 

588 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

589 q_dtype = q.dtype 

590 q_device = q.device 

591 assert q_dtype in ( 

592 torch.float16, 

593 torch.bfloat16, 

594 ), "FlashAttention only support fp16 and bf16 data type" 

595 assert q_dtype == k.dtype 

596 assert q_dtype == v.dtype 

597 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

598 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

599 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

600 batch_size, seqlen_q, num_heads, head_size = q.size() 

601 _, seqlen_k, num_heads_k, _ = k.size() 

602 

603 # Check output shape 

604 if out is not None: 

605 assert out.stride(-1) == 1 

606 assert out.dtype == q.dtype 

607 assert out.size() == (batch_size, seqlen_q, num_heads, head_size) 

608 CHECK_DEVICE(out) 

609 

610 assert ( 

611 head_size % 8 == 0 

612 ), "head_size must be a multiple of 8, this is ensured by padding!" 

613 assert ( 

614 num_heads % num_heads_k == 0 

615 ), "Number of heads in key/value must divide number of heads in query" 

616 if window_size_left >= seqlen_k: 

617 window_size_left = -1 

618 if window_size_right >= seqlen_k: 

619 window_size_right = -1 

620 if seqlen_q == 1 and alibi_slopes is None: 

621 is_causal = False 

622 if is_causal: 

623 window_size_right = 0 

624 

625 is_causal = window_size_left < 0 and window_size_right == 0 

626 is_local = window_size_left >= 0 and window_size_right >= 0 

627 

628 seqlenq_ngroups_swapped = ( 

629 seqlen_q == 1 

630 and alibi_slopes is None 

631 and num_heads > num_heads_k 

632 and window_size_left < 0 

633 and window_size_right < 0 

634 and p_dropout == 0 

635 ) 

636 q_groups = num_heads // num_heads_k 

637 

638 if seqlenq_ngroups_swapped: 

639 logger.debug("q_kg swapped.") 

640 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2) 

641 seqlen_q = q_groups 

642 num_heads = num_heads_k 

643 

644 round_multiple = lambda x, m: (x + m - 1) // m * m 

645 head_size_rounded = round_multiple(head_size, 32) 

646 seqlen_q_rounded = round_multiple(seqlen_q, 128) 

647 seqlen_k_rounded = round_multiple(seqlen_k, 32) 

648 

649 assert ( 

650 head_size <= 256 

651 ), "FlashAttention forward only supports head dimension at most 256" 

652 assert head_size == head_size_rounded, "head_size must be rounded to 32" 

653 

654 def splits_heuristic(num_tasks, num_sms, n_blocks): 

655 # splits when wave efficiency is low 

656 n_waves = triton.cdiv(num_tasks, num_sms) 

657 eff = (num_tasks / num_sms) / n_waves 

658 if eff > 0.8 or n_waves > 1: 

659 return 1 

660 

661 min_blocks_per_split = 2 

662 best_splits = min( 

663 triton.cdiv(n_blocks, min_blocks_per_split), 

664 int(math.floor(1.0 / eff)), 

665 num_sms, 

666 ) 

667 

668 return best_splits 

669 

670 with torch_device_fn.device(q_device): 

671 # Set softmax params 

672 lse = torch.empty( 

673 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device 

674 ) 

675 

676 if out is not None: 

677 if seqlenq_ngroups_swapped: 

678 out = out.reshape( 

679 batch_size, num_heads_k, q_groups, head_size 

680 ).transpose(1, 2) 

681 else: 

682 out = torch.empty_like(q, dtype=v.dtype) 

683 

684 # Set dropout params 

685 if p_dropout > 0: 

686 is_dropout = True 

687 increment = batch_size * num_heads * 32 

688 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

689 philox_args = torch.tensor( 

690 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

691 ) 

692 else: 

693 is_dropout = False 

694 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

695 

696 p_dropout = 1 - p_dropout 

697 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

698 rp_dropout = 1.0 / p_dropout 

699 

700 if return_softmax: 

701 assert is_dropout, "Only supported with non-zero dropout." 

702 p = torch.empty( 

703 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

704 device=q_device, 

705 ) 

706 else: 

707 p = torch.empty((), device=q_device) 

708 

709 M_LOG2E = 1.4426950408889634074 

710 if softcap > 0.0: 

711 is_softcap = True 

712 adjusted_scale_softmax = softcap 

713 adjusted_softcap = softmax_scale / softcap 

714 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

715 else: 

716 is_softcap = False 

717 adjusted_softcap = 0.0 

718 adjusted_scale_softmax = softmax_scale 

719 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

720 

721 # Set alibi params 

722 if alibi_slopes is not None: 

723 assert alibi_slopes.device == q_device 

724 assert alibi_slopes.dtype in (torch.float,) 

725 assert alibi_slopes.stride(-1) == 1 

726 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

727 batch_size, 

728 num_heads, 

729 ) 

730 alibi_slopes_batch_stride = ( 

731 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

732 ) 

733 is_alibi = True 

734 else: 

735 alibi_slopes_batch_stride = 0 

736 is_alibi = False 

737 

738 # ONLY EVEN_K IS SUPPORTED 

739 assert head_size == head_size_rounded 

740 

741 # Do kernel dispatching 

742 def dispatch(B, H, Q, K, D, params): 

743 num_sms = torch_device_fn.get_device_properties( 

744 "cuda" 

745 ).multi_processor_count 

746 

747 # Try bh parallel 

748 # if B * H > 0.8 * num_sms: 

749 # kernel = flash_fwd_bh_parallel_kernel[(H, B)] 

750 # # Yield kernel and prefilled args 

751 # return kernel, default_args, None, None 

752 

753 # Try splitkv 

754 if not is_dropout and not is_local and not disable_splitkv: 

755 BM = block_m_splitkv_heuristic(D) 

756 n_tasks = B * H * triton.cdiv(seqlen_q, BM) 

757 BN = block_n_splitkv_heuristic(D) 

758 n_blocks = triton.cdiv(seqlen_k, BN) 

759 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks) 

760 

761 if n_splits > 1: 

762 logger.debug("kernel: flash_fwd_splitkv") 

763 lse_splits = torch.empty( 

764 (n_splits, B, H, Q), dtype=torch.float, device=q_device 

765 ) 

766 out_splits = torch.empty( 

767 (n_splits, B * H, Q, D), dtype=torch.float, device=q_device 

768 ) 

769 grid = lambda args: ( 

770 triton.cdiv(Q, args["BLOCK_M"]), 

771 n_splits, 

772 B * H, 

773 ) 

774 splitkv_kernel = flash_fwd_splitkv_kernel[grid] 

775 params.o_ptr = out_splits 

776 params.softmax_lse_ptr = lse_splits 

777 params.o_row_stride = D 

778 params.o_head_stride = Q * D 

779 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)} 

780 kernel = splitkv_kernel(*params.args(), **extra_args) 

781 

782 if D >= 128: 

783 BLOCK_M = 4 

784 elif D >= 64: 

785 BLOCK_M = 8 

786 else: 

787 BLOCK_M = 16 

788 BLOCK_K = triton.next_power_of_2(D) 

789 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),) 

790 combine_kernel = flash_fwd_splitkv_combine_kernel[grid] 

791 combine_args = { 

792 "out_ptr": out, 

793 "lse_ptr": lse, 

794 "head_size": head_size, 

795 "out_split_stride": out_splits.stride(0), 

796 "lse_split_stride": lse_splits.stride(0), 

797 "out_b_stride": out.stride(0), 

798 "out_s_stride": head_size, 

799 "out_h_stride": out.stride(-1), 

800 "out_splits_ptr": out_splits, 

801 "lse_splits_ptr": lse_splits, 

802 "n_splits": n_splits, 

803 "BLOCK_M": BLOCK_M, 

804 "BLOCK_K": BLOCK_K, 

805 "q_total": B * H * Q, 

806 "MAX_N_SPLITS": triton.next_power_of_2(n_splits), 

807 } 

808 combine_kernel(**combine_args) 

809 return kernel 

810 

811 # Last option: flash_fwd 

812 logger.debug("kernel: flash_fwd") 

813 grid = lambda args: ( 

814 triton.cdiv(Q, args["BLOCK_M"]), 

815 H * B, 

816 ) 

817 kernel = flash_fwd_kernel[grid] 

818 kernel = kernel(*params.args()) 

819 return kernel 

820 

821 if _debug: 

822 p = torch.empty( 

823 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

824 dtype=torch.float32, 

825 device=q_device, 

826 ) 

827 return_softmax = True 

828 

829 params = fwd_params( 

830 q, # q_ptr, 

831 k, # k_ptr, 

832 v, # v_ptr, 

833 out, # o_ptr, 

834 p, # p_ptr, 

835 lse, # softmax_lse_ptr, 

836 q.stride(-3), # q_row_stride, 

837 k.stride(-3), # k_row_stride, 

838 v.stride(-3), # v_row_stride, 

839 q.stride(-2), # q_head_stride, 

840 k.stride(-2), # k_head_stride, 

841 v.stride(-2), # v_head_stride, 

842 out.stride(-3), # o_row_stride, 

843 out.stride(-2), # o_head_stride, 

844 q.stride(0), # q_batch_stride, 

845 k.stride(0), # k_batch_stride, 

846 v.stride(0), # v_batch_stride, 

847 out.stride(0), # o_batch_stride, 

848 False, # is_cu_seqlens_q, 

849 None, # cu_seqlens_q_ptr, 

850 False, # is_cu_seqlens_k, 

851 None, # cu_seqlens_k_ptr, 

852 False, # is_seqused_k, 

853 None, # seqused_k_ptr, 

854 # sizes 

855 batch_size, # b, 

856 0, # bk, 

857 num_heads, # h, 

858 num_heads_k, # hk, 

859 num_heads // num_heads_k, # h_hk_ratio, 

860 seqlen_q, # seqlen_q, 

861 seqlen_k, # seqlen_k, 

862 seqlen_q_rounded, # seqlen_q_rounded, 

863 seqlen_k_rounded, # seqlen_k_rounded, 

864 head_size, # d, 

865 head_size_rounded, # d_rounded, 

866 # scaling factors 

867 is_softcap, 

868 adjusted_softcap, # softcap, 

869 adjusted_scale_softmax, # scale_softmax, 

870 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

871 # dropout 

872 is_dropout, 

873 p_dropout, 

874 rp_dropout, 

875 p_dropout_in_uint8_t, 

876 philox_args, 

877 return_softmax, 

878 # causal and swa 

879 is_causal, # is_causal, 

880 is_local, # is_local, 

881 window_size_left, # window_size_left, 

882 window_size_right, # window_size_right, 

883 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

884 # alibi 

885 is_alibi, # 

886 alibi_slopes, # alibi_slopes_ptr, 

887 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

888 # block table params 

889 0, # total_q, 

890 None, # page_table_ptr, 

891 0, # page_table_batch_stride, 

892 0, # block_size, 

893 ) 

894 

895 # Move TxD to last dims for correct stride in Triton tt.load 

896 if flag_gems.vendor_name == "iluvatar": 

897 params.q_ptr = q.transpose(1, 2) 

898 params.k_ptr = k.transpose(1, 2) 

899 params.v_ptr = v.transpose(1, 2) 

900 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params) 

901 

902 if _debug: 

903 print(f"{kernel.name} shared memory:", kernel.metadata.shared) 

904 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps) 

905 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages) 

906 # print(kernel.asm['ttgir']) 

907 

908 if seqlenq_ngroups_swapped: 

909 out = out.transpose(1, 2).reshape( 

910 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

911 ) 

912 q = q.transpose(1, 2).reshape( 

913 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

914 ) 

915 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1)) 

916 

917 unused = torch.empty((), dtype=torch.int64, device=q_device) 

918 

919 return out, q, k, v, lse, philox_args, unused, p