Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/flash_kernel.py: 0%

534 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import triton 

2import triton.language as tl 

3 

4# from flag_gems import runtime 

5from flag_gems.utils import libentry, tl_extra_shim 

6 

7 

8@triton.jit 

9def u64_to_lohi(x): 

10 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32) 

11 

12 

13@triton.jit 

14def u64_from_lohi(lo, hi): 

15 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64) 

16 

17 

18@triton.jit 

19def philox_(seed, subsequence, offset): 

20 kPhilox10A: tl.constexpr = 0x9E3779B9 

21 kPhilox10B: tl.constexpr = 0xBB67AE85 

22 k0, k1 = u64_to_lohi(seed.to(tl.uint64)) 

23 c0, c1 = u64_to_lohi(offset.to(tl.uint64)) 

24 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64)) 

25 

26 # pragma unroll 

27 kPhiloxSA: tl.constexpr = 0xD2511F53 

28 kPhiloxSB: tl.constexpr = 0xCD9E8D57 

29 for _ in tl.static_range(6): 

30 res0 = kPhiloxSA * c0.to(tl.uint64) 

31 res1 = kPhiloxSB * c2.to(tl.uint64) 

32 res0_x, res0_y = u64_to_lohi(res0) 

33 res1_x, res1_y = u64_to_lohi(res1) 

34 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x 

35 k0 += kPhilox10A 

36 k1 += kPhilox10B 

37 

38 res0 = kPhiloxSA * c0.to(tl.uint64) 

39 res1 = kPhiloxSB * c2.to(tl.uint64) 

40 res0_x, res0_y = u64_to_lohi(res0) 

41 res1_x, res1_y = u64_to_lohi(res1) 

42 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x 

43 

44 return c0, c1, c2, c3 

45 

46 

47@triton.jit 

48def apply_dropout_mask( 

49 P, 

50 mask, 

51 encode_dropout_in_sign_bit: tl.constexpr, 

52): 

53 if encode_dropout_in_sign_bit: 

54 P = tl.where(mask, -P, P) 

55 else: 

56 P = tl.where(mask, (P * 0).to(P.dtype), P) 

57 return P 

58 

59 

60@triton.jit 

61def apply_dropout( 

62 P, 

63 row_start, 

64 col_start, 

65 n_cols, 

66 bid, 

67 hid, 

68 philox_seed, 

69 philox_offset, 

70 p_dropout_uint8: tl.constexpr, 

71 is_dropout: tl.constexpr, 

72 encode_dropout_in_sign_bit: tl.constexpr, 

73 NUM_HEADS: tl.constexpr, 

74 BLOCK_M: tl.constexpr, 

75 BLOCK_N: tl.constexpr, 

76): 

77 if is_dropout: 

78 row_start = tl.multiple_of(row_start, BLOCK_M) 

79 col_start = tl.multiple_of(col_start, BLOCK_N) 

80 row = row_start + tl.arange(0, BLOCK_M)[:, None] 

81 # Down scale col_idx by 4 

82 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :] 

83 

84 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64) 

85 

86 offset = philox_offset + bid * NUM_HEADS + hid 

87 offset += subsequence * 0 

88 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset) 

89 

90 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N) 

91 

92 mask = (r & 0xFF) >= p_dropout_uint8 

93 

94 P = apply_dropout_mask( 

95 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit 

96 ) 

97 return P 

98 

99 

100@triton.jit 

101def apply_alibi( 

102 S, 

103 col_idx, 

104 row_idx, 

105 max_seqlen_q, 

106 max_seqlen_k, 

107 is_causal: tl.constexpr, 

108 is_alibi: tl.constexpr, 

109 alibi_slope: tl.constexpr = None, 

110): 

111 if is_alibi: 

112 if is_causal: 

113 # The row independent alibi bias renders the same attention output 

114 # as with the standard alibi because softmax is shift invariant, i.e., 

115 # softmax(A + bias + const) = softamx(A + bias). The following two 

116 # biases are no different if causal is true. 

117 # bias_1 = [ 

118 # -4, -3, -2, X, X, 

119 # -4, -3, -2, -1, X, 

120 # -4, -3, -2, -1, 0, 

121 # ] 

122 # bias_2 = [ 

123 # -2, -1, 0, X, X, 

124 # -3, -2, -1, 0, X, 

125 # -4, -3, -2, -1, 0, 

126 # ] 

127 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32) 

128 S += bias 

129 else: 

130 bias = -alibi_slope * tl.abs( 

131 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None] 

132 ).to(tl.float32) 

133 S += bias 

134 

135 return S 

136 

137 

138@triton.jit 

139def apply_mask( 

140 S, 

141 col_idx, 

142 row_idx, 

143 max_seqlen_q, 

144 max_seqlen_k, 

145 window_size_left, 

146 window_size_right, 

147 is_even_mn: tl.constexpr, 

148 is_causal: tl.constexpr, 

149 is_local: tl.constexpr, 

150): 

151 need_mask = is_causal | is_local | (not is_even_mn) 

152 # need_mask: tl.constexpr = is_causal | is_local 

153 if need_mask: 

154 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive! 

155 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left) 

156 col_rb = min( 

157 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right 

158 ) 

159 

160 if is_causal: 

161 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S) 

162 

163 if is_local: 

164 S = tl.where( 

165 (col_idx[None, :] > col_rb[:, None]) 

166 | (col_idx[None, :] < col_lb[:, None]), 

167 float("-inf"), 

168 S, 

169 ) 

170 

171 if (not is_local) & (not is_causal) & (not is_even_mn): 

172 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S) 

173 

174 return S 

175 

176 

177@triton.jit 

178def softmax_rescale( 

179 O_acc, 

180 S, 

181 row_max, 

182 row_sum, 

183 softmax_scale_log2e: tl.constexpr, 

184 is_border: tl.constexpr, 

185 # is_init: tl.constexpr 

186): 

187 prev_max = row_max 

188 row_max = tl.maximum(row_max, tl.max(S, 1)) 

189 

190 if is_border: 

191 cur_max = tl.where(row_max == float("-inf"), 0, row_max) 

192 else: 

193 cur_max = row_max 

194 

195 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) 

196 row_sum *= p_scale 

197 O_acc *= p_scale[:, None] 

198 

199 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e) 

200 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None]) 

201 row_sum = row_sum + tl.sum(P, 1) 

202 return O_acc, P, row_max, row_sum 

203 

204 

205@triton.jit 

206def apply_softcap(S, softcap, is_softcap: tl.constexpr): 

