Coverage for src/flag_gems/runtime/backend/_cambricon/ops/attention.py: 0%

397 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2import math 

3from functools import partial 

4 

5import torch 

6import torch.nn.functional as F 

7import triton 

8import triton.language as tl 

9 

10from flag_gems import runtime 

11from flag_gems.config import use_c_extension 

12from flag_gems.ops.flash_api import mha_fwd, mha_varlan_fwd 

13from flag_gems.ops.flash_kernel import keep 

14from flag_gems.runtime import torch_device_fn 

15from flag_gems.utils import libentry, libtuner 

16 

17logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

18 

19 

20# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 

21@triton.jit 

22def _attn_fwd_inner( 

23 acc, 

24 l_i, 

25 m_i, 

26 query, # 

27 K_block_ptr, 

28 V_block_ptr, # 

29 mask_block_ptr, # 

30 stride_k_seqlen, 

31 stride_v_seqlen, 

32 stride_attn_mask_kv_seqlen, # 

33 start_m, 

34 qk_scale, # 

35 q_load_mask, 

36 BLOCK_M: tl.constexpr, 

37 HEAD_DIM: tl.constexpr, 

38 BLOCK_N: tl.constexpr, # 

39 STAGE: tl.constexpr, 

40 offs_m: tl.constexpr, 

41 offs_n: tl.constexpr, # 

42 KV_CTX: tl.constexpr, 

43 fp8_v: tl.constexpr, 

44 HAS_ATTN_MASK: tl.constexpr, 

45 PRE_LOAD_V: tl.constexpr, 

46): 

47 # range of values handled by this stage 

48 if STAGE == 1: 

49 lo, hi = 0, start_m * BLOCK_M 

50 elif STAGE == 2: 

51 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 

52 # causal = False 

53 else: 

54 lo, hi = 0, KV_CTX 

55 

56 K_block_ptr += lo * stride_k_seqlen 

57 V_block_ptr += lo * stride_v_seqlen 

58 if HAS_ATTN_MASK: 

59 mask_block_ptr += lo * stride_attn_mask_kv_seqlen 

60 

61 LOG2E = 1.44269504 # log2(e) constant 

62 

63 # loop over key, value and update accumulator 

64 for start_n in range(lo, hi, BLOCK_N): 

65 kv_load_mask = (start_n + offs_n) < KV_CTX 

66 # start_n = tl.multiple_of(start_n, BLOCK_N) 

67 # -- compute qk ---- 

68 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0) 

69 if PRE_LOAD_V: 

70 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

71 

72 qk = tl.dot(query, key, allow_tf32=False) 

73 # incase not divisible. 

74 qk = tl.where(kv_load_mask[None, :], qk, -float("inf")) 

75 # qk = qk.to(tl.float32) 

76 

77 if HAS_ATTN_MASK: 

78 attn_mask = tl.load( 

79 mask_block_ptr, 

80 mask=q_load_mask[:, None] & kv_load_mask[None, :], 

81 other=0.0, 

82 ) 

83 

84 if STAGE == 2: 

85 mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 

86 

87 if HAS_ATTN_MASK: 

88 qk = qk * qk_scale + attn_mask 

89 qk *= LOG2E 

90 qk = qk + tl.where(mask, 0, -1.0e6) 

91 else: 

92 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6) 

93 

94 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

95 qk -= m_ij[:, None] 

96 else: 

97 qk *= qk_scale * LOG2E 

98 if HAS_ATTN_MASK: 

99 qk = qk + attn_mask 

100 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

101 qk = qk - m_ij[:, None] 

102 

103 p = tl.math.exp2(qk) 

104 l_ij = tl.sum(p, 1) 

105 # -- update m_i and l_i 

106 alpha = tl.math.exp2(m_i - m_ij) 

107 l_i = l_i * alpha + l_ij 

108 # -- update output accumulator -- 

109 acc = acc * alpha[:, None] 

110 # update acc 

111 if not PRE_LOAD_V: 

112 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

113 if fp8_v: 

114 p = p.to(tl.float8e5) 

115 acc = tl.dot(p, value.to(p.dtype), acc, allow_tf32=False) 

116 # update m_i and l_i 

117 m_i = m_ij 

118 

119 K_block_ptr += BLOCK_N * stride_k_seqlen 

120 V_block_ptr += BLOCK_N * stride_v_seqlen 

121 

122 if HAS_ATTN_MASK: 

123 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen 

124 

125 return acc, l_i, m_i 

126 

127 

128# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim, 

129# we need to generate more configs. 

130configs = runtime.get_tuned_config("attention") 

131SMALL_HEAD_DIM_CONFIGS = [ 

132 triton.Config( 

133 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w 

134 ) 

135 for BM in [64, 128] 

136 for BN in [16, 32] 

137 for s in [2, 3, 4] 

138 for w in [4, 8] 

139] 

140configs += SMALL_HEAD_DIM_CONFIGS 

141 

142 

143@libentry() 

144@libtuner( 

145 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)), 

146 key=["KV_CTX", "HEAD_DIM"], 

147) 

148@triton.jit 

