Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/attention.py: 0%

382 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2import math 

3from functools import partial 

4 

5import torch 

6 

7# import torch.nn.functional as F 

8import triton 

9import triton.language as tl 

10 

11from flag_gems import runtime 

12from flag_gems.config import use_c_extension 

13from flag_gems.runtime import torch_device_fn 

14 

15from .flash_api import mha_fwd, mha_varlan_fwd 

16from .flash_kernel import keep 

17 

18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

19 

20 

21# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 

22@triton.jit 

23def _attn_fwd_inner( 

24 acc, 

25 l_i, 

26 m_i, 

27 query, # 

28 K_block_ptr, 

29 V_block_ptr, # 

30 mask_block_ptr, # 

31 stride_k_seqlen, 

32 stride_v_seqlen, 

33 stride_attn_mask_kv_seqlen, # 

34 start_m, 

35 qk_scale, # 

36 q_load_mask, 

37 BLOCK_M: tl.constexpr, 

38 HEAD_DIM: tl.constexpr, 

39 BLOCK_N: tl.constexpr, # 

40 STAGE: tl.constexpr, 

41 offs_m: tl.constexpr, 

42 offs_n: tl.constexpr, # 

43 KV_CTX: tl.constexpr, 

44 fp8_v: tl.constexpr, 

45 HAS_ATTN_MASK: tl.constexpr, 

46 PRE_LOAD_V: tl.constexpr, 

47): 

48 # range of values handled by this stage 

49 if STAGE == 1: 

50 lo, hi = 0, start_m * BLOCK_M 

51 elif STAGE == 2: 

52 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 

53 # causal = False 

54 else: 

55 lo, hi = 0, KV_CTX 

56 

57 K_block_ptr += lo * stride_k_seqlen 

58 V_block_ptr += lo * stride_v_seqlen 

59 if HAS_ATTN_MASK: 

60 mask_block_ptr += lo * stride_attn_mask_kv_seqlen 

61 

62 LOG2E = 1.44269504 # log2(e) constant 

63 

64 # loop over key, value and update accumulator 

65 for start_n in range(lo, hi, BLOCK_N): 

66 kv_load_mask = (start_n + offs_n) < KV_CTX 

67 # start_n = tl.multiple_of(start_n, BLOCK_N) 

68 # -- compute qk ---- 

69 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0) 

70 if PRE_LOAD_V: 

71 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

72 

73 qk = tl.dot(query, key, allow_tf32=False) 

74 # incase not divisible. 

75 qk = tl.where(kv_load_mask[None, :], qk, -float("inf")) 

76 # qk = qk.to(tl.float32) 

77 

78 if HAS_ATTN_MASK: 

79 attn_mask = tl.load( 

80 mask_block_ptr, 

81 mask=q_load_mask[:, None] & kv_load_mask[None, :], 

82 other=0.0, 

83 ) 

84 

85 if STAGE == 2: 

86 mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 

87 

88 if HAS_ATTN_MASK: 

89 qk = qk * qk_scale + attn_mask 

90 qk *= LOG2E 

91 qk = qk + tl.where(mask, 0, -1.0e6) 

92 else: 

93 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6) 

94 

95 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

96 qk -= m_ij[:, None] 

97 else: 

98 qk *= qk_scale * LOG2E 

99 if HAS_ATTN_MASK: 

100 qk = qk + attn_mask 

101 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

102 qk = qk - m_ij[:, None] 

103 

104 p = tl.math.exp2(qk) 

105 l_ij = tl.sum(p, 1) 

106 # -- update m_i and l_i 

107 alpha = tl.math.exp2(m_i - m_ij) 

108 l_i = l_i * alpha + l_ij 

109 # -- update output accumulator -- 

110 acc = acc * alpha[:, None] 

111 # update acc 

112 if not PRE_LOAD_V: 

113 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

114 if fp8_v: 

115 p = p.to(tl.float8e5) 

116 else: 

117 p = p.to(query.dtype) 

118 p = p.to(value.dtype) 

119 acc = tl.dot(p, value, acc, allow_tf32=False) 

120 # update m_i and l_i 

121 m_i = m_ij 

122 

123 K_block_ptr += BLOCK_N * stride_k_seqlen 

124 V_block_ptr += BLOCK_N * stride_v_seqlen 

125 

126 if HAS_ATTN_MASK: 

127 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen 

128 

129 return acc, l_i, m_i 

130 

131 

132# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim, 

133# we need to generate more configs. 

134configs = runtime.get_tuned_config("attention") 

135SMALL_HEAD_DIM_CONFIGS = [ 

136 triton.Config( 

137 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w 

138 ) 

139 for BM in [64, 128] 

140 for BN in [16, 32] 

141 for s in [2, 3, 4] 

142 for w in [4, 8] 

143] 

144configs += SMALL_HEAD_DIM_CONFIGS 

145 

146 

147@triton.autotune( 

148 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)), 

149 key=["KV_CTX", "HEAD_DIM"], 

150) 

151@triton.jit 