207 if is_softcap: 

208 S = tl_extra_shim.tanh(S * softcap) 

209 

210 return S 

211 

212 

213def block_m_splitkv_heuristic(headdim): 

214 return 128 if headdim <= 128 else 64 

215 

216 

217def block_n_splitkv_heuristic(headdim): 

218 return 64 if headdim <= 64 else 32 

219 

220 

221def is_even_mn(args): 

222 if args["M"] % args["BM"] == 0 and args["N"] % args["BN"] == 0: 

223 if args["M"] % args["N"] == 0 or args["N"] % args["M"] == 0: 

224 if (args["WL"] == -1 or args["WL"] % args["BN"] == 0) and ( 

225 args["WR"] == -1 or args["WR"] % args["BN"] == 0 

226 ): 

227 return True 

228 return False 

229 

230 

231def block_m_splitkv_heuristic_spec_args(args): 

232 return 128 if args["d"] <= 128 else 64 

233 

234 

235def block_n_splitkv_heuristic_spec_args(args): 

236 return 64 if args["d"] <= 64 else 32 

237 

238 

239def is_even_mn_spec_args(args): 

240 if ( 

241 args["seqlen_q"] % args["BLOCK_M"] == 0 

242 and args["seqlen_k"] % args["BLOCK_N"] == 0 

243 ): 

244 if ( 

245 args["seqlen_q"] % args["seqlen_k"] == 0 

246 or args["seqlen_k"] % args["seqlen_q"] == 0 

247 ): 

248 if ( 

249 args["window_size_left"] == -1 

250 or args["window_size_left"] % args["BLOCK_N"] == 0 

251 ) and ( 

252 args["window_size_right"] == -1 

253 or args["window_size_right"] % args["BLOCK_N"] == 0 

254 ): 

255 return True 

256 return False 

257 

258 

259def keep(cfg, must_keep=None): 

260 BM = cfg.kwargs["BLOCK_M"] 

261 BN = cfg.kwargs["BLOCK_N"] 

262 w = cfg.num_warps 

263 

264 # we always keep configurations in `must_keep` 

265 return (BM, BN, w) in ((128, 32, 4), (128, 128, 8)) or ( 

266 must_keep and cfg in must_keep 

267 ) 

268 

269 

270def prune_fwd_configs(configs, nargs, **kwargs): 

271 is_dropout = nargs["is_dropout"] 

272 if is_dropout: 

273 return list( 

274 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs) 

275 ) 

276 else: 

277 return configs 

278 

279 

280# @libentry() 

281# @triton.autotune( 

282# configs=list(filter(keep, runtime.get_tuned_config("attention"))), 

283# prune_configs_by={"early_config_prune": prune_fwd_configs}, 

284# key=["d", "is_dropout"], 

285# ) 

286# @triton.heuristics( 

287# values={ 

288# "BLOCK_K": lambda args: triton.next_power_of_2(args["d"]), 

289# "PRE_LOAD_V": lambda args: False, 

290# "IS_EVEN_MN": is_even_mn, 

291# } 

292# ) 

293# @triton.jit( 

294# do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"] 

295# ) 

296def flash_fwd_kernel( 

297 q_ptr, 

298 k_ptr, 

299 v_ptr, 

300 o_ptr, 

301 p_ptr, 

302 softmax_lse_ptr, 

303 q_row_stride, 

304 k_row_stride, 

305 v_row_stride, 

306 q_head_stride, 

307 k_head_stride, 

308 v_head_stride, 

309 o_row_stride, 

310 o_head_stride, 

311 q_batch_stride, 

312 k_batch_stride, 

313 v_batch_stride, 

314 o_batch_stride, 

315 is_cu_seqlens_q, 

316 cu_seqlens_q_ptr, 

317 is_cu_seqlens_k, 

318 cu_seqlens_k_ptr, 

319 is_seqused_k, 

320 seqused_k_ptr, 

321 # sizes 

322 b: tl.constexpr, 

323 bk: tl.constexpr, 

324 h: tl.constexpr, 

325 hk: tl.constexpr, 

326 h_hk_ratio: tl.constexpr, 

327 seqlen_q, 

328 seqlen_k, 

329 seqlen_q_rounded, 

330 seqlen_k_rounded, 

331 d: tl.constexpr, 

332 d_rounded: tl.constexpr, 

333 # scaling factors 

334 is_softcap: tl.constexpr, 

335 softcap: tl.constexpr, 

336 scale_softmax: tl.constexpr, 

337 scale_softmax_log2: tl.constexpr, 

338 # dropout 

339 is_dropout: tl.constexpr, 

340 p_dropout: tl.constexpr, 

341 rp_dropout: tl.constexpr, 

342 p_dropout_in_uint8_t: tl.constexpr, 

343 philox_args, 

344 return_softmax: tl.constexpr, 

345 # causal and swa 

346 is_causal: tl.constexpr, 

347 is_local: tl.constexpr, 

348 window_size_left: tl.constexpr, 

349 window_size_right: tl.constexpr, 

350 seqlenq_ngroups_swapped: tl.constexpr, 

351 # alibi 

352 is_alibi: tl.constexpr, 

353 alibi_slopes_ptr, 

354 alibi_slopes_batch_stride: tl.constexpr, 

355 # block table 

356 total_q: tl.constexpr, 

357 page_table_ptr, 

358 page_table_batch_stride: tl.constexpr, 

359 block_size: tl.constexpr, 

360 # kernel params 

361 IS_EVEN_MN: tl.constexpr, 

362 PRE_LOAD_V: tl.constexpr, 

363 BLOCK_M: tl.constexpr, 

364 BLOCK_N: tl.constexpr, 

365 BLOCK_K: tl.constexpr, 

366 num_warps: tl.constexpr, 

367 num_stages: tl.constexpr, 

368): 

369 m_block = tl.program_id(0) 

370 bh = tl.program_id(1) 

371 hid = bh % h 

372 bid = bh // h 

373 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) 

374 

375 # We draw a minimum covering frame on the attention map that this CTA is assigned to process. 

376 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively. 

377 

378 col_min = 0 

379 if is_local: 

380 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left) 

381 if not IS_EVEN_MN: 

382 # round left 