149def _attn_fwd( 

150 Q, 

151 K, 

152 V, 

153 attn_mask, 

154 sm_scale, 

155 M, 

156 Out, # 

157 stride_q_batch, 

158 stride_q_head, 

159 stride_q_seqlen, 

160 stride_q_headsize, 

161 stride_k_batch, 

162 stride_k_head, 

163 stride_k_seqlen, 

164 stride_k_headsize, 

165 stride_v_batch, 

166 stride_v_head, 

167 stride_v_seqlen, 

168 stride_v_headsize, 

169 stride_attn_mask_batch, 

170 stride_attn_mask_head, 

171 stride_attn_mask_q_seqlen, 

172 stride_attn_mask_kv_seqlen, 

173 stride_o_batch, 

174 stride_o_head, 

175 stride_o_seqlen, 

176 stride_o_headsize, 

177 Z, 

178 q_head_num, 

179 kv_head_num, 

180 GROUP_HEAD: tl.constexpr, 

181 Q_CTX, 

182 KV_CTX, 

183 HEAD_DIM: tl.constexpr, 

184 BLOCK_M: tl.constexpr, 

185 BLOCK_N: tl.constexpr, 

186 STAGE: tl.constexpr, 

187 HAS_ATTN_MASK: tl.constexpr, 

188 PRE_LOAD_V: tl.constexpr, 

189): 

190 tl.static_assert(BLOCK_N <= HEAD_DIM) 

191 start_m = tl.program_id(0) 

192 off_hz = tl.program_id(1) 

193 batch_id = off_hz // q_head_num 

194 head_id = off_hz % q_head_num 

195 kv_head_id = head_id // GROUP_HEAD 

196 

197 q_offset = ( 

198 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head 

199 ) 

200 o_offset = ( 

201 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head 

202 ) 

203 kv_offset = ( 

204 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head 

205 ) 

206 

207 offs_headsize = tl.arange(0, HEAD_DIM) 

208 

209 # initialize offsets 

210 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 

211 q_load_mask = offs_m < Q_CTX 

212 offs_n = tl.arange(0, BLOCK_N) 

213 

214 Q_block_ptr = ( 

215 Q 

216 + q_offset 

217 + offs_m[:, None] * stride_q_seqlen 

218 + offs_headsize[None, :] * stride_q_headsize 

219 ) 

220 K_block_ptr = ( 

221 K 

222 + kv_offset 

223 + offs_n[None, :] * stride_k_seqlen 

224 + offs_headsize[:, None] * stride_k_headsize 

225 ) 

226 V_block_ptr = ( 

227 V 

228 + kv_offset 

229 + offs_n[:, None] * stride_v_seqlen 

230 + offs_headsize[None, :] * stride_v_headsize 

231 ) 

232 

233 if HAS_ATTN_MASK: 

234 attn_mask_offset = ( 

235 batch_id.to(tl.int64) * stride_attn_mask_batch 

236 + head_id.to(tl.int64) * stride_attn_mask_head 

237 ) 

238 mask_block_ptr = ( 

239 attn_mask 

240 + attn_mask_offset 

241 + offs_m[:, None] * stride_attn_mask_q_seqlen 

242 + offs_n[None, :] * stride_attn_mask_kv_seqlen 

243 ) 

244 else: 

245 mask_block_ptr = None 

246 

247 O_block_ptr = ( 

248 Out 

249 + o_offset 

250 + offs_m[:, None] * stride_o_seqlen 

251 + offs_headsize[None, :] * stride_o_headsize 

252 ) 

253 

254 # initialize pointer to m and l 

255 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 

256 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 

257 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 

258 # load scales 

259 qk_scale = sm_scale 

260 # qk_scale *= 1.44269504 # 1/log(2) 

261 # load query: it will stay in SRAM throughout 

262 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0) 

263 # stage 1: off-band 

264 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 

265 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 

266 if STAGE & 1: 

267 acc, l_i, m_i = _attn_fwd_inner( 

268 acc, 

269 l_i, 

270 m_i, 

271 query, 

272 K_block_ptr, 

273 V_block_ptr, 

274 mask_block_ptr, 

275 stride_k_seqlen, 

276 stride_v_seqlen, 

277 stride_attn_mask_kv_seqlen, 

278 start_m, 

279 qk_scale, 

280 q_load_mask, 

281 BLOCK_M, 

282 HEAD_DIM, 

283 BLOCK_N, 

284 4 - STAGE, 

285 offs_m, 

286 offs_n, 

287 KV_CTX, 

288 V.dtype.element_ty == tl.float8e5, 

289 HAS_ATTN_MASK, 

290 PRE_LOAD_V, 

291 ) 

292 # stage 2: on-band 

293 if STAGE & 2: 

294 # barrier makes it easier for compielr to schedule the 

295 # two loops independently 

296 acc, l_i, m_i = _attn_fwd_inner( 

297 acc, 

298 l_i, 

299 m_i, 

300 query, 

301 K_block_ptr, 

302 V_block_ptr, 

303 mask_block_ptr, 

304 stride_k_seqlen, 

305 stride_v_seqlen, 

306 stride_attn_mask_kv_seqlen, 

307 start_m, 

308 qk_scale, 

309 q_load_mask, 

310 BLOCK_M, 

311 HEAD_DIM, 

312 BLOCK_N, 

313 2, 

314 offs_m, 

315 offs_n, 

316 KV_CTX, 

317 V.dtype.element_ty == tl.float8e5, 

318 HAS_ATTN_MASK, 

319 PRE_LOAD_V, 

320 ) 