152def _attn_fwd( 

153 Q, 

154 K, 

155 V, 

156 attn_mask, 

157 sm_scale, 

158 M, 

159 Out, # 

160 stride_q_batch, 

161 stride_q_head, 

162 stride_q_seqlen, 

163 stride_q_headsize, 

164 stride_k_batch, 

165 stride_k_head, 

166 stride_k_seqlen, 

167 stride_k_headsize, 

168 stride_v_batch, 

169 stride_v_head, 

170 stride_v_seqlen, 

171 stride_v_headsize, 

172 stride_attn_mask_batch, 

173 stride_attn_mask_head, 

174 stride_attn_mask_q_seqlen, 

175 stride_attn_mask_kv_seqlen, 

176 stride_o_batch, 

177 stride_o_head, 

178 stride_o_seqlen, 

179 stride_o_headsize, 

180 Z, 

181 q_head_num, 

182 kv_head_num, 

183 GROUP_HEAD: tl.constexpr, 

184 Q_CTX, 

185 KV_CTX, 

186 HEAD_DIM: tl.constexpr, 

187 BLOCK_M: tl.constexpr, 

188 BLOCK_N: tl.constexpr, 

189 STAGE: tl.constexpr, 

190 HAS_ATTN_MASK: tl.constexpr, 

191 PRE_LOAD_V: tl.constexpr, 

192): 

193 tl.static_assert(BLOCK_N <= HEAD_DIM) 

194 start_m = tl.program_id(0) 

195 off_hz = tl.program_id(1) 

196 batch_id = off_hz // q_head_num 

197 head_id = off_hz % q_head_num 

198 kv_head_id = head_id // GROUP_HEAD 

199 

200 q_offset = ( 

201 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head 

202 ) 

203 o_offset = ( 

204 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head 

205 ) 

206 kv_offset = ( 

207 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head 

208 ) 

209 

210 offs_headsize = tl.arange(0, HEAD_DIM) 

211 

212 # initialize offsets 

213 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 

214 q_load_mask = offs_m < Q_CTX 

215 offs_n = tl.arange(0, BLOCK_N) 

216 

217 Q_block_ptr = ( 

218 Q 

219 + q_offset 

220 + offs_m[:, None] * stride_q_seqlen 

221 + offs_headsize[None, :] * stride_q_headsize 

222 ) 

223 K_block_ptr = ( 

224 K 

225 + kv_offset 

226 + offs_n[None, :] * stride_k_seqlen 

227 + offs_headsize[:, None] * stride_k_headsize 

228 ) 

229 V_block_ptr = ( 

230 V 

231 + kv_offset 

232 + offs_n[:, None] * stride_v_seqlen 

233 + offs_headsize[None, :] * stride_v_headsize 

234 ) 

235 

236 if HAS_ATTN_MASK: 

237 attn_mask_offset = ( 

238 batch_id.to(tl.int64) * stride_attn_mask_batch 

239 + head_id.to(tl.int64) * stride_attn_mask_head 

240 ) 

241 mask_block_ptr = ( 

242 attn_mask 

243 + attn_mask_offset 

244 + offs_m[:, None] * stride_attn_mask_q_seqlen 

245 + offs_n[None, :] * stride_attn_mask_kv_seqlen 

246 ) 

247 else: 

248 mask_block_ptr = None 

249 

250 O_block_ptr = ( 

251 Out 

252 + o_offset 

253 + offs_m[:, None] * stride_o_seqlen 

254 + offs_headsize[None, :] * stride_o_headsize 

255 ) 

256 

257 # initialize pointer to m and l 

258 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 

259 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 

260 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 

261 # load scales 

262 qk_scale = sm_scale 

263 # qk_scale *= 1.44269504 # 1/log(2) 

264 # load query: it will stay in SRAM throughout 

265 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0) 

266 # stage 1: off-band 

267 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 

268 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 

269 if STAGE & 1: 

270 acc, l_i, m_i = _attn_fwd_inner( 

271 acc, 

272 l_i, 

273 m_i, 

274 query, 

275 K_block_ptr, 

276 V_block_ptr, 

277 mask_block_ptr, 

278 stride_k_seqlen, 

279 stride_v_seqlen, 

280 stride_attn_mask_kv_seqlen, 

281 start_m, 

282 qk_scale, 

283 q_load_mask, 

284 BLOCK_M, 

285 HEAD_DIM, 

286 BLOCK_N, 

287 4 - STAGE, 

288 offs_m, 

289 offs_n, 

290 KV_CTX, 

291 V.dtype.element_ty == tl.float8e5, 

292 HAS_ATTN_MASK, 

293 PRE_LOAD_V, 

294 ) 

295 # stage 2: on-band 

296 if STAGE & 2: 

297 # barrier makes it easier for compielr to schedule the 

298 # two loops independently 