383 col_min = (col_min // BLOCK_N) * BLOCK_N 

384 

385 col_max = seqlen_k 

386 if is_causal or is_local: 

387 col_max += (m_block - num_m_blocks + 1) * BLOCK_M 

388 if is_local: 

389 col_max += window_size_right 

390 col_max = min(seqlen_k, col_max) 

391 

392 if not IS_EVEN_MN: 

393 # round right 

394 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N 

395 

396 if (not is_causal) and (not is_local): 

397 if IS_EVEN_MN: 

398 masking_cols: tl.constexpr = 0 

399 else: 

400 masking_cols: tl.constexpr = BLOCK_N 

401 elif ( 

402 is_causal | is_local 

403 ) and IS_EVEN_MN: # causal implies window_size_right is zero 

404 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N 

405 else: 

406 # local 

407 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N 

408 

409 if is_dropout: 

410 philox_seed = tl.load(philox_args).to(tl.uint64) 

411 philox_offset = tl.load(philox_args + 1).to(tl.uint64) 

412 

413 if is_alibi: 

414 alibi_offset = bid * alibi_slopes_batch_stride + hid 

415 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

416 alibi_slope /= scale_softmax 

417 else: 

418 alibi_slope = 0.0 

419 

420 q_batch_stride = tl.multiple_of(q_batch_stride, d * h) 

421 q_ptr += bid * q_batch_stride + hid * q_head_stride 

422 row_start = m_block * BLOCK_M 

423 row_idx = row_start + tl.arange(0, BLOCK_M) 

424 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :] 

425 dmask = tl.arange(0, BLOCK_K) < d 

426 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q) 

427 if IS_EVEN_MN & d == BLOCK_K: 

428 Q = tl.load(q_ptr + q_off, cache_modifier=".cg") 

429 else: 

430 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg") 

431 

432 if return_softmax: 

433 p_ptr += ( 

434 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M 

435 ) * seqlen_k_rounded 

436 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange( 

437 0, BLOCK_N 

438 ) 

439 p_bp0 = p_ptr + p_offset 

440 

441 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) 

442 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

443 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

444 

445 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk) 

446 h_hk_ratio = h // hk 

447 k_ptr += bid * k_batch_stride 

448 k_ptr += (hid // h_hk_ratio) * k_head_stride 

449 v_ptr += bid * k_batch_stride 

450 v_ptr += (hid // h_hk_ratio) * k_head_stride 

451 

452 k_offset = ( 

453 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None] 

454 ) 

455 v_offset = ( 

456 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :] 

457 ) 

458 

459 p_bk0 = k_ptr + k_offset 

460 p_bv0 = v_ptr + v_offset 

461 

462 if is_causal | is_local | (not IS_EVEN_MN): 

463 # Cut short masking cols if there's not enough cols out there 

464 masking_cols = min(col_max - col_min, masking_cols) 

465 for col_shift in tl.range(0, masking_cols, step=BLOCK_N): 

466 col_start = col_max - col_shift - BLOCK_N 

467 col_start = tl.multiple_of(col_start, BLOCK_N) 

468 off = col_start * k_row_stride 

469 if IS_EVEN_MN & d == BLOCK_K: 

470 K = tl.load(p_bk0 + off, cache_modifier=".cg") 

471 if PRE_LOAD_V: 

472 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

473 elif d == BLOCK_K: 

474 col_idx = col_start + tl.arange(0, BLOCK_N) 

475 kvmask = col_idx < seqlen_k 

476 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") 

477 if PRE_LOAD_V: 

478 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") 

479 else: 

480 col_idx = col_start + tl.arange(0, BLOCK_N) 

481 kvmask = col_idx < seqlen_k 

482 K = tl.load( 

483 p_bk0 + off, 

484 mask=kvmask[None, :] & dmask[:, None], 

485 cache_modifier=".cg", 

486 ) 

487 if PRE_LOAD_V: 

488 V = tl.load( 

489 p_bv0 + off, 

490 mask=kvmask[:, None] & dmask[None, :], 

491 cache_modifier=".cg", 

492 ) 

493 S = tl.dot(Q, K, allow_tf32=False) 

494 S = apply_softcap(S, softcap, is_softcap) 

495 col_idx = col_start + tl.arange(0, BLOCK_N) 

496 row_idx = row_start + tl.arange(0, BLOCK_M) 

497 S = apply_alibi( 

498 S, 

499 col_idx, 

500 row_idx, 

501 seqlen_q, 

502 seqlen_k, 

503 is_causal=is_causal, 

504 is_alibi=is_alibi, 

505 alibi_slope=alibi_slope, 

506 ) 

507 # tl.store(p_bp0 + col_start, S) 

508 S = apply_mask( 

509 S, 

510 col_idx, 

511 row_idx, 

512 seqlen_q, 

513 seqlen_k, 

514 window_size_left, 

515 window_size_right, 

516 is_even_mn=IS_EVEN_MN, 

517 is_causal=is_causal, 

518 is_local=is_local, 

519 ) 

520 

521 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

522 acc_, 

523 S, 

524 rowmax_, 

525 rowsum_, 

526 softmax_scale_log2e=scale_softmax_log2, 

527 is_border=(is_causal or is_local), 

528 ) 

529 P = P.to(v_ptr.type.element_ty) 

530 

531 if is_dropout: 

532 if return_softmax: 

533 P_drop = P 

534 

535 P_drop = apply_dropout( 

536 P_drop, 

537 row_start, 

538 col_start, 

539 seqlen_k, 

540 bid, 

541 hid, 

542 philox_seed, 

543 philox_offset, 

544 p_dropout_in_uint8_t, 

545 is_dropout, 

546 encode_dropout_in_sign_bit=True, 

547 NUM_HEADS=h, 

548 BLOCK_M=BLOCK_M, 

549 BLOCK_N=BLOCK_N, 

550 ) 

551 if IS_EVEN_MN: 

552 tl.store(p_bp0 + col_start, P_drop) 

553 else: 

554 kvmask = col_idx < seqlen_k 

555 tl.store( 

556 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :] 

557 ) 

558 

559 P = apply_dropout( 

560 P, 

561 row_start, 

562 col_start, 

563 seqlen_k, 

564 bid, 

565 hid, 

566 philox_seed, 

567 philox_offset, 

568 p_dropout_in_uint8_t, 

569 is_dropout, 

570 encode_dropout_in_sign_bit=False, 

571 NUM_HEADS=h, 

572 BLOCK_M=BLOCK_M, 

573 BLOCK_N=BLOCK_N, 

574 ) 

575 

576 if not PRE_LOAD_V: 

577 off = col_start * k_row_stride 

578 if IS_EVEN_MN & d == BLOCK_K: 

579 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

580 elif d == BLOCK_K: 

581 kvmask = col_idx < seqlen_k 

582 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") 

583 else: 

584 kvmask = col_idx < seqlen_k 