321 # epilogue 

322 m_i += tl.math.log2(l_i) 

323 acc = acc / l_i[:, None] 

324 m_ptrs = M + off_hz * Q_CTX + offs_m 

325 tl.store(m_ptrs, m_i, mask=q_load_mask) 

326 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) 

327 

328 

329@triton.jit 

330def _attn_bwd_preprocess( 

331 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr 

332): 

333 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 

334 mask = off_m < Q_CTX 

335 

336 off_hz = tl.program_id(1) 

337 off_n = tl.arange(0, D_HEAD) 

338 # load 

339 o = tl.load( 

340 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

341 mask=mask[:, None], 

342 other=0.0, 

343 ) 

344 do = tl.load( 

345 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

346 mask=mask[:, None], 

347 other=0.0, 

348 ).to(tl.float32) 

349 delta = tl.sum(o * do, axis=1) 

350 # write-back 

351 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask) 

352 

353 

354# The main inner-loop logic for computing dK and dV. 

355@triton.jit 

356def _attn_bwd_dkdv( 

357 dk, 

358 dv, # 

359 Q, 

360 key, 

361 value, 

362 sm_scale, # 

363 DO, # 

364 M, 

365 D, # 

366 # shared by Q/K/V/DO. 

367 stride_tok, 

368 stride_d, # 

369 H, 

370 Q_CTX, 

371 KV_CTX, 

372 BLOCK_M1: tl.constexpr, # 

373 BLOCK_N1: tl.constexpr, # 

374 BLOCK_DMODEL: tl.constexpr, # 

375 # Filled in by the wrapper. 

376 start_n, 

377 start_m, 

378 num_steps, # 

379 MASK: tl.constexpr, 

380): 

381 # BLOCK_M1: 32 

382 # BLOCK_N1: 128 

383 offs_n = start_n + tl.arange(0, BLOCK_N1) 

384 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, ) 

385 

386 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, ) 

387 

388 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 

389 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 

390 curr_m = start_m 

391 step_m = BLOCK_M1 

392 for blk_idx in range(num_steps): 

393 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, ) 

394 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, ) 

395 

396 qT_ptrs = ( 

397 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d 

398 ) # (BLOCK_DMODEL, BLOCK_M1) 

399 do_ptrs = ( 

400 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

401 ) # (BLOCK_M1, BLOCK_DMODEL) 

402 

403 qT = tl.load( 

404 qT_ptrs, mask=offs_m_mask[None, :], other=0.0 

405 ) # (BLOCK_DMODEL, BLOCK_M1) 

406 

407 # Load m before computing qk to reduce pipeline stall. 

408 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, ) 

409 

410 # key: (BLOCK_N1, BLOCK_DMODEL) 

411 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1) 

412 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1) 

413 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1) 

414 pT = tl.math.exp2(qkT - m) 

415 # pT = tl.math.exp2(qkT - m[None, :]) 

416 

417 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[ 

418 :, None 

419 ] # (BLOCK_N1, BLOCK_M1) 

420 # Autoregressive masking. 

421 if MASK: 

422 mask &= offs_m[None, :] >= offs_n[:, None] 

423 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1) 

424 

425 do = tl.load(do_ptrs) 

426 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL) 

427 

428 # Compute dV. 

429 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL) 

430 # D (= delta) is pre-divided by ds_scale. 

431 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, ) 

432 

433 # Compute dP and dS. 

434 dpT = tl.dot(value, tl.trans(do)).to( 

435 tl.float32 

436 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1) 

437 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1) 

438 dsT = dsT.to(qT.dtype) 

439 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1) 

440 dsT = tl.where( 

441 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0 

442 ) # (BLOCK_N1, BLOCK_M1) 

443 dk += tl.dot( 

444 dsT, tl.trans(qT) 

445 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL) 

446 # Increment pointers. 

447 curr_m += step_m 

448 return dk, dv 

449 

450 

451# the main inner-loop logic for computing dQ 

452@triton.jit 

453def _attn_bwd_dq( 

454 dq, 

455 query, 

456 K, 

457 V, # 

458 do, 

459 m, 

460 D, 

461 # shared by Q/K/V/DO. 

462 stride_tok, 

463 stride_d, # 

464 H, 

465 Q_CTX, # 

466 KV_CTX, # 

467 BLOCK_M2: tl.constexpr, # 

468 BLOCK_N2: tl.constexpr, # 

469 BLOCK_DMODEL: tl.constexpr, 

470 # Filled in by the wrapper. 

471 start_m, 

472 start_n, 

473 num_steps, # 

474 MASK: tl.constexpr, 

475): 

476 offs_m = start_m + tl.arange(0, BLOCK_M2) 

477 offs_m_mask = offs_m < Q_CTX 

478 

479 offs_k = tl.arange(0, BLOCK_DMODEL) 

480 # D (= delta) is pre-divided by ds_scale. 

481 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) 

482 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 

483 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 

484 curr_n = start_n 

485 step_n = BLOCK_N2 

486 for blk_idx in range(num_steps): 

487 offs_n = curr_n + tl.arange(0, BLOCK_N2) 

488 offs_n_mask = offs_n < KV_CTX 

489 

490 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

491 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

492 

493 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

494 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

495 qk = tl.dot(query, kT) 

496 p = tl.math.exp2(qk - m) 

497 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :] 