299 acc, l_i, m_i = _attn_fwd_inner( 

300 acc, 

301 l_i, 

302 m_i, 

303 query, 

304 K_block_ptr, 

305 V_block_ptr, 

306 mask_block_ptr, 

307 stride_k_seqlen, 

308 stride_v_seqlen, 

309 stride_attn_mask_kv_seqlen, 

310 start_m, 

311 qk_scale, 

312 q_load_mask, 

313 BLOCK_M, 

314 HEAD_DIM, 

315 BLOCK_N, 

316 2, 

317 offs_m, 

318 offs_n, 

319 KV_CTX, 

320 V.dtype.element_ty == tl.float8e5, 

321 HAS_ATTN_MASK, 

322 PRE_LOAD_V, 

323 ) 

324 # epilogue 

325 m_i += tl.math.log2(l_i) 

326 acc = acc / l_i[:, None] 

327 m_ptrs = M + off_hz * Q_CTX + offs_m 

328 tl.store(m_ptrs, m_i, mask=q_load_mask) 

329 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) 

330 

331 

332@triton.jit 

333def _attn_bwd_preprocess( 

334 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr 

335): 

336 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 

337 mask = off_m < Q_CTX 

338 

339 off_hz = tl.program_id(1) 

340 off_n = tl.arange(0, D_HEAD) 

341 # load 

342 o = tl.load( 

343 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

344 mask=mask[:, None], 

345 other=0.0, 

346 ) 

347 do = tl.load( 

348 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

349 mask=mask[:, None], 

350 other=0.0, 

351 ).to(tl.float32) 

352 delta = tl.sum(o * do, axis=1) 

353 # write-back 

354 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask) 

355 

356 

357# The main inner-loop logic for computing dK and dV. 

358@triton.jit 

359def _attn_bwd_dkdv( 

360 dk, 

361 dv, # 

362 Q, 

363 key, 

364 value, 

365 sm_scale, # 

366 DO, # 

367 M, 

368 D, # 

369 # shared by Q/K/V/DO. 

370 stride_tok, 

371 stride_d, # 

372 H, 

373 Q_CTX, 

374 KV_CTX, 

375 BLOCK_M1: tl.constexpr, # 

376 BLOCK_N1: tl.constexpr, # 

377 BLOCK_DMODEL: tl.constexpr, # 

378 # Filled in by the wrapper. 

379 start_n, 

380 start_m, 

381 num_steps, # 

382 MASK: tl.constexpr, 

383): 

384 # BLOCK_M1: 32 

385 # BLOCK_N1: 128 

386 offs_n = start_n + tl.arange(0, BLOCK_N1) 

387 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, ) 

388 

389 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, ) 

390 

391 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 

392 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 

393 curr_m = start_m 

394 step_m = BLOCK_M1 

395 for blk_idx in range(num_steps): 

396 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, ) 

397 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, ) 

398 

399 qT_ptrs = ( 

400 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d 

401 ) # (BLOCK_DMODEL, BLOCK_M1) 

402 do_ptrs = ( 

403 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

404 ) # (BLOCK_M1, BLOCK_DMODEL) 

405 

406 qT = tl.load( 

407 qT_ptrs, mask=offs_m_mask[None, :], other=0.0 

408 ) # (BLOCK_DMODEL, BLOCK_M1) 

409 

410 # Load m before computing qk to reduce pipeline stall. 

411 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, ) 

412 

413 # key: (BLOCK_N1, BLOCK_DMODEL) 

414 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1) 

415 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1) 

416 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1) 

417 pT = tl.math.exp2(qkT - m) 

418 # pT = tl.math.exp2(qkT - m[None, :]) 

419 

420 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[ 

421 :, None 

422 ] # (BLOCK_N1, BLOCK_M1) 

423 # Autoregressive masking. 

424 if MASK: 

425 mask &= offs_m[None, :] >= offs_n[:, None] 

426 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1) 

427 

428 do = tl.load(do_ptrs) 

429 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL) 

430 

431 # Compute dV. 

432 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL) 

433 # D (= delta) is pre-divided by ds_scale. 

434 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, ) 

435 

436 # Compute dP and dS. 

437 dpT = tl.dot(value, tl.trans(do)).to( 

438 tl.float32 

439 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1) 

440 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1) 

441 dsT = dsT.to(qT.dtype) 

442 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1) 

443 dsT = tl.where( 

444 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0 

445 ) # (BLOCK_N1, BLOCK_M1) 

446 dk += tl.dot( 

447 dsT, tl.trans(qT) 

448 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL) 

449 # Increment pointers. 

450 curr_m += step_m 

451 return dk, dv 

452 

453 

454# the main inner-loop logic for computing dQ 

455@triton.jit 

456def _attn_bwd_dq( 

457 dq, 

458 query, 

459 K, 

460 V, # 

461 do, 

462 m, 

463 D, 

464 # shared by Q/K/V/DO. 

465 stride_tok, 

466 stride_d, # 

467 H, 

468 Q_CTX, # 

469 KV_CTX, # 

470 BLOCK_M2: tl.constexpr, # 

471 BLOCK_N2: tl.constexpr, # 

472 BLOCK_DMODEL: tl.constexpr, 

473 # Filled in by the wrapper. 

474 start_m, 

475 start_n, 

476 num_steps, # 

477 MASK: tl.constexpr, 

478): 

479 offs_m = start_m + tl.arange(0, BLOCK_M2) 