585 V = tl.load( 

586 p_bv0 + off, 

587 mask=kvmask[:, None] & dmask[None, :], 

588 cache_modifier=".cg", 

589 ) 

590 acc_ = tl.dot(P, V, acc_, allow_tf32=False) 

591 

592 for col_start in tl.range( 

593 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages 

594 ): 

595 col_start = tl.multiple_of(col_start, BLOCK_N) 

596 off = col_start * k_row_stride 

597 if d == BLOCK_K: 

598 K = tl.load(p_bk0 + off, cache_modifier=".cg") 

599 if PRE_LOAD_V: 

600 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

601 else: 

602 K = tl.load(p_bk0 + off, mask=dmask[:, None], cache_modifier=".cg") 

603 if PRE_LOAD_V: 

604 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg") 

605 

606 S = tl.dot(Q, K) 

607 S = apply_softcap(S, softcap, is_softcap) 

608 col_idx = col_start + tl.arange(0, BLOCK_N) 

609 row_idx = row_start + tl.arange(0, BLOCK_M) 

610 S = apply_alibi( 

611 S, 

612 col_idx, 

613 row_idx, 

614 seqlen_q, 

615 seqlen_k, 

616 is_causal=is_causal, 

617 is_alibi=is_alibi, 

618 alibi_slope=alibi_slope, 

619 ) 

620 S = apply_mask( 

621 S, 

622 col_idx, 

623 row_idx, 

624 seqlen_q, 

625 seqlen_k, 

626 window_size_left, 

627 window_size_right, 

628 is_even_mn=True, 

629 is_causal=False, 

630 is_local=is_local, 

631 ) 

632 

633 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

634 acc_, 

635 S, 

636 rowmax_, 

637 rowsum_, 

638 softmax_scale_log2e=scale_softmax_log2, 

639 is_border=is_local, 

640 ) 

641 P = P.to(v_ptr.type.element_ty) 

642 

643 if is_dropout: 

644 if return_softmax: 

645 P_drop = P 

646 P_drop = apply_dropout( 

647 P_drop, 

648 row_start, 

649 col_start, 

650 seqlen_k, 

651 bid, 

652 hid, 

653 philox_seed, 

654 philox_offset, 

655 p_dropout_in_uint8_t, 

656 is_dropout, 

657 encode_dropout_in_sign_bit=True, 

658 NUM_HEADS=h, 

659 BLOCK_M=BLOCK_M, 

660 BLOCK_N=BLOCK_N, 

661 ) 

662 if IS_EVEN_MN: 

663 tl.store(p_bp0 + col_start, P_drop) 

664 else: 

665 kvmask = col_idx < seqlen_k 

666 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]) 

667 

668 P = apply_dropout( 

669 P, 

670 row_start, 

671 col_start, 

672 seqlen_k, 

673 bid, 

674 hid, 

675 philox_seed, 

676 philox_offset, 

677 p_dropout_in_uint8_t, 

678 is_dropout, 

679 encode_dropout_in_sign_bit=False, 

680 NUM_HEADS=h, 

681 BLOCK_M=BLOCK_M, 

682 BLOCK_N=BLOCK_N, 

683 ) 

684 

685 if not PRE_LOAD_V: 

686 off = col_start * k_row_stride 

687 if d == BLOCK_K: 

688 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

689 else: 

690 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg") 

691 acc_ = tl.dot(P, V, acc_) 

692 

693 # LSE 

694 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels 

695 # the effect of rowmax and outputs lse only. 

696 lse = tl.where( 

697 rowsum_ == 0 | (rowsum_ != rowsum_), 

698 float("inf"), 

699 rowmax_ * scale_softmax + tl.log(rowsum_), 

700 ) 

701 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

702 

703 if is_dropout: 

704 acc_ *= inv_sum[:, None] * rp_dropout 

705 else: 

706 acc_ *= inv_sum[:, None] 

707 

708 out = acc_.to(o_ptr.type.element_ty) # noqa 

709 

710 # Write back output 

711 o_batch_stride = tl.multiple_of(o_batch_stride, d * h) 

712 o_ptr += bid * o_batch_stride 

713 o_ptr += hid * o_head_stride 

714 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, BLOCK_K) 

715 

716 if IS_EVEN_MN & d == BLOCK_K: 

717 tl.store(o_ptr + o_offset, out) 

718 else: 

719 tl.store(o_ptr + o_offset, out, mask=qmask) 

720 

721 # Write back lse 

722 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q 

723 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

724 

725 if IS_EVEN_MN: 

726 tl.store(p_lse + row_idx, lse) 

727 else: 

728 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q) 

729 

730 

731@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"]) 

732def flash_fwd_bh_parallel_kernel(): 

733 # (TODO) 

734 pass 

735 

736 

737# @libentry() 

738# @triton.heuristics( 

739# values={ 

740# "BLOCK_M": block_m_splitkv_heuristic_spec_args, 

741# "BLOCK_N": block_n_splitkv_heuristic_spec_args, 

742# "BLOCK_K": lambda args: triton.next_power_of_2(args["d"]), 

743# "num_warps": lambda args: 4, 

744# "num_stages": lambda args: 3, 

745# "PRE_LOAD_V": lambda args: True, 

746# "IS_EVEN_MN": is_even_mn_spec_args, 

747# } 

748# ) 

749# @triton.jit( 

750# do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"] 

751# ) 