498 # Autoregressive masking. 

499 if MASK: 

500 # mask = (offs_m[:, None] >= offs_n[None, :]) 

501 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :] 

502 mask &= offs_m[:, None] >= offs_n[None, :] 

503 p = tl.where(mask, p, 0.0) 

504 # Compute dP and dS. 

505 dp = tl.dot(do, vT).to(tl.float32) 

506 ds = p * (dp - Di[:, None]) 

507 ds = tl.where(mask, ds, 0.0).to(kT.dtype) 

508 # Compute dQ. 

509 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 

510 dq += tl.dot(ds, tl.trans(kT)) 

511 # Increment pointers. 

512 curr_n += step_n 

513 return dq 

514 

515 

516config_backward = runtime.get_tuned_config("attention_bwd") 

517 

518 

519@libentry() 

520@libtuner( 

521 configs=config_backward, 

522 key=["KV_CTX", "BLOCK_DMODEL"], 

523) 

524@triton.jit 

525def _attn_bwd( 

526 Q, 

527 K, 

528 V, 

529 sm_scale, # 

530 DO, # 

531 DQ, 

532 DK, 

533 DV, # 

534 M, 

535 D, 

536 # shared by Q/K/V/DO. 

537 stride_z, 

538 stride_h, 

539 stride_tok, 

540 stride_d, # 

541 kv_stride_z, 

542 kv_stride_h, # 

543 H, # query head num 

544 Q_CTX, # 

545 KV_CTX, # 

546 kv_head_num, # 

547 GROUP_HEAD: tl.constexpr, # 

548 BLOCK_M1: tl.constexpr, # 

549 BLOCK_N1: tl.constexpr, # 

550 BLOCK_M2: tl.constexpr, # 

551 BLOCK_N2: tl.constexpr, # 

552 BLK_SLICE_FACTOR: tl.constexpr, # 

553 BLOCK_DMODEL: tl.constexpr, 

554): 

555 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.") 

556 

557 LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 

558 

559 bhid = tl.program_id(2) 

560 off_chz = (bhid * Q_CTX).to(tl.int64) 

561 batch_id = bhid // H 

562 q_head_id = bhid % H 

563 kv_head_id = q_head_id // GROUP_HEAD 

564 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64) 

565 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64) 

566 

567 pid = tl.program_id(0) 

568 

569 # offset pointers for batch/head 

570 Q += adj 

571 K += kv_adj 

572 V += kv_adj 

573 DO += adj 

574 DQ += adj 

575 DK += adj 

576 DV += adj 

577 M += off_chz 

578 D += off_chz 

579 

580 # load scales 

581 offs_k = tl.arange(0, BLOCK_DMODEL) 

582 

583 start_n = pid * BLOCK_N1 

584 start_m = start_n 

585 

586 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 

587 offs_n = start_n + tl.arange(0, BLOCK_N1) 

588 offs_n_mask = offs_n < KV_CTX 

589 

590 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

591 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

592 

593 # load K and V: they stay in SRAM throughout the inner loop. 

594 key = tl.load( 

595 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

596 mask=offs_n_mask[:, None], 

597 other=0.0, 

598 ) 

599 value = tl.load( 

600 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

601 mask=offs_n_mask[:, None], 

602 other=0.0, 

603 ) 

604 

605 num_steps = BLOCK_N1 // MASK_BLOCK_M1 

606 

607 dk, dv = _attn_bwd_dkdv( 

608 dk, 

609 dv, # 

610 Q, 

611 key, 

612 value, 

613 sm_scale, # 

614 DO, # 

615 M, 

616 D, # 

617 stride_tok, 

618 stride_d, # 

619 H, 

620 Q_CTX, # 

621 KV_CTX, # 

622 MASK_BLOCK_M1, 

623 BLOCK_N1, 

624 BLOCK_DMODEL, # 

625 start_n, 

626 start_m, 

627 num_steps, # 

628 MASK=True, # 

629 ) 

630 

631 # Compute dK and dV for non-masked blocks. 

632 start_m += num_steps * MASK_BLOCK_M1 

633 remaining_m = Q_CTX - start_m 

634 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1 

635 

636 if num_steps > 0 and start_m < Q_CTX: 

637 dk, dv = _attn_bwd_dkdv( # 

638 dk, 

639 dv, # 

640 Q, 

641 key, 

642 value, 

643 sm_scale, # 

644 DO, # 

645 M, 

646 D, # 

647 stride_tok, 

648 stride_d, # 

649 H, 

650 Q_CTX, # 

651 KV_CTX, # 

652 BLOCK_M1, 

653 BLOCK_N1, 

654 BLOCK_DMODEL, # 

655 start_n, 

656 start_m, 

657 num_steps, # 

658 MASK=False, # 

659 ) 

660 # tl.device_print("dv: ", dv) 

661 

662 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

663 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None]) 

664 

665 # Write back dK. 

666 dk *= sm_scale 

667 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

668 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None]) 

669 

670 # THIS BLOCK DOES DQ: 

671 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 

672 start_m = pid * BLOCK_M2 

673 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX 

674 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2 

675 

676 offs_m = start_m + tl.arange(0, BLOCK_M2) 

677 offs_m_mask = offs_m < Q_CTX 

678 

679 query = tl.load( 

680 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

681 mask=offs_m_mask[:, None], 

682 other=0.0, 

683 ) 