480 offs_m_mask = offs_m < Q_CTX 

481 

482 offs_k = tl.arange(0, BLOCK_DMODEL) 

483 # D (= delta) is pre-divided by ds_scale. 

484 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) 

485 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 

486 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 

487 curr_n = start_n 

488 step_n = BLOCK_N2 

489 for blk_idx in range(num_steps): 

490 offs_n = curr_n + tl.arange(0, BLOCK_N2) 

491 offs_n_mask = offs_n < KV_CTX 

492 

493 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

494 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

495 

496 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

497 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

498 qk = tl.dot(query, kT) 

499 p = tl.math.exp2(qk - m) 

500 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :] 

501 # Autoregressive masking. 

502 if MASK: 

503 # mask = (offs_m[:, None] >= offs_n[None, :]) 

504 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :] 

505 mask &= offs_m[:, None] >= offs_n[None, :] 

506 p = tl.where(mask, p, 0.0) 

507 # Compute dP and dS. 

508 dp = tl.dot(do, vT).to(tl.float32) 

509 ds = p * (dp - Di[:, None]) 

510 ds = tl.where(mask, ds, 0.0).to(kT.dtype) 

511 # Compute dQ. 

512 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 

513 dq += tl.dot(ds, tl.trans(kT)) 

514 # Increment pointers. 

515 curr_n += step_n 

516 return dq 

517 

518 

519@triton.jit 

520def _attn_bwd( 

521 Q, 

522 K, 

523 V, 

524 sm_scale, # 

525 DO, # 

526 DQ, 

527 DK, 

528 DV, # 

529 M, 

530 D, 

531 # shared by Q/K/V/DO. 

532 stride_z, 

533 stride_h, 

534 stride_tok, 

535 stride_d, # 

536 kv_stride_z, 

537 kv_stride_h, # 

538 H, # query head num 

539 Q_CTX, # 

540 KV_CTX, # 

541 kv_head_num, # 

542 GROUP_HEAD: tl.constexpr, # 

543 BLOCK_M1: tl.constexpr, # 

544 BLOCK_N1: tl.constexpr, # 

545 BLOCK_M2: tl.constexpr, # 

546 BLOCK_N2: tl.constexpr, # 

547 BLK_SLICE_FACTOR: tl.constexpr, # 

548 BLOCK_DMODEL: tl.constexpr, 

549): 

550 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.") 

551 

552 LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 

553 

554 bhid = tl.program_id(2) 

555 off_chz = (bhid * Q_CTX).to(tl.int64) 

556 batch_id = bhid // H 

557 q_head_id = bhid % H 

558 kv_head_id = q_head_id // GROUP_HEAD 

559 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64) 

560 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64) 

561 

562 pid = tl.program_id(0) 

563 

564 # offset pointers for batch/head 

565 Q += adj 

566 K += kv_adj 

567 V += kv_adj 

568 DO += adj 

569 DQ += adj 

570 DK += adj 

571 DV += adj 

572 M += off_chz 

573 D += off_chz 

574 

575 # load scales 

576 offs_k = tl.arange(0, BLOCK_DMODEL) 

577 

578 start_n = pid * BLOCK_N1 

579 start_m = start_n 

580 

581 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 

582 offs_n = start_n + tl.arange(0, BLOCK_N1) 

583 offs_n_mask = offs_n < KV_CTX 

584 

585 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

586 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

587 

588 # load K and V: they stay in SRAM throughout the inner loop. 

589 key = tl.load( 

590 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

591 mask=offs_n_mask[:, None], 

592 other=0.0, 

593 ) 

594 value = tl.load( 

595 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

596 mask=offs_n_mask[:, None], 

597 other=0.0, 

598 ) 

599 

600 num_steps = BLOCK_N1 // MASK_BLOCK_M1 

601 

602 dk, dv = _attn_bwd_dkdv( 

603 dk, 

604 dv, # 

605 Q, 

606 key, 

607 value, 

608 sm_scale, # 

609 DO, # 

610 M, 

611 D, # 

612 stride_tok, 

613 stride_d, # 

614 H, 

615 Q_CTX, # 

616 KV_CTX, # 

617 MASK_BLOCK_M1, 

618 BLOCK_N1, 

619 BLOCK_DMODEL, # 

620 start_n, 

621 start_m, 

622 num_steps, # 

623 MASK=True, # 

624 ) 

625 

626 # Compute dK and dV for non-masked blocks. 

627 start_m += num_steps * MASK_BLOCK_M1 

628 remaining_m = Q_CTX - start_m 

629 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1 

630 

631 if num_steps > 0 and start_m < Q_CTX: 

632 dk, dv = _attn_bwd_dkdv( # 

633 dk, 

634 dv, # 

635 Q, 

636 key, 

637 value, 

638 sm_scale, # 

639 DO, # 

640 M, 

641 D, # 

642 stride_tok, 

643 stride_d, # 

644 H, 

645 Q_CTX, # 

646 KV_CTX, # 

647 BLOCK_M1, 

648 BLOCK_N1, 

649 BLOCK_DMODEL, # 

650 start_n, 

651 start_m, 

652 num_steps, # 

653 MASK=False, # 

654 ) 