752def flash_fwd_splitkv_kernel( 

753 q_ptr, 

754 k_ptr, 

755 v_ptr, 

756 o_ptr, 

757 p_ptr, 

758 softmax_lse_ptr, 

759 q_row_stride, 

760 k_row_stride, 

761 v_row_stride, 

762 q_head_stride, 

763 k_head_stride, 

764 v_head_stride, 

765 o_row_stride, 

766 o_head_stride, 

767 q_batch_stride, 

768 k_batch_stride, 

769 v_batch_stride, 

770 o_batch_stride, 

771 is_cu_seqlens_q, 

772 cu_seqlens_q_ptr, 

773 is_cu_seqlens_k: tl.constexpr, 

774 cu_seqlens_k_ptr, 

775 is_seqused_k: tl.constexpr, 

776 seqused_k_ptr, 

777 # sizes 

778 b: tl.constexpr, 

779 bk: tl.constexpr, 

780 h: tl.constexpr, 

781 hk: tl.constexpr, 

782 h_hk_ratio: tl.constexpr, 

783 seqlen_q, 

784 seqlen_k, 

785 seqlen_q_rounded, 

786 seqlen_k_rounded, 

787 d: tl.constexpr, 

788 d_rounded: tl.constexpr, 

789 # scaling factors 

790 is_softcap: tl.constexpr, 

791 softcap: tl.constexpr, 

792 scale_softmax: tl.constexpr, 

793 scale_softmax_log2: tl.constexpr, 

794 # dropout 

795 is_dropout: tl.constexpr, 

796 p_dropout: tl.constexpr, 

797 rp_dropout: tl.constexpr, 

798 p_dropout_in_uint8_t: tl.constexpr, 

799 philox_args, 

800 return_softmax: tl.constexpr, 

801 # causal and swa 

802 is_causal: tl.constexpr, 

803 is_local: tl.constexpr, 

804 window_size_left: tl.constexpr, 

805 window_size_right: tl.constexpr, 

806 seqlenq_ngroups_swapped: tl.constexpr, 

807 # alibi 

808 is_alibi: tl.constexpr, 

809 alibi_slopes_ptr, 

810 alibi_slopes_batch_stride: tl.constexpr, 

811 # block table 

812 total_q, 

813 page_table_ptr, 

814 page_table_batch_stride: tl.constexpr, 

815 block_size: tl.constexpr, 

816 # kernel params 

817 IS_EVEN_MN: tl.constexpr, 

818 PRE_LOAD_V: tl.constexpr, 

819 blocks_per_split: tl.constexpr, 

820 BLOCK_M: tl.constexpr, 

821 BLOCK_N: tl.constexpr, 

822 BLOCK_K: tl.constexpr, 

823 num_warps: tl.constexpr, 

824 num_stages: tl.constexpr, 

825): 

826 m_block = tl.program_id(0) 

827 split_id = tl.program_id(1) 

828 bid = tl.program_id(2) // h 

829 hid = tl.program_id(2) % h 

830 

831 split_block_min = split_id * blocks_per_split 

832 split_block_max = split_block_min + blocks_per_split 

833 

834 n_block_max = tl.cdiv(seqlen_k, BLOCK_N) 

835 if is_causal: 

836 n_block_max = min( 

837 n_block_max, 

838 tl.cdiv( 

839 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right, 

840 BLOCK_N, 

841 ), 

842 ) 

843 

844 if is_alibi: 

845 alibi_offset = bid * alibi_slopes_batch_stride + hid 

846 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

847 alibi_slope /= scale_softmax 

848 else: 

849 alibi_slope = 0 

850 

851 if not is_causal: 

852 if IS_EVEN_MN: 

853 masking_block_min = n_block_max 

854 else: 

855 masking_block_min = n_block_max - 1 

856 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero 

857 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) 

858 else: 

859 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1 

860 

861 q_ptr += bid * q_batch_stride 

862 q_ptr += hid * q_head_stride 

863 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

864 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :] 

865 p_qm = q_ptr + q_off 

866 dmask = tl.arange(0, BLOCK_K) < d 

867 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q) 

868 if IS_EVEN_MN & BLOCK_K == d: 

869 Q = tl.load(p_qm, cache_modifier=".cg") 

870 else: 

871 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg") 

872 

873 h_hk_ratio = h // hk 

874 k_ptr += bid * k_batch_stride 

875 k_ptr += (hid // h_hk_ratio) * k_head_stride 

876 v_ptr += bid * k_batch_stride 

877 v_ptr += (hid // h_hk_ratio) * k_head_stride 

878 

879 k_offset = ( 

880 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None] 

881 ) 

882 p_k0 = k_ptr + k_offset 

883 

884 v_offset = ( 

885 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :] 

886 ) 

887 p_v0 = v_ptr + v_offset 

888 

889 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) 

890 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

891 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

892 

893 if split_block_max <= masking_block_min: 

894 # no masking needed 

895 for n_block in tl.range( 

896 split_block_min, split_block_max, num_stages=num_stages 

897 ): 

898 kv_off = n_block * BLOCK_N * k_row_stride 

899 if d == BLOCK_K: 

900 K = tl.load(p_k0 + kv_off, cache_modifier=".cg") 

901 else: 

902 K = tl.load( 

903 p_k0 + kv_off, mask=dmask[:, None], cache_modifier=".cg", other=0.0 

904 ) 

905 if PRE_LOAD_V: 

906 if d == BLOCK_K: 

907 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

908 else: 

909 V = tl.load( 

910 p_v0 + kv_off, 

911 mask=dmask[None, :], 

912 cache_modifier=".cg", 

913 other=0.0, 

914 ) 

915 S = tl.dot(Q, K) 

916 S = apply_softcap(S, softcap, is_softcap) 

917 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

918 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

919 S = apply_alibi( 

920 S, 

921 col_idx, 

922 row_idx, 

923 seqlen_q, 

924 seqlen_k, 

925 is_causal=is_causal, 

926 is_alibi=is_alibi, 

927 alibi_slope=alibi_slope, 

928 ) 

929 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

930 acc_, 

931 S, 

932 rowmax_, 

933 rowsum_, 

934 softmax_scale_log2e=scale_softmax_log2, 

935 is_border=False, 

936 ) 

937 

938 if not PRE_LOAD_V: 

939 if d == BLOCK_K: 

940 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

941 else: 

942 V = tl.load( 

943 p_v0 + kv_off, 

944 mask=dmask[None, :], 

945 cache_modifier=".cg", 

946 other=0.0, 

947 ) 

948 P = P.to(v_ptr.type.element_ty) 

949 acc_ = tl.dot(P, V, acc_) 

950 else: 

951 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)): 

952 kv_off = n_block * BLOCK_N * k_row_stride 

953 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

954 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

955 if IS_EVEN_MN & d == BLOCK_K: 

956 K = tl.load(p_k0 + kv_off, cache_modifier=".cg") 

957 if PRE_LOAD_V: 

958 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

959 elif d == BLOCK_K: 

960 kvmask = col_idx < seqlen_k 

961 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg") 

962 if PRE_LOAD_V: 

963 V = tl.load( 

964 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" 

965 ) 

966 else: 

967 kvmask = col_idx < seqlen_k 

968 K = tl.load( 

969 p_k0 + kv_off, 

970 mask=dmask[:, None] & kvmask[None, :], 

971 cache_modifier=".cg", 

972 other=0.0, 

973 ) 

974 if PRE_LOAD_V: 

975 V = tl.load( 

976 p_v0 + kv_off, 

977 mask=dmask[None, :] & kvmask[:, None], 

978 cache_modifier=".cg", 

979 other=0.0, 

980 ) 