684 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) 

685 do = tl.load( 

686 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

687 mask=offs_m_mask[:, None], 

688 other=0.0, 

689 ) 

690 

691 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) 

692 m = m[:, None] 

693 

694 # Stage 1 - Compute dQ for masked (diagonal) blocks. 

695 # NOTE: This code scans each row of QK^T backward (from right to left, 

696 # but inside each call to _attn_bwd_dq, from left to right), but that's 

697 # not due to anything important. I just wanted to reuse the loop 

698 # structure for dK & dV above as much as possible. 

699 

700 if num_steps > 0: 

701 dq = _attn_bwd_dq( 

702 dq, 

703 query, 

704 K, 

705 V, # 

706 do, 

707 m, 

708 D, # 

709 stride_tok, 

710 stride_d, # 

711 H, 

712 Q_CTX, # 

713 KV_CTX, # 

714 BLOCK_M2, 

715 MASK_BLOCK_N2, 

716 BLOCK_DMODEL, # 

717 start_m, 

718 start_n, 

719 num_steps, # 

720 MASK=True, # 

721 ) 

722 

723 # Stage 2 - non-masked blocks 

724 stage2_end_n = start_n 

725 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2 

726 

727 if stage2_num_steps > 0: 

728 dq = _attn_bwd_dq( 

729 dq, 

730 query, 

731 K, 

732 V, # 

733 do, 

734 m, 

735 D, # 

736 stride_tok, 

737 stride_d, # 

738 H, 

739 Q_CTX, # 

740 KV_CTX, # 

741 BLOCK_M2, 

742 BLOCK_N2, 

743 BLOCK_DMODEL, # 

744 start_m, 

745 stage2_end_n - stage2_num_steps * BLOCK_N2, 

746 stage2_num_steps, # 

747 MASK=False, # 

748 ) 

749 # Write back dQ. 

750 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

751 dq *= LN2 

752 # tl.store(dq_ptrs, dq) 

753 

754 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None]) 

755 

756 

757def scaled_dot_product_attention_forward( 

758 query, 

759 key, 

760 value, 

761 attn_mask=None, 

762 dropout_p=0.0, 

763 is_causal=False, 

764 scale=None, 

765 enable_gqa=False, 

766): 

767 logger.debug("GEMS_CAMBRICON SCALED DOT PRODUCT ATTENTION FORWARD") 

768 # shape constraints 

769 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

770 # when v is in float8_e5m2 it is transposed. 

771 HEAD_DIM_V = value.shape[-1] 

772 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

773 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

774 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

775 

776 o = torch.empty_like(query, dtype=value.dtype) 

777 

778 stage = 3 if is_causal else 1 

779 

780 if scale is None: 

781 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

782 else: 

783 sm_scale = scale 

784 

785 q_head_num = query.shape[1] 

786 kv_head_num = key.shape[1] 

787 assert enable_gqa or q_head_num == kv_head_num, ( 

788 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, " 

789 "enable_gqa must be True to support different head numbers." 

790 ) 

791 

792 grid = lambda args: ( 

793 triton.cdiv(query.shape[2], args["BLOCK_M"]), 

794 query.shape[0] * query.shape[1], 

795 1, 

796 ) 

797 

798 if attn_mask is not None: 

799 HAS_ATTN_MASK = True 

800 if attn_mask.dtype == torch.bool: 

801 attn_mask = attn_mask.to(query.dtype) * -1.0e6 

802 stride_attn_mask_batch = attn_mask.stride(0) 

803 stride_attn_mask_head = attn_mask.stride(1) 

804 stride_attn_mask_q_seqlen = attn_mask.stride(2) 

805 stride_attn_mask_kv_seqlen = attn_mask.stride(3) 

806 else: 

807 HAS_ATTN_MASK = False 

808 stride_attn_mask_batch = 1 

809 stride_attn_mask_head = 1 

810 stride_attn_mask_q_seqlen = 1 

811 stride_attn_mask_kv_seqlen = 1 

812 

813 M = torch.empty( 

814 (query.shape[0], query.shape[1], query.shape[2]), 

815 device=query.device, 

816 dtype=torch.float32, 

817 ) 

818 

819 with torch_device_fn.device(query.device): 

820 _attn_fwd[grid]( 

821 query, 

822 key, 

823 value, 

824 attn_mask, 

825 sm_scale, 

826 M, 

827 o, # 

828 query.stride(0), 

829 query.stride(1), 

830 query.stride(2), 

831 query.stride(3), # 

832 key.stride(0), 

833 key.stride(1), 

834 key.stride(2), 

835 key.stride(3), # 

836 value.stride(0), 

837 value.stride(1), 

838 value.stride(2), 

839 value.stride(3), # 

840 stride_attn_mask_batch, 

841 stride_attn_mask_head, 

842 stride_attn_mask_q_seqlen, 

843 stride_attn_mask_kv_seqlen, # 

844 o.stride(0), 

845 o.stride(1), 

846 o.stride(2), 

847 o.stride(3), # 

848 query.shape[0], 

849 q_head_num, 

850 kv_head_num, # 

851 q_head_num // kv_head_num, # group_head 

852 query.shape[2], # 

853 key.shape[2], # 

854 HEAD_DIM_K, # 

855 STAGE=stage, # 

856 HAS_ATTN_MASK=HAS_ATTN_MASK, # 

857 ) 