655 # tl.device_print("dv: ", dv) 

656 

657 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

658 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None]) 

659 

660 # Write back dK. 

661 dk *= sm_scale 

662 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

663 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None]) 

664 

665 # THIS BLOCK DOES DQ: 

666 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 

667 start_m = pid * BLOCK_M2 

668 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX 

669 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2 

670 

671 offs_m = start_m + tl.arange(0, BLOCK_M2) 

672 offs_m_mask = offs_m < Q_CTX 

673 

674 query = tl.load( 

675 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

676 mask=offs_m_mask[:, None], 

677 other=0.0, 

678 ) 

679 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) 

680 do = tl.load( 

681 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

682 mask=offs_m_mask[:, None], 

683 other=0.0, 

684 ) 

685 

686 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) 

687 m = m[:, None] 

688 

689 # Stage 1 - Compute dQ for masked (diagonal) blocks. 

690 # NOTE: This code scans each row of QK^T backward (from right to left, 

691 # but inside each call to _attn_bwd_dq, from left to right), but that's 

692 # not due to anything important. I just wanted to reuse the loop 

693 # structure for dK & dV above as much as possible. 

694 

695 if num_steps > 0: 

696 dq = _attn_bwd_dq( 

697 dq, 

698 query, 

699 K, 

700 V, # 

701 do, 

702 m, 

703 D, # 

704 stride_tok, 

705 stride_d, # 

706 H, 

707 Q_CTX, # 

708 KV_CTX, # 

709 BLOCK_M2, 

710 MASK_BLOCK_N2, 

711 BLOCK_DMODEL, # 

712 start_m, 

713 start_n, 

714 num_steps, # 

715 MASK=True, # 

716 ) 

717 

718 # Stage 2 - non-masked blocks 

719 stage2_end_n = start_n 

720 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2 

721 

722 if stage2_num_steps > 0: 

723 dq = _attn_bwd_dq( 

724 dq, 

725 query, 

726 K, 

727 V, # 

728 do, 

729 m, 

730 D, # 

731 stride_tok, 

732 stride_d, # 

733 H, 

734 Q_CTX, # 

735 KV_CTX, # 

736 BLOCK_M2, 

737 BLOCK_N2, 

738 BLOCK_DMODEL, # 

739 start_m, 

740 stage2_end_n - stage2_num_steps * BLOCK_N2, 

741 stage2_num_steps, # 

742 MASK=False, # 

743 ) 

744 # Write back dQ. 

745 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

746 dq *= LN2 

747 # tl.store(dq_ptrs, dq) 

748 

749 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None]) 

750 

751 

752def scaled_dot_product_attention_forward( 

753 query, 

754 key, 

755 value, 

756 attn_mask=None, 

757 dropout_p=0.0, 

758 is_causal=False, 

759 scale=None, 

760 enable_gqa=False, 

761): 

762 return ScaleDotProductAttention.apply( 

763 query, 

764 key, 

765 value, 

766 attn_mask, 

767 dropout_p, 

768 is_causal, 

769 scale, 

770 enable_gqa, 

771 ) 

772 

773 

774def scaled_dot_product_attention_backward( 

775 do, 

776 query, 

777 key, 

778 value, 

779 o, 

780 M, 

781 attn_mask=None, 

782 dropout_p=0.0, 

783 is_causal=False, 

784 scale=None, 

785 enable_gqa=False, 

786): 

787 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION BACKWARD") 

788 # shape constraints 

789 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

790 # when v is in float8_e5m2 it is transposed. 

791 HEAD_DIM_V = value.shape[-1] 

792 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

793 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

794 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

795 

796 if scale is None: 

797 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

798 else: 

799 sm_scale = scale 

800 

801 assert do.is_contiguous() 

802 assert ( 

803 query.is_contiguous() 

804 and key.is_contiguous() 

805 and value.is_contiguous() 

806 and o.is_contiguous() 

807 ) 

808 assert query.stride() == o.stride() == do.stride() 

809 assert key.stride() == value.stride() 

810 

811 BLOCK_DMODEL = HEAD_DIM_K 

812 BATCH, Q_HEAD, Q_CTX = query.shape[:3] 

813 _, KV_HEAD, KV_CTX = key.shape[:3] 

814 group_head = Q_HEAD // KV_HEAD 

815 

816 NUM_WARPS, NUM_STAGES = 4, 1 

817 BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 

818 BLK_SLICE_FACTOR = 2 

819 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 

820 

821 RCP_LN2 = 1.0 / math.log(2) 

822 

823 arg_k = key 

824 arg_k = arg_k * (sm_scale * RCP_LN2) 

825 # PRE_BLOCK = 128 

826 PRE_BLOCK = 256 

827 

828 # PRE_BLOCK = 32 

829 # assert N_CTX % PRE_BLOCK == 0 

830 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD) 

831 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD) 

832 

833 delta = torch.empty_like(M) 

834 

835 # NOTE that dk & dv always have the same number of heads as q 