981 

982 S = tl.dot(Q, K) 

983 S = apply_softcap(S, softcap, is_softcap) 

984 S = apply_alibi( 

985 S, 

986 col_idx, 

987 row_idx, 

988 seqlen_q, 

989 seqlen_k, 

990 is_causal=is_causal, 

991 is_alibi=is_alibi, 

992 alibi_slope=alibi_slope, 

993 ) 

994 S = apply_mask( 

995 S, 

996 col_idx, 

997 row_idx, 

998 seqlen_q, 

999 seqlen_k, 

1000 window_size_left, 

1001 window_size_right, 

1002 is_even_mn=IS_EVEN_MN, 

1003 is_causal=is_causal, 

1004 is_local=False, 

1005 ) 

1006 

1007 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1008 acc_, 

1009 S, 

1010 rowmax_, 

1011 rowsum_, 

1012 softmax_scale_log2e=scale_softmax_log2, 

1013 is_border=(is_causal or is_local), 

1014 ) 

1015 

1016 if not PRE_LOAD_V: 

1017 if IS_EVEN_MN & d == BLOCK_K: 

1018 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

1019 elif d == BLOCK_K: 

1020 V = tl.load( 

1021 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" 

1022 ) 

1023 else: 

1024 V = tl.load( 

1025 p_v0 + kv_off, 

1026 mask=dmask[None, :] & kvmask[:, None], 

1027 cache_modifier=".cg", 

1028 other=0.0, 

1029 ) 

1030 P = P.to(v_ptr.type.element_ty) 

1031 acc_ = tl.dot(P, V, acc_) 

1032 

1033 # LSE 

1034 lse = tl.where( 

1035 rowsum_ == 0 | (rowsum_ != rowsum_), 

1036 float("-inf"), 

1037 rowmax_ * scale_softmax + tl.log(rowsum_), 

1038 ) 

1039 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

1040 

1041 # Rescale output 

1042 acc_ *= inv_sum[:, None] 

1043 

1044 # Write back output 

1045 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size) 

1046 # grid = (seq_block, split, batch * head) 

1047 o_split_ptr = o_ptr 

1048 # + split, batch, head offsets, seq_block offsets are already added in row_idx 

1049 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d 

1050 o_split_offset = row_idx[:, None] * d + tl.arange(0, BLOCK_K) 

1051 o_split_ptr = tl.multiple_of(o_split_ptr, d) 

1052 p_om = o_split_ptr + o_split_offset 

1053 

1054 if IS_EVEN_MN & BLOCK_K == d: 

1055 tl.store(p_om, acc_, cache_modifier=".cg") 

1056 else: 

1057 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg") 

1058 

1059 # Write back lse 

1060 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q) 

1061 lse_split_ptr = softmax_lse_ptr 

1062 # + split, batch, head, seq_block offsets 

1063 lse_split_ptr += ( 

1064 split_id * tl.num_programs(2) + tl.program_id(2) 

1065 ) * seqlen_q + m_block * BLOCK_M 

1066 

1067 if IS_EVEN_MN: 

1068 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg") 

1069 else: 

1070 tl.store( 

1071 lse_split_ptr + tl.arange(0, BLOCK_M), 

1072 lse, 

1073 mask=row_idx < seqlen_q, 

1074 cache_modifier=".cg", 

1075 ) 

1076 

1077 

1078@libentry() 

1079@triton.jit 

1080def flash_fwd_splitkv_combine_kernel( 

1081 out_ptr, 

1082 lse_ptr, 

1083 out_splits_ptr, 

1084 lse_splits_ptr, 

1085 head_size: tl.constexpr, 

1086 out_split_stride, 

1087 lse_split_stride, 

1088 out_b_stride, 

1089 out_s_stride, 

1090 out_h_stride, 

1091 n_splits, 

1092 BLOCK_M: tl.constexpr, 

1093 BLOCK_K: tl.constexpr, 

1094 q_total, 

1095 MAX_N_SPLITS: tl.constexpr, 

1096): 

1097 pid = tl.program_id(0) 

1098 lse_splits_ptr += pid * BLOCK_M 

1099 lse_ptr += pid * BLOCK_M 

1100 out_splits_ptr += pid * BLOCK_M * head_size 

1101 out_ptr += pid * BLOCK_M * head_size 

1102 

1103 # Subtracting maximum from each of the split lse's for better numerical stability 

1104 lse_split_offset = ( 

1105 tl.arange(0, BLOCK_M)[:, None] 

1106 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride 

1107 ) 

1108 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & ( 

1109 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits 

1110 ) 

1111 lse_splits = tl.load( 

1112 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf") 

1113 ) 

1114 max_lse = tl.max(lse_splits, 1) 

1115 

1116 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse) 

1117 Zi_scaled = tl.exp(lse_splits - max_lse[:, None]) 

1118 Z_scaled = tl.sum(Zi_scaled, 1) 

1119 Zi_Z = Zi_scaled / Z_scaled[:, None] 

1120 

1121 # Write back LSE 

1122 lse = tl.log(Z_scaled) + max_lse 

1123 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total 

1124 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask) 

1125 

1126 out_split_offset = ( 

1127 tl.arange(0, BLOCK_M)[:, None, None] * head_size 

1128 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride 

1129 + tl.arange(0, BLOCK_K)[None, None, :] 

1130 ) 

1131 out_split_mask = ( 

1132 (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total) 

1133 & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits) 

1134 & (tl.arange(0, BLOCK_K)[None, None, :] < head_size) 

1135 ) 

1136 out_splits = tl.load( 

1137 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0.0 

1138 ) 

1139 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1) 

1140 out = out.to(out_ptr.type.element_ty) 

1141 

1142 # Write back output 

1143 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, BLOCK_K) 

1144 dmask = tl.arange(0, BLOCK_K) < head_size 

1145 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None] & dmask[None, :]) 

1146 

1147 

1148@triton.jit 

1149def virtual_to_cache(virtual_index, page_table_ptr, block_size): 

1150 # virtual_index is the kv sequence index in the current batch element 

1151 # page_table_ptr is already pointed at current batch element's block table entry 

1152 # block_size is the size of each block in the page table 

1153 virtual_page_index = virtual_index // block_size 

1154 page_offset = virtual_index % block_size 

1155 page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32) 

1156 return page_block_index * block_size + page_offset 

1157 

1158 

1159@triton.jit 