858 return o, M 

859 

860 

861def scaled_dot_product_attention_backward( 

862 do, 

863 query, 

864 key, 

865 value, 

866 o, 

867 M, 

868 attn_mask=None, 

869 dropout_p=0.0, 

870 is_causal=False, 

871 scale=None, 

872 enable_gqa=False, 

873): 

874 logger.debug("GEMS_CAMBRICON SCALED DOT PRODUCT ATTENTION BACKWARD") 

875 # shape constraints 

876 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

877 # when v is in float8_e5m2 it is transposed. 

878 HEAD_DIM_V = value.shape[-1] 

879 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

880 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

881 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

882 

883 if scale is None: 

884 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

885 else: 

886 sm_scale = scale 

887 

888 assert do.is_contiguous() 

889 assert ( 

890 query.is_contiguous() 

891 and key.is_contiguous() 

892 and value.is_contiguous() 

893 and o.is_contiguous() 

894 ) 

895 assert query.stride() == o.stride() == do.stride() 

896 assert key.stride() == value.stride() 

897 

898 BLOCK_DMODEL = HEAD_DIM_K 

899 BATCH, Q_HEAD, Q_CTX = query.shape[:3] 

900 _, KV_HEAD, KV_CTX = key.shape[:3] 

901 group_head = Q_HEAD // KV_HEAD 

902 

903 # NUM_WARPS, NUM_STAGES = 4, 1 

904 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 

905 BLK_SLICE_FACTOR = 2 

906 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 

907 

908 RCP_LN2 = 1.0 / math.log(2) 

909 

910 arg_k = key * (sm_scale * RCP_LN2) 

911 # PRE_BLOCK = 128 

912 PRE_BLOCK = 256 

913 

914 # PRE_BLOCK = 32 

915 # assert N_CTX % PRE_BLOCK == 0 

916 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD) 

917 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD) 

918 

919 delta = torch.empty_like(M) 

920 

921 # NOTE that dk & dv always have the same number of heads as q 

922 dq = torch.empty_like(query).contiguous() 

923 dk = torch.empty( 

924 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K), 

925 device=key.device, 

926 dtype=key.dtype, 

927 memory_format=torch.contiguous_format, 

928 ) 

929 dv = torch.empty( 

930 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V), 

931 device=value.device, 

932 dtype=value.dtype, 

933 memory_format=torch.contiguous_format, 

934 ) 

935 

936 _attn_bwd_preprocess[pre_grid]( 

937 o, 

938 do, # 

939 delta, # 

940 BATCH, 

941 Q_HEAD, 

942 Q_CTX, # 

943 BLOCK_M=PRE_BLOCK, 

944 D_HEAD=BLOCK_DMODEL, # 

945 ) 

946 

947 max_block_n1 = ( 

948 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward]) 

949 if config_backward 

950 else 128 

951 ) 

952 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD) 

953 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}") 

954 # logger.info(f"{M.shape=}") 

955 

956 _attn_bwd[grid]( 

957 query, 

958 arg_k, 

959 value, 

960 sm_scale, 

961 do, 

962 dq, 

963 dk, 

964 dv, # 

965 M, 

966 delta, # 

967 query.stride(0), 

968 query.stride(1), 

969 query.stride(2), 

970 query.stride(3), # 

971 key.stride(0), 

972 key.stride(1), # 

973 Q_HEAD, 

974 Q_CTX, # 

975 KV_CTX, # 

976 KV_HEAD, # 

977 GROUP_HEAD=group_head, # 

978 # BLOCK_M1=BLOCK_M1, 

979 # BLOCK_N1=BLOCK_N1, # 

980 # BLOCK_M2=BLOCK_M2, 

981 # BLOCK_N2=BLOCK_N2, # 

982 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 

983 BLOCK_DMODEL=BLOCK_DMODEL, # 

984 # num_warps=NUM_WARPS, # 

985 # num_stages=NUM_STAGES, # 

986 ) 

987 

988 if group_head > 1: 

989 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K) 

990 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V) 

991 dk = dk.sum(dim=2) 

992 dv = dv.sum(dim=2) 

993 

994 return dq, dk, dv 

995 

996 

997class ScaleDotProductAttention(torch.autograd.Function): 

998 @staticmethod 

999 def forward( 

1000 ctx, 

1001 query, 

1002 key, 

1003 value, 

1004 attn_mask=None, 

1005 dropout_p=0.0, 

1006 is_causal=False, 

1007 scale=None, 

1008 enable_gqa=False, 

1009 ): 

1010 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5) 

1011 o, M = scaled_dot_product_attention_forward( 

1012 query, 

1013 key, 

1014 value, 

1015 attn_mask, 

1016 dropout_p, 

1017 is_causal, 

1018 sm_scale, 

1019 enable_gqa, 

1020 ) 

1021 

1022 ctx.save_for_backward(query, key, value, o, M) 

1023 ctx.sm_scale = sm_scale 

1024 ctx.causal = is_causal 

1025 ctx.enable_gqa = enable_gqa 

1026 return o 

1027 

1028 @staticmethod 

1029 def backward(ctx, do): 

1030 query, key, value, o, M = ctx.saved_tensors 

1031 is_causal = ctx.causal 

1032 enable_gqa = ctx.enable_gqa 