836 dq = torch.empty_like(query).contiguous() 

837 dk = torch.empty((BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K)).to(key.device).contiguous() 

838 dv = torch.empty((BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V)).to(value.device).contiguous() 

839 

840 _attn_bwd_preprocess[pre_grid]( 

841 o, 

842 do, # 

843 delta, # 

844 BATCH, 

845 Q_HEAD, 

846 Q_CTX, # 

847 BLOCK_M=PRE_BLOCK, 

848 D_HEAD=BLOCK_DMODEL, # 

849 ) 

850 

851 grid = (triton.cdiv(Q_CTX, BLOCK_N1), 1, BATCH * Q_HEAD) 

852 logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}") 

853 logger.info(f"{M.shape=}") 

854 

855 _attn_bwd[grid]( 

856 query, 

857 arg_k, 

858 value, 

859 sm_scale, 

860 do, 

861 dq, 

862 dk, 

863 dv, # 

864 M, 

865 delta, # 

866 query.stride(0), 

867 query.stride(1), 

868 query.stride(2), 

869 query.stride(3), # 

870 key.stride(0), 

871 key.stride(1), # 

872 Q_HEAD, 

873 Q_CTX, # 

874 KV_CTX, # 

875 KV_HEAD, # 

876 GROUP_HEAD=group_head, # 

877 BLOCK_M1=BLOCK_M1, 

878 BLOCK_N1=BLOCK_N1, # 

879 BLOCK_M2=BLOCK_M2, 

880 BLOCK_N2=BLOCK_N2, # 

881 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 

882 BLOCK_DMODEL=BLOCK_DMODEL, # 

883 num_warps=NUM_WARPS, # 

884 num_stages=NUM_STAGES, # 

885 ) 

886 

887 if group_head > 1: 

888 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K) 

889 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V) 

890 dk = dk.sum(dim=2) 

891 dv = dv.sum(dim=2) 

892 

893 return dq, dk, dv 

894 

895 

896class ScaleDotProductAttention(torch.autograd.Function): 

897 @staticmethod 

898 def forward( 

899 ctx, 

900 query, 

901 key, 

902 value, 

903 attn_mask=None, 

904 dropout_p=0.0, 

905 is_causal=False, 

906 scale=None, 

907 enable_gqa=False, 

908 ): 

909 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION") 

910 # shape constraints 

911 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

912 # when v is in float8_e5m2 it is transposed. 

913 HEAD_DIM_V = value.shape[-1] 

914 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

915 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

916 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

917 

918 o = torch.empty_like(query, dtype=value.dtype) 

919 

920 stage = 3 if is_causal else 1 

921 

922 if scale is None: 

923 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

924 else: 

925 sm_scale = scale 

926 

927 q_head_num = query.shape[1] 

928 kv_head_num = key.shape[1] 

929 assert enable_gqa or q_head_num == kv_head_num, ( 

930 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, " 

931 "enable_gqa must be True to support different head numbers." 

932 ) 

933 

934 grid = lambda args: ( 

935 triton.cdiv(query.shape[2], args["BLOCK_M"]), 

936 query.shape[0] * query.shape[1], 

937 1, 

938 ) 

939 

940 if attn_mask is not None: 

941 HAS_ATTN_MASK = True 

942 if attn_mask.dtype == torch.bool: 

943 attn_mask = attn_mask.to(query.dtype) * -1.0e6 

944 stride_attn_mask_batch = attn_mask.stride(0) 

945 stride_attn_mask_head = attn_mask.stride(1) 

946 stride_attn_mask_q_seqlen = attn_mask.stride(2) 

947 stride_attn_mask_kv_seqlen = attn_mask.stride(3) 

948 else: 

949 HAS_ATTN_MASK = False 

950 stride_attn_mask_batch = 1 

951 stride_attn_mask_head = 1 

952 stride_attn_mask_q_seqlen = 1 

953 stride_attn_mask_kv_seqlen = 1 

954 

955 M = torch.empty( 

956 (query.shape[0], query.shape[1], query.shape[2]), 

957 device=query.device, 

958 dtype=torch.float32, 

959 ) 

960 

961 with torch_device_fn.device(query.device): 

962 _attn_fwd[grid]( 

963 query, 

964 key, 

965 value, 

966 attn_mask, 

967 sm_scale, 

968 M, 

969 o, # 

970 query.stride(0), 

971 query.stride(1), 

972 query.stride(2), 

973 query.stride(3), # 

974 key.stride(0), 

975 key.stride(1), 

976 key.stride(2), 

977 key.stride(3), # 

978 value.stride(0), 

979 value.stride(1), 

980 value.stride(2), 

981 value.stride(3), # 

982 stride_attn_mask_batch, 

983 stride_attn_mask_head, 

984 stride_attn_mask_q_seqlen, 

985 stride_attn_mask_kv_seqlen, # 

986 o.stride(0), 

987 o.stride(1), 

988 o.stride(2), 

989 o.stride(3), # 

990 query.shape[0], 

991 q_head_num, 

992 kv_head_num, # 

993 q_head_num // kv_head_num, # group_head 

994 query.shape[2], # 

995 key.shape[2], # 

996 HEAD_DIM_K, # 

997 STAGE=stage, # 

998 HAS_ATTN_MASK=HAS_ATTN_MASK, # 

999 ) 