1160def load_from_kvcache( 

1161 i, 

1162 page_table_ptr, 

1163 k_ptr_base, 

1164 v_ptr_base, 

1165 block_size, 

1166 d, 

1167 k_row_stride, 

1168 BLOCK_K: tl.constexpr, 

1169): 

1170 kvcache_idx = virtual_to_cache(i, page_table_ptr, block_size) 

1171 k_offset = tl.arange(0, BLOCK_K)[:, None] + kvcache_idx[None, :] * k_row_stride 

1172 v_offset = tl.arange(0, BLOCK_K)[None, :] + kvcache_idx[:, None] * k_row_stride 

1173 bK = tl.load( 

1174 k_ptr_base + k_offset, mask=tl.arange(0, BLOCK_K)[:, None] < d, other=0.0 

1175 ) 

1176 bV = tl.load( 

1177 v_ptr_base + v_offset, mask=tl.arange(0, BLOCK_K)[None, :] < d, other=0.0 

1178 ) 

1179 return bK, bV 

1180 

1181 

1182@libentry() 

1183@triton.jit( 

1184 do_not_specialize=[ 

1185 "q_batch_stride", 

1186 "k_batch_stride", 

1187 "v_batch_stride", 

1188 "o_batch_stride", 

1189 "b", 

1190 "bk", 

1191 "seqlen_q", 

1192 "seqlen_k", 

1193 "seqlen_q_rounded", 

1194 "seqlen_k_rounded", 

1195 "total_q", 

1196 ] 

1197) 

1198def flash_varlen_fwd_kernel( 

1199 q_ptr, 

1200 k_ptr, 

1201 v_ptr, 

1202 o_ptr, 

1203 p_ptr, 

1204 softmax_lse_ptr, 

1205 q_row_stride, 

1206 k_row_stride, 

1207 v_row_stride, 

1208 q_head_stride, 

1209 k_head_stride, 

1210 v_head_stride, 

1211 o_row_stride, 

1212 o_head_stride, 

1213 q_batch_stride, 

1214 k_batch_stride, 

1215 v_batch_stride, 

1216 o_batch_stride, 

1217 is_cu_seqlens_q: tl.constexpr, 

1218 cu_seqlens_q_ptr, 

1219 is_cu_seqlens_k: tl.constexpr, 

1220 cu_seqlens_k_ptr, 

1221 is_seqused_k: tl.constexpr, 

1222 seqused_k_ptr, 

1223 # sizes 

1224 b, 

1225 bk, 

1226 h: tl.constexpr, 

1227 hk: tl.constexpr, 

1228 h_hk_ratio: tl.constexpr, 

1229 seqlen_q, 

1230 seqlen_k, 

1231 seqlen_q_rounded, 

1232 seqlen_k_rounded, 

1233 d: tl.constexpr, 

1234 d_rounded: tl.constexpr, 

1235 # scaling factors 

1236 is_softcap: tl.constexpr, 

1237 softcap: tl.constexpr, 

1238 scale_softmax: tl.constexpr, 

1239 scale_softmax_log2: tl.constexpr, 

1240 # dropout 

1241 is_dropout: tl.constexpr, 

1242 p_dropout: tl.constexpr, 

1243 rp_dropout: tl.constexpr, 

1244 p_dropout_in_uint8_t: tl.constexpr, 

1245 philox_args, 

1246 return_softmax: tl.constexpr, 

1247 # causal and swa 

1248 is_causal: tl.constexpr, 

1249 is_local: tl.constexpr, 

1250 window_size_left: tl.constexpr, 

1251 window_size_right: tl.constexpr, 

1252 seqlenq_ngroups_swapped: tl.constexpr, 

1253 # alibi 

1254 is_alibi: tl.constexpr, 

1255 alibi_slopes_ptr, 

1256 alibi_slopes_batch_stride: tl.constexpr, 

1257 # block table 

1258 total_q, 

1259 page_table_ptr, 

1260 page_table_batch_stride: tl.constexpr, 

1261 block_size: tl.constexpr, 

1262 # kernel params 

1263 BLOCK_M: tl.constexpr, 

1264 BLOCK_N: tl.constexpr, 

1265 BLOCK_K: tl.constexpr, 

1266 num_warps: tl.constexpr, 

1267 num_stages: tl.constexpr, 

1268): 

1269 m_block = tl.program_id(0) 

1270 bid = tl.program_id(1) 

1271 hid = tl.program_id(2) 

1272 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) 

1273 

1274 if is_cu_seqlens_q: 

1275 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32) 

1276 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32) 

1277 q_len = q_eos - q_bos 

1278 # Current request's start offset in the batched Q 

1279 q_offset = q_bos * q_row_stride 

1280 o_offset = q_bos * o_row_stride 

1281 lse_offset = q_bos * 1 

1282 else: 

1283 q_len = seqlen_q 

1284 q_offset = bid * q_batch_stride 

1285 o_offset = bid * o_batch_stride 

1286 lse_offset = bid * seqlen_q 

1287 

1288 if is_cu_seqlens_k: 

1289 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32) 

1290 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32) 

1291 k_len_cache = k_eos - k_bos 

1292 # k_offset = k_bos * k_row_stride 

1293 else: 

1294 k_len_cache = seqlen_k 

1295 # k_offset = bid * k_batch_stride 

1296 

1297 if is_seqused_k: 

1298 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32) 

1299 else: 

1300 k_len = k_len_cache 

1301 

1302 # Noop CTA 

1303 if m_block * BLOCK_M > q_len: 

1304 return 

1305 

1306 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0) 

1307 is_even_mn: tl.constexpr = False 

1308 

1309 if is_local: 

1310 n_block_min = max( 

1311 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N 

1312 ) 

1313 else: 

1314 n_block_min = 0 

1315 

1316 n_block_max = tl.cdiv(k_len, BLOCK_N) 

1317 if is_causal or is_local: 

1318 n_block_max = min( 

1319 n_block_max, 

1320 tl.cdiv( 

1321 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N 

1322 ), 

1323 ) 

1324 

1325 if is_dropout: 

1326 philox_seed = tl.load(philox_args).to(tl.uint64) 

1327 philox_offset = tl.load(philox_args + 1).to(tl.uint64) 

1328 

1329 # Locate the page table entry for the current batch element 

1330 page_table_ptr += bid * page_table_batch_stride 

1331 # Calculate the starting offset of q for the current head 

1332 q_row_offset = hid * q_head_stride 

1333 # Calculate the starting offset of k and v for the current head 