1033 sm_scale = ctx.sm_scale 

1034 dq, dk, dv = scaled_dot_product_attention_backward( 

1035 do, 

1036 query, 

1037 key, 

1038 value, 

1039 o, 

1040 M, 

1041 attn_mask=None, 

1042 dropout_p=0.0, 

1043 is_causal=is_causal, 

1044 scale=sm_scale, 

1045 enable_gqa=enable_gqa, 

1046 ) 

1047 return dq, dk, dv, None, None, None, None, None 

1048 

1049 

1050def scaled_dot_product_attention( 

1051 query, 

1052 key, 

1053 value, 

1054 attn_mask=None, 

1055 dropout_p=0.0, 

1056 is_causal=False, 

1057 scale=None, 

1058 enable_gqa=False, 

1059): 

1060 return ScaleDotProductAttention.apply( 

1061 query, 

1062 key, 

1063 value, 

1064 attn_mask, 

1065 dropout_p, 

1066 is_causal, 

1067 scale, 

1068 enable_gqa, 

1069 ) 

1070 

1071 

1072def flash_attention_forward( 

1073 query, 

1074 key, 

1075 value, 

1076 cumulative_sequence_length_q, 

1077 cumulative_sequence_length_k, 

1078 max_q, 

1079 max_k, 

1080 dropout_p, 

1081 is_causal, 

1082 return_debug_mask, 

1083 *, 

1084 scale=None, 

1085 softcap=0.0, 

1086 window_size_left=None, 

1087 window_size_right=None, 

1088 seqused_k=None, 

1089 alibi_slopes=None, 

1090 disable_splitkv=False, 

1091): 

1092 logger.debug("GEMS_CAMBRICON FLASH_ATTENTION_FORWARD") 

1093 assert ( 

1094 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None 

1095 ), "varlen is not supported yet." 

1096 

1097 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

1098 HEAD_DIM_V = value.shape[-1] 

1099 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

1100 original_head_dim = HEAD_DIM_K 

1101 supported_head_dims = (16, 32, 64, 96, 128, 192, 256) 

1102 if HEAD_DIM_K not in supported_head_dims: 

1103 padded_head_dim = None 

1104 for d in supported_head_dims: 

1105 if d >= HEAD_DIM_K: 

1106 padded_head_dim = d 

1107 break 

1108 assert ( 

1109 padded_head_dim is not None 

1110 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}" 

1111 pad = padded_head_dim - HEAD_DIM_K 

1112 query = F.pad(query, (0, pad)) 

1113 key = F.pad(key, (0, pad)) 

1114 value = F.pad(value, (0, pad)) 

1115 HEAD_DIM_K = padded_head_dim 

1116 

1117 softmax_scale = scale or 1.0 / (original_head_dim**0.5) 

1118 if window_size_left is not None: 

1119 non_null_window_left = window_size_left 

1120 else: 

1121 non_null_window_left = -1 

1122 if window_size_right is not None: 

1123 non_null_window_right = window_size_right 

1124 else: 

1125 non_null_window_right = -1 

1126 

1127 out = torch.empty_like(query) 

1128 if cumulative_sequence_length_q is not None: 

1129 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd( 

1130 query, 

1131 key, 

1132 value, 

1133 out, 

1134 cumulative_sequence_length_q, 

1135 cumulative_sequence_length_k, 

1136 seqused_k, 

1137 None, 

1138 None, # block_table 

1139 alibi_slopes, 

1140 max_q, 

1141 max_k, 

1142 dropout_p, 

1143 scale, 

1144 False, 

1145 is_causal, 

1146 non_null_window_left, 

1147 non_null_window_right, 

1148 softcap, 

1149 return_debug_mask and dropout_p > 0, 

1150 None, 

1151 ) 

1152 else: 

1153 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd( 

1154 query, 

1155 key, 

1156 value, 

1157 out, 

1158 alibi_slopes, 

1159 dropout_p, 

1160 softmax_scale, 

1161 is_causal, 

1162 non_null_window_left, 

1163 non_null_window_right, 

1164 softcap, 

1165 return_debug_mask, 

1166 disable_splitkv=disable_splitkv, 

1167 ) 

1168 

1169 if HEAD_DIM_K != original_head_dim: 

1170 out = out[..., :original_head_dim] 

1171 return (out, lse, philox_seed, philox_offset, p) 

1172 

1173 

1174# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py 

1175def maybe_contiguous(x): 

1176 return x.contiguous() if x is not None and x.stride(-1) != 1 else x 

1177 

1178 

1179def flash_attn_varlen_func( 

1180 q, 

1181 k, 

1182 v, 

1183 max_seqlen_q, 

1184 cu_seqlens_q, 

1185 max_seqlen_k, 

1186 cu_seqlens_k=None, # only used for non-paged prefill 

1187 seqused_k=None, 

1188 q_v=None, 

1189 dropout_p=0.0, 

1190 softmax_scale=None, 

1191 causal=False, 

1192 window_size=None, 

1193 softcap=0.0, # 0.0 means deactivated 

1194 alibi_slopes=None, 

1195 deterministic=False, 

1196 return_attn_probs=False, 

1197 block_table=None, 

1198 return_softmax_lse=False, 

1199 out=None, 

1200 # Dummy FA3 arguments 

1201 scheduler_metadata=None, 

1202 q_descale=None, 

1203 k_descale=None, 

1204 v_descale=None, 

1205 s_aux=None, 

1206 num_splits: int = 0, 

1207 cp_world_size: int = 1, 

1208 cp_rank: int = 0, 

1209 cp_tot_seqused_k=None, 

1210 fa_version: int = 2, 

1211): 