1000 

1001 ctx.save_for_backward(query, key, value, o, M) 

1002 ctx.grid = grid 

1003 ctx.sm_scale = sm_scale 

1004 ctx.BLOCK_DMODEL = HEAD_DIM_K 

1005 ctx.causal = is_causal 

1006 ctx.enable_gqa = enable_gqa 

1007 return o 

1008 

1009 @staticmethod 

1010 def backward(ctx, do): 

1011 query, key, value, o, M = ctx.saved_tensors 

1012 is_causal = ctx.causal 

1013 enable_gqa = ctx.enable_gqa 

1014 sm_scale = ctx.sm_scale 

1015 dq, dk, dv = scaled_dot_product_attention_backward( 

1016 do, 

1017 query, 

1018 key, 

1019 value, 

1020 o, 

1021 M, 

1022 attn_mask=None, 

1023 dropout_p=0.0, 

1024 is_causal=is_causal, 

1025 scale=sm_scale, 

1026 enable_gqa=enable_gqa, 

1027 ) 

1028 return dq, dk, dv, None, None, None, None, None 

1029 

1030 

1031def scaled_dot_product_attention( 

1032 query, 

1033 key, 

1034 value, 

1035 attn_mask=None, 

1036 dropout_p=0.0, 

1037 is_causal=False, 

1038 scale=None, 

1039 enable_gqa=False, 

1040): 

1041 return ScaleDotProductAttention.apply( 

1042 query, 

1043 key, 

1044 value, 

1045 attn_mask, 

1046 dropout_p, 

1047 is_causal, 

1048 scale, 

1049 enable_gqa, 

1050 ) 

1051 

1052 

1053def flash_attention_forward( 

1054 query, 

1055 key, 

1056 value, 

1057 cumulative_sequence_length_q, 

1058 cumulative_sequence_length_k, 

1059 max_q, 

1060 max_k, 

1061 dropout_p, 

1062 is_causal, 

1063 return_debug_mask, 

1064 *, 

1065 scale=None, 

1066 softcap=0.0, 

1067 window_size_left=None, 

1068 window_size_right=None, 

1069 seqused_k=None, 

1070 alibi_slopes=None, 

1071 disable_splitkv=False, 

1072): 

1073 logger.debug("GEMS FLASH_ATTENTION_FORWARD") 

1074 assert ( 

1075 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None 

1076 ), "varlen is not supported yet." 

1077 

1078 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

1079 HEAD_DIM_V = value.shape[-1] 

1080 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

1081 assert HEAD_DIM_K in {16, 32, 64, 96, 128, 192, 256} 

1082 

1083 softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5) 

1084 if window_size_left is not None: 

1085 non_null_window_left = window_size_left 

1086 else: 

1087 non_null_window_left = -1 

1088 if window_size_right is not None: 

1089 non_null_window_right = window_size_right 

1090 else: 

1091 non_null_window_right = -1 

1092 

1093 out = torch.empty_like(query) 

1094 if cumulative_sequence_length_q is not None: 

1095 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd( 

1096 query, 

1097 key, 

1098 value, 

1099 out, 

1100 cumulative_sequence_length_q, 

1101 cumulative_sequence_length_k, 

1102 seqused_k, 

1103 None, 

1104 None, # block_table 

1105 alibi_slopes, 

1106 max_q, 

1107 max_k, 

1108 dropout_p, 

1109 scale, 

1110 False, 

1111 is_causal, 

1112 non_null_window_left, 

1113 non_null_window_right, 

1114 softcap, 

1115 return_debug_mask and dropout_p > 0, 

1116 None, 

1117 ) 

1118 else: 

1119 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd( 

1120 query, 

1121 key, 

1122 value, 

1123 out, 

1124 alibi_slopes, 

1125 dropout_p, 

1126 softmax_scale, 

1127 is_causal, 

1128 non_null_window_left, 

1129 non_null_window_right, 

1130 softcap, 

1131 return_debug_mask, 

1132 disable_splitkv=disable_splitkv, 

1133 ) 

1134 

1135 return (out, lse, philox_seed, philox_offset, p) 

1136 

1137 

1138# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py 

1139def maybe_contiguous(x): 

1140 return x.contiguous() if x is not None and x.stride(-1) != 1 else x 

1141 

1142 

1143def flash_attn_varlen_func( 

1144 q, 

1145 k, 

1146 v, 

1147 max_seqlen_q, 

1148 cu_seqlens_q, 

1149 max_seqlen_k, 

1150 cu_seqlens_k=None, # only used for non-paged prefill 

1151 seqused_k=None, 

1152 q_v=None, 

1153 dropout_p=0.0, 

1154 softmax_scale=None, 

1155 causal=False, 

1156 window_size=None, 

1157 softcap=0.0, # 0.0 means deactivated 

1158 alibi_slopes=None, 

1159 deterministic=False, 

1160 return_attn_probs=False, 

1161 block_table=None, 

1162 return_softmax_lse=False, 

1163 out=None, 

1164 # Dummy FA3 arguments 

1165 scheduler_metadata=None, 

1166 q_descale=None, 

1167 k_descale=None, 

1168 v_descale=None, 

1169 num_splits: int = 0, 

1170 fa_version: int = 2, 

1171): 