1334 k_row_offset = (hid // h_hk_ratio) * k_head_stride 

1335 # Shift the k, v pointers to align with the current head 

1336 k_ptr_base = k_ptr + k_row_offset 

1337 v_ptr_base = v_ptr + k_row_offset 

1338 

1339 gQ = tl.make_block_ptr( 

1340 base=q_ptr + q_offset + q_row_offset, 

1341 shape=(q_len, d), 

1342 strides=(q_row_stride, 1), 

1343 offsets=(0, 0), 

1344 block_shape=(BLOCK_M, BLOCK_K), 

1345 order=(0, 1), 

1346 ) 

1347 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0, 1)) 

1348 

1349 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) 

1350 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

1351 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

1352 

1353 if is_alibi: 

1354 alibi_offset = bid * alibi_slopes_batch_stride + hid 

1355 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

1356 alibi_slope /= scale_softmax 

1357 else: 

1358 alibi_slope = 0.0 

1359 

1360 if not is_causal and not is_local: 

1361 n_masking_steps = 1 

1362 elif is_even_mn: 

1363 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) 

1364 else: 

1365 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1 

1366 

1367 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps) 

1368 

1369 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

1370 n_block = n_block_max - 1 

1371 for step in tl.range(0, n_masking_steps): 

1372 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

1373 bK, bV = load_from_kvcache( 

1374 col_idx, 

1375 page_table_ptr, 

1376 k_ptr_base, 

1377 v_ptr_base, 

1378 block_size, 

1379 d, 

1380 k_row_stride, 

1381 BLOCK_K=BLOCK_K, 

1382 ) 

1383 S = tl.dot(bQ, bK, out_dtype=tl.float32) 

1384 S = apply_softcap(S, softcap, is_softcap) 

1385 S = apply_alibi( 

1386 S, 

1387 col_idx, 

1388 row_idx, 

1389 q_len, 

1390 k_len, 

1391 is_causal=is_causal, 

1392 is_alibi=is_alibi, 

1393 alibi_slope=alibi_slope, 

1394 ) 

1395 S = apply_mask( 

1396 S, 

1397 col_idx, 

1398 row_idx, 

1399 q_len, 

1400 k_len, 

1401 window_size_left, 

1402 window_size_right, 

1403 is_even_mn=is_even_mn, 

1404 is_causal=is_causal, 

1405 is_local=is_local, 

1406 ) 

1407 

1408 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1409 acc_, 

1410 S, 

1411 rowmax_, 

1412 rowsum_, 

1413 softmax_scale_log2e=scale_softmax_log2, 

1414 is_border=True, 

1415 ) 

1416 P = P.to(v_ptr.type.element_ty) 

1417 

1418 if is_dropout: 

1419 P = apply_dropout( 

1420 P, 

1421 n_block * BLOCK_N, 

1422 m_block * BLOCK_M, 

1423 k_len, 

1424 bid, 

1425 hid, 

1426 philox_seed, 

1427 philox_offset, 

1428 p_dropout_in_uint8_t, 

1429 is_dropout, 

1430 encode_dropout_in_sign_bit=False, 

1431 NUM_HEADS=h, 

1432 BLOCK_M=BLOCK_M, 

1433 BLOCK_N=BLOCK_N, 

1434 ) 

1435 

1436 acc_ = tl.dot(P, bV, acc_) 

1437 n_block -= 1 

1438 

1439 for n_block in tl.range( 

1440 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1 

1441 ): 

1442 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

1443 bK, bV = load_from_kvcache( 

1444 col_idx, 

1445 page_table_ptr, 

1446 k_ptr_base, 

1447 v_ptr_base, 

1448 block_size, 

1449 d, 

1450 k_row_stride, 

1451 BLOCK_K=BLOCK_K, 

1452 ) 

1453 S = tl.dot(bQ, bK, out_dtype=tl.float32) 

1454 S = apply_softcap(S, softcap, is_softcap) 

1455 S = apply_alibi( 

1456 S, 

1457 col_idx, 

1458 row_idx, 

1459 q_len, 

1460 k_len, 

1461 is_causal=is_causal, 

1462 is_alibi=is_alibi, 

1463 alibi_slope=alibi_slope, 

1464 ) 

1465 S = apply_mask( 

1466 S, 

1467 col_idx, 

1468 row_idx, 

1469 q_len, 

1470 k_len, 

1471 window_size_left, 

1472 window_size_right, 

1473 is_even_mn=True, 

1474 is_causal=False, 

1475 is_local=is_local, 

1476 ) 

1477 

1478 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1479 acc_, 

1480 S, 

1481 rowmax_, 

1482 rowsum_, 

1483 softmax_scale_log2e=scale_softmax_log2, 

1484 is_border=is_local, 

1485 ) 

1486 P = P.to(v_ptr.type.element_ty) 

1487 

1488 if is_dropout: 

1489 P = apply_dropout( 

1490 P, 

1491 m_block * BLOCK_M, 

1492 n_block * BLOCK_N, 

1493 k_len, 

1494 bid, 

1495 hid, 

1496 philox_seed, 

1497 philox_offset, 

1498 p_dropout_in_uint8_t, 

1499 is_dropout, 

1500 encode_dropout_in_sign_bit=False, 

1501 NUM_HEADS=h, 

1502 BLOCK_M=BLOCK_M, 

1503 BLOCK_N=BLOCK_N, 

1504 ) 

1505 acc_ = tl.dot(P, bV, acc_) 

1506 

1507 # LSE 

1508 lse = tl.where( 

1509 rowsum_ == 0 | (rowsum_ != rowsum_), 

1510 float("inf"), 

1511 rowmax_ * scale_softmax + tl.log(rowsum_), 

1512 ) 

1513 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

1514 

1515 acc_ *= inv_sum[:, None] 

1516 

1517 out = acc_.to(o_ptr.type.element_ty) # noqa 

1518 

1519 # Write back output 

1520 o_row_offset = hid * o_head_stride 

1521 

1522 gO = tl.make_block_ptr( 

1523 base=o_ptr + o_offset + o_row_offset, 

1524 shape=(q_len, d), 

1525 strides=(o_row_stride, 1), 

1526 offsets=(0, 0), 

1527 block_shape=(BLOCK_M, BLOCK_K), 

1528 order=(0, 1), 

1529 ) 

1530 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0, 1)) 

1531 

1532 # Write back lse 

1533 # lse shape: [h, total_q] 

1534 softmax_lse_ptr += hid * total_q 

1535 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

1536 tl.store( 

1537 softmax_lse_ptr + lse_row_offset, 

1538 lse, 

1539 mask=lse_row_offset < (lse_offset + q_len), 

1540 )