1212 """dropout_p should be set to 0.0 during evaluation 

1213 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads 

1214 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. 

1215 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 

1216 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 

1217 

1218 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. 

1219 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 

1220 1 1 1 1 0 

1221 1 1 1 1 1 

1222 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 

1223 0 0 

1224 0 0 

1225 0 0 

1226 1 0 

1227 1 1 

1228 If the row of the mask is all zero, the output will be zero. 

1229 

1230 If window_size != (-1, -1), implements sliding window local attention. Query at position i 

1231 will only attend to keys between 

1232 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. 

1233 

1234 Arguments: 

1235 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 

1236 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1237 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1238 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1239 of the sequences in the batch, used to index into q. 

1240 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1241 of the sequences in the batch, used to index into kv. 

1242 max_seqlen_q: int. Maximum query sequence length in the batch. 

1243 max_seqlen_k: int. Maximum key sequence length in the batch. 

1244 dropout_p: float. Dropout probability. 

1245 softmax_scale: float. The scaling of QK^T before applying softmax. 

1246 Default to 1 / sqrt(headdim). 

1247 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 

1248 window_size: (left, right). If not (-1, -1), implements sliding window local attention. 

1249 softcap: float. Anything > 0 activates softcapping attention. 

1250 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of 

1251 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 

1252 is added to the attention score of query i and key j. 

1253 deterministic: bool. Whether to use the deterministic implementation of the backward pass, 

1254 which is slightly slower and uses more memory. The forward pass is always deterministic. 

1255 return_attn_probs: bool. Whether to return the attention probabilities. This option is for 

1256 testing only. The returned probabilities are not guaranteed to be correct 

1257 (they might not have the right scaling). 

1258 Return: 

1259 out: (total, nheads, headdim). 

1260 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The 

1261 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax 

1262 normalization factor). 

1263 """ 

1264 if fa_version != 2: 

1265 raise RuntimeError("Only FA2 is implemented.") 

1266 if num_splits > 0: 

1267 raise RuntimeError("num_splits > 0 is not implemented in GEMS_CAMBRICON.") 

1268 if use_c_extension: 

1269 logger.debug("GEMS_CAMBRICON FLASH_ATTN_VARLEN_FUNC(C EXTENSION)") 

1270 with torch_device_fn.device(q.device): 

1271 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func( 

1272 q, 

1273 k, 

1274 v, 

1275 max_seqlen_q, 

1276 cu_seqlens_q, 

1277 max_seqlen_k, 

1278 cu_seqlens_k, 

1279 seqused_k, 

1280 q_v, 

1281 dropout_p, 

1282 softmax_scale, 

1283 causal, 

1284 window_size, 

1285 softcap, 

1286 alibi_slopes, 

1287 deterministic, 

1288 return_attn_probs, 

1289 block_table, 

1290 return_softmax_lse, 

1291 out, 

1292 scheduler_metadata, 

1293 q_descale, 

1294 k_descale, 

1295 v_descale, 

1296 s_aux, 

1297 num_splits, 

1298 cp_world_size, 

1299 cp_rank, 

1300 cp_tot_seqused_k, 

1301 fa_version, 

1302 ) 

1303 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp 

1304 else: 

1305 logger.debug("GEMS_CAMBRICON FLASH_ATTN_VARLEN_FUNC") 

1306 assert ( 

1307 cu_seqlens_k is not None or seqused_k is not None 

1308 ), "cu_seqlens_k or seqused_k must be provided" 

1309 assert ( 

1310 cu_seqlens_k is None or seqused_k is None 

1311 ), "cu_seqlens_k and seqused_k cannot be provided at the same time" 

1312 assert ( 

1313 block_table is None or seqused_k is not None 

1314 ), "seqused_k must be provided if block_table is provided" 

1315 if softmax_scale is None: 

1316 softmax_scale = q.shape[-1] ** (-0.5) 

1317 # custom op does not support non-tuple input 

1318 if window_size is None: 

1319 real_window_size = (-1, -1) 

1320 else: 

1321 assert len(window_size) == 2 

1322 real_window_size = (window_size[0], window_size[1]) 

1323 q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 

1324 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) 

1325 max_seqlen_q = ( 

1326 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q 

1327 ) 

1328 max_seqlen_k = ( 

1329 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k 

1330 ) 

1331 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd( 

1332 q, 

1333 k, 

1334 v, 

1335 out, 

1336 cu_seqlens_q, 

1337 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp 

1338 # still wants it so we pass all zeros 

1339 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, 

1340 seqused_k, 

1341 None, 

1342 block_table, 

1343 alibi_slopes, 

1344 max_seqlen_q, 

1345 max_seqlen_k, 

1346 dropout_p, 

1347 softmax_scale, 

1348 False, 

1349 causal, 

1350 real_window_size[0], 

1351 real_window_size[1], 

1352 softcap, 

1353 return_softmax_lse and dropout_p > 0, 

1354 None, 

1355 ) 

1356 

1357 return (out, softmax_lse) if return_softmax_lse else out