1172 """dropout_p should be set to 0.0 during evaluation 

1173 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads 

1174 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. 

1175 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 

1176 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 

1177 

1178 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. 

1179 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 

1180 1 1 1 1 0 

1181 1 1 1 1 1 

1182 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 

1183 0 0 

1184 0 0 

1185 0 0 

1186 1 0 

1187 1 1 

1188 If the row of the mask is all zero, the output will be zero. 

1189 

1190 If window_size != (-1, -1), implements sliding window local attention. Query at position i 

1191 will only attend to keys between 

1192 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. 

1193 

1194 Arguments: 

1195 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 

1196 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1197 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1198 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1199 of the sequences in the batch, used to index into q. 

1200 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1201 of the sequences in the batch, used to index into kv. 

1202 max_seqlen_q: int. Maximum query sequence length in the batch. 

1203 max_seqlen_k: int. Maximum key sequence length in the batch. 

1204 dropout_p: float. Dropout probability. 

1205 softmax_scale: float. The scaling of QK^T before applying softmax. 

1206 Default to 1 / sqrt(headdim). 

1207 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 

1208 window_size: (left, right). If not (-1, -1), implements sliding window local attention. 

1209 softcap: float. Anything > 0 activates softcapping attention. 

1210 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of 

1211 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 

1212 is added to the attention score of query i and key j. 

1213 deterministic: bool. Whether to use the deterministic implementation of the backward pass, 

1214 which is slightly slower and uses more memory. The forward pass is always deterministic. 

1215 return_attn_probs: bool. Whether to return the attention probabilities. This option is for 

1216 testing only. The returned probabilities are not guaranteed to be correct 

1217 (they might not have the right scaling). 

1218 Return: 

1219 out: (total, nheads, headdim). 

1220 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The 

1221 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax 

1222 normalization factor). 

1223 """ 

1224 if use_c_extension: 

1225 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)") 

1226 with torch_device_fn.device(q.device): 

1227 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func( 

1228 q, 

1229 k, 

1230 v, 

1231 max_seqlen_q, 

1232 cu_seqlens_q, 

1233 max_seqlen_k, 

1234 cu_seqlens_k, 

1235 seqused_k, 

1236 q_v, 

1237 dropout_p, 

1238 softmax_scale, 

1239 causal, 

1240 window_size, 

1241 softcap, 

1242 alibi_slopes, 

1243 deterministic, 

1244 return_attn_probs, 

1245 block_table, 

1246 return_softmax_lse, 

1247 out, 

1248 scheduler_metadata, 

1249 q_descale, 

1250 k_descale, 

1251 v_descale, 

1252 fa_version, 

1253 ) 

1254 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp 

1255 else: 

1256 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC") 

1257 assert ( 

1258 cu_seqlens_k is not None or seqused_k is not None 

1259 ), "cu_seqlens_k or seqused_k must be provided" 

1260 assert ( 

1261 cu_seqlens_k is None or seqused_k is None 

1262 ), "cu_seqlens_k and seqused_k cannot be provided at the same time" 

1263 assert ( 

1264 block_table is None or seqused_k is not None 

1265 ), "seqused_k must be provided if block_table is provided" 

1266 if softmax_scale is None: 

1267 softmax_scale = q.shape[-1] ** (-0.5) 

1268 # custom op does not support non-tuple input 

1269 if window_size is None: 

1270 real_window_size = (-1, -1) 

1271 else: 

1272 assert len(window_size) == 2 

1273 real_window_size = (window_size[0], window_size[1]) 

1274 q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 

1275 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) 

1276 if fa_version != 2: 

1277 raise RuntimeError("Only FA2 is implemented.") 

1278 if num_splits > 0: 

1279 raise RuntimeError("num_splits > 0 is not implemented in GEMS.") 

1280 max_seqlen_q = ( 

1281 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q 

1282 ) 

1283 max_seqlen_k = ( 

1284 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k 

1285 ) 

1286 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd( 

1287 q, 

1288 k, 

1289 v, 

1290 out, 

1291 cu_seqlens_q, 

1292 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp 

1293 # still wants it so we pass all zeros 

1294 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, 

1295 seqused_k, 

1296 None, 

1297 block_table, 

1298 alibi_slopes, 

1299 max_seqlen_q, 

1300 max_seqlen_k, 

1301 dropout_p, 

1302 softmax_scale, 

1303 False, 

1304 causal, 

1305 real_window_size[0], 

1306 real_window_size[1], 

1307 softcap, 

1308 return_softmax_lse and dropout_p > 0, 

1309 None, 

1310 ) 

1311 

1312 return (out, softmax_lse) if return_softmax_lse else out