Coverage for src/flag_gems/runtime/backend/_cambricon/ops/attention.py: 0%
397 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
3from functools import partial
5import torch
6import torch.nn.functional as F
7import triton
8import triton.language as tl
10from flag_gems import runtime
11from flag_gems.config import use_c_extension
12from flag_gems.ops.flash_api import mha_fwd, mha_varlan_fwd
13from flag_gems.ops.flash_kernel import keep
14from flag_gems.runtime import torch_device_fn
15from flag_gems.utils import libentry, libtuner
17logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
20# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
21@triton.jit
22def _attn_fwd_inner(
23 acc,
24 l_i,
25 m_i,
26 query, #
27 K_block_ptr,
28 V_block_ptr, #
29 mask_block_ptr, #
30 stride_k_seqlen,
31 stride_v_seqlen,
32 stride_attn_mask_kv_seqlen, #
33 start_m,
34 qk_scale, #
35 q_load_mask,
36 BLOCK_M: tl.constexpr,
37 HEAD_DIM: tl.constexpr,
38 BLOCK_N: tl.constexpr, #
39 STAGE: tl.constexpr,
40 offs_m: tl.constexpr,
41 offs_n: tl.constexpr, #
42 KV_CTX: tl.constexpr,
43 fp8_v: tl.constexpr,
44 HAS_ATTN_MASK: tl.constexpr,
45 PRE_LOAD_V: tl.constexpr,
46):
47 # range of values handled by this stage
48 if STAGE == 1:
49 lo, hi = 0, start_m * BLOCK_M
50 elif STAGE == 2:
51 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
52 # causal = False
53 else:
54 lo, hi = 0, KV_CTX
56 K_block_ptr += lo * stride_k_seqlen
57 V_block_ptr += lo * stride_v_seqlen
58 if HAS_ATTN_MASK:
59 mask_block_ptr += lo * stride_attn_mask_kv_seqlen
61 LOG2E = 1.44269504 # log2(e) constant
63 # loop over key, value and update accumulator
64 for start_n in range(lo, hi, BLOCK_N):
65 kv_load_mask = (start_n + offs_n) < KV_CTX
66 # start_n = tl.multiple_of(start_n, BLOCK_N)
67 # -- compute qk ----
68 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0)
69 if PRE_LOAD_V:
70 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
72 qk = tl.dot(query, key, allow_tf32=False)
73 # incase not divisible.
74 qk = tl.where(kv_load_mask[None, :], qk, -float("inf"))
75 # qk = qk.to(tl.float32)
77 if HAS_ATTN_MASK:
78 attn_mask = tl.load(
79 mask_block_ptr,
80 mask=q_load_mask[:, None] & kv_load_mask[None, :],
81 other=0.0,
82 )
84 if STAGE == 2:
85 mask = offs_m[:, None] >= (start_n + offs_n[None, :])
87 if HAS_ATTN_MASK:
88 qk = qk * qk_scale + attn_mask
89 qk *= LOG2E
90 qk = qk + tl.where(mask, 0, -1.0e6)
91 else:
92 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6)
94 m_ij = tl.maximum(m_i, tl.max(qk, 1))
95 qk -= m_ij[:, None]
96 else:
97 qk *= qk_scale * LOG2E
98 if HAS_ATTN_MASK:
99 qk = qk + attn_mask
100 m_ij = tl.maximum(m_i, tl.max(qk, 1))
101 qk = qk - m_ij[:, None]
103 p = tl.math.exp2(qk)
104 l_ij = tl.sum(p, 1)
105 # -- update m_i and l_i
106 alpha = tl.math.exp2(m_i - m_ij)
107 l_i = l_i * alpha + l_ij
108 # -- update output accumulator --
109 acc = acc * alpha[:, None]
110 # update acc
111 if not PRE_LOAD_V:
112 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
113 if fp8_v:
114 p = p.to(tl.float8e5)
115 acc = tl.dot(p, value.to(p.dtype), acc, allow_tf32=False)
116 # update m_i and l_i
117 m_i = m_ij
119 K_block_ptr += BLOCK_N * stride_k_seqlen
120 V_block_ptr += BLOCK_N * stride_v_seqlen
122 if HAS_ATTN_MASK:
123 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen
125 return acc, l_i, m_i
128# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim,
129# we need to generate more configs.
130configs = runtime.get_tuned_config("attention")
131SMALL_HEAD_DIM_CONFIGS = [
132 triton.Config(
133 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w
134 )
135 for BM in [64, 128]
136 for BN in [16, 32]
137 for s in [2, 3, 4]
138 for w in [4, 8]
139]
140configs += SMALL_HEAD_DIM_CONFIGS
143@libentry()
144@libtuner(
145 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)),
146 key=["KV_CTX", "HEAD_DIM"],
147)
148@triton.jit
149def _attn_fwd(
150 Q,
151 K,
152 V,
153 attn_mask,
154 sm_scale,
155 M,
156 Out, #
157 stride_q_batch,
158 stride_q_head,
159 stride_q_seqlen,
160 stride_q_headsize,
161 stride_k_batch,
162 stride_k_head,
163 stride_k_seqlen,
164 stride_k_headsize,
165 stride_v_batch,
166 stride_v_head,
167 stride_v_seqlen,
168 stride_v_headsize,
169 stride_attn_mask_batch,
170 stride_attn_mask_head,
171 stride_attn_mask_q_seqlen,
172 stride_attn_mask_kv_seqlen,
173 stride_o_batch,
174 stride_o_head,
175 stride_o_seqlen,
176 stride_o_headsize,
177 Z,
178 q_head_num,
179 kv_head_num,
180 GROUP_HEAD: tl.constexpr,
181 Q_CTX,
182 KV_CTX,
183 HEAD_DIM: tl.constexpr,
184 BLOCK_M: tl.constexpr,
185 BLOCK_N: tl.constexpr,
186 STAGE: tl.constexpr,
187 HAS_ATTN_MASK: tl.constexpr,
188 PRE_LOAD_V: tl.constexpr,
189):
190 tl.static_assert(BLOCK_N <= HEAD_DIM)
191 start_m = tl.program_id(0)
192 off_hz = tl.program_id(1)
193 batch_id = off_hz // q_head_num
194 head_id = off_hz % q_head_num
195 kv_head_id = head_id // GROUP_HEAD
197 q_offset = (
198 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head
199 )
200 o_offset = (
201 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head
202 )
203 kv_offset = (
204 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head
205 )
207 offs_headsize = tl.arange(0, HEAD_DIM)
209 # initialize offsets
210 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
211 q_load_mask = offs_m < Q_CTX
212 offs_n = tl.arange(0, BLOCK_N)
214 Q_block_ptr = (
215 Q
216 + q_offset
217 + offs_m[:, None] * stride_q_seqlen
218 + offs_headsize[None, :] * stride_q_headsize
219 )
220 K_block_ptr = (
221 K
222 + kv_offset
223 + offs_n[None, :] * stride_k_seqlen
224 + offs_headsize[:, None] * stride_k_headsize
225 )
226 V_block_ptr = (
227 V
228 + kv_offset
229 + offs_n[:, None] * stride_v_seqlen
230 + offs_headsize[None, :] * stride_v_headsize
231 )
233 if HAS_ATTN_MASK:
234 attn_mask_offset = (
235 batch_id.to(tl.int64) * stride_attn_mask_batch
236 + head_id.to(tl.int64) * stride_attn_mask_head
237 )
238 mask_block_ptr = (
239 attn_mask
240 + attn_mask_offset
241 + offs_m[:, None] * stride_attn_mask_q_seqlen
242 + offs_n[None, :] * stride_attn_mask_kv_seqlen
243 )
244 else:
245 mask_block_ptr = None
247 O_block_ptr = (
248 Out
249 + o_offset
250 + offs_m[:, None] * stride_o_seqlen
251 + offs_headsize[None, :] * stride_o_headsize
252 )
254 # initialize pointer to m and l
255 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
256 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
257 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
258 # load scales
259 qk_scale = sm_scale
260 # qk_scale *= 1.44269504 # 1/log(2)
261 # load query: it will stay in SRAM throughout
262 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0)
263 # stage 1: off-band
264 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
265 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
266 if STAGE & 1:
267 acc, l_i, m_i = _attn_fwd_inner(
268 acc,
269 l_i,
270 m_i,
271 query,
272 K_block_ptr,
273 V_block_ptr,
274 mask_block_ptr,
275 stride_k_seqlen,
276 stride_v_seqlen,
277 stride_attn_mask_kv_seqlen,
278 start_m,
279 qk_scale,
280 q_load_mask,
281 BLOCK_M,
282 HEAD_DIM,
283 BLOCK_N,
284 4 - STAGE,
285 offs_m,
286 offs_n,
287 KV_CTX,
288 V.dtype.element_ty == tl.float8e5,
289 HAS_ATTN_MASK,
290 PRE_LOAD_V,
291 )
292 # stage 2: on-band
293 if STAGE & 2:
294 # barrier makes it easier for compielr to schedule the
295 # two loops independently
296 acc, l_i, m_i = _attn_fwd_inner(
297 acc,
298 l_i,
299 m_i,
300 query,
301 K_block_ptr,
302 V_block_ptr,
303 mask_block_ptr,
304 stride_k_seqlen,
305 stride_v_seqlen,
306 stride_attn_mask_kv_seqlen,
307 start_m,
308 qk_scale,
309 q_load_mask,
310 BLOCK_M,
311 HEAD_DIM,
312 BLOCK_N,
313 2,
314 offs_m,
315 offs_n,
316 KV_CTX,
317 V.dtype.element_ty == tl.float8e5,
318 HAS_ATTN_MASK,
319 PRE_LOAD_V,
320 )
321 # epilogue
322 m_i += tl.math.log2(l_i)
323 acc = acc / l_i[:, None]
324 m_ptrs = M + off_hz * Q_CTX + offs_m
325 tl.store(m_ptrs, m_i, mask=q_load_mask)
326 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])
329@triton.jit
330def _attn_bwd_preprocess(
331 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
332):
333 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
334 mask = off_m < Q_CTX
336 off_hz = tl.program_id(1)
337 off_n = tl.arange(0, D_HEAD)
338 # load
339 o = tl.load(
340 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
341 mask=mask[:, None],
342 other=0.0,
343 )
344 do = tl.load(
345 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
346 mask=mask[:, None],
347 other=0.0,
348 ).to(tl.float32)
349 delta = tl.sum(o * do, axis=1)
350 # write-back
351 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask)
354# The main inner-loop logic for computing dK and dV.
355@triton.jit
356def _attn_bwd_dkdv(
357 dk,
358 dv, #
359 Q,
360 key,
361 value,
362 sm_scale, #
363 DO, #
364 M,
365 D, #
366 # shared by Q/K/V/DO.
367 stride_tok,
368 stride_d, #
369 H,
370 Q_CTX,
371 KV_CTX,
372 BLOCK_M1: tl.constexpr, #
373 BLOCK_N1: tl.constexpr, #
374 BLOCK_DMODEL: tl.constexpr, #
375 # Filled in by the wrapper.
376 start_n,
377 start_m,
378 num_steps, #
379 MASK: tl.constexpr,
380):
381 # BLOCK_M1: 32
382 # BLOCK_N1: 128
383 offs_n = start_n + tl.arange(0, BLOCK_N1)
384 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, )
386 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, )
388 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
389 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
390 curr_m = start_m
391 step_m = BLOCK_M1
392 for blk_idx in range(num_steps):
393 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, )
394 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, )
396 qT_ptrs = (
397 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
398 ) # (BLOCK_DMODEL, BLOCK_M1)
399 do_ptrs = (
400 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
401 ) # (BLOCK_M1, BLOCK_DMODEL)
403 qT = tl.load(
404 qT_ptrs, mask=offs_m_mask[None, :], other=0.0
405 ) # (BLOCK_DMODEL, BLOCK_M1)
407 # Load m before computing qk to reduce pipeline stall.
408 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, )
410 # key: (BLOCK_N1, BLOCK_DMODEL)
411 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1)
412 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1)
413 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1)
414 pT = tl.math.exp2(qkT - m)
415 # pT = tl.math.exp2(qkT - m[None, :])
417 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[
418 :, None
419 ] # (BLOCK_N1, BLOCK_M1)
420 # Autoregressive masking.
421 if MASK:
422 mask &= offs_m[None, :] >= offs_n[:, None]
423 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1)
425 do = tl.load(do_ptrs)
426 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL)
428 # Compute dV.
429 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL)
430 # D (= delta) is pre-divided by ds_scale.
431 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, )
433 # Compute dP and dS.
434 dpT = tl.dot(value, tl.trans(do)).to(
435 tl.float32
436 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1)
437 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1)
438 dsT = dsT.to(qT.dtype)
439 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1)
440 dsT = tl.where(
441 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0
442 ) # (BLOCK_N1, BLOCK_M1)
443 dk += tl.dot(
444 dsT, tl.trans(qT)
445 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL)
446 # Increment pointers.
447 curr_m += step_m
448 return dk, dv
451# the main inner-loop logic for computing dQ
452@triton.jit
453def _attn_bwd_dq(
454 dq,
455 query,
456 K,
457 V, #
458 do,
459 m,
460 D,
461 # shared by Q/K/V/DO.
462 stride_tok,
463 stride_d, #
464 H,
465 Q_CTX, #
466 KV_CTX, #
467 BLOCK_M2: tl.constexpr, #
468 BLOCK_N2: tl.constexpr, #
469 BLOCK_DMODEL: tl.constexpr,
470 # Filled in by the wrapper.
471 start_m,
472 start_n,
473 num_steps, #
474 MASK: tl.constexpr,
475):
476 offs_m = start_m + tl.arange(0, BLOCK_M2)
477 offs_m_mask = offs_m < Q_CTX
479 offs_k = tl.arange(0, BLOCK_DMODEL)
480 # D (= delta) is pre-divided by ds_scale.
481 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0)
482 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
483 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
484 curr_n = start_n
485 step_n = BLOCK_N2
486 for blk_idx in range(num_steps):
487 offs_n = curr_n + tl.arange(0, BLOCK_N2)
488 offs_n_mask = offs_n < KV_CTX
490 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
491 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
493 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0)
494 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0)
495 qk = tl.dot(query, kT)
496 p = tl.math.exp2(qk - m)
497 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :]
498 # Autoregressive masking.
499 if MASK:
500 # mask = (offs_m[:, None] >= offs_n[None, :])
501 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :]
502 mask &= offs_m[:, None] >= offs_n[None, :]
503 p = tl.where(mask, p, 0.0)
504 # Compute dP and dS.
505 dp = tl.dot(do, vT).to(tl.float32)
506 ds = p * (dp - Di[:, None])
507 ds = tl.where(mask, ds, 0.0).to(kT.dtype)
508 # Compute dQ.
509 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
510 dq += tl.dot(ds, tl.trans(kT))
511 # Increment pointers.
512 curr_n += step_n
513 return dq
516config_backward = runtime.get_tuned_config("attention_bwd")
519@libentry()
520@libtuner(
521 configs=config_backward,
522 key=["KV_CTX", "BLOCK_DMODEL"],
523)
524@triton.jit
525def _attn_bwd(
526 Q,
527 K,
528 V,
529 sm_scale, #
530 DO, #
531 DQ,
532 DK,
533 DV, #
534 M,
535 D,
536 # shared by Q/K/V/DO.
537 stride_z,
538 stride_h,
539 stride_tok,
540 stride_d, #
541 kv_stride_z,
542 kv_stride_h, #
543 H, # query head num
544 Q_CTX, #
545 KV_CTX, #
546 kv_head_num, #
547 GROUP_HEAD: tl.constexpr, #
548 BLOCK_M1: tl.constexpr, #
549 BLOCK_N1: tl.constexpr, #
550 BLOCK_M2: tl.constexpr, #
551 BLOCK_N2: tl.constexpr, #
552 BLK_SLICE_FACTOR: tl.constexpr, #
553 BLOCK_DMODEL: tl.constexpr,
554):
555 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.")
557 LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
559 bhid = tl.program_id(2)
560 off_chz = (bhid * Q_CTX).to(tl.int64)
561 batch_id = bhid // H
562 q_head_id = bhid % H
563 kv_head_id = q_head_id // GROUP_HEAD
564 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64)
565 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64)
567 pid = tl.program_id(0)
569 # offset pointers for batch/head
570 Q += adj
571 K += kv_adj
572 V += kv_adj
573 DO += adj
574 DQ += adj
575 DK += adj
576 DV += adj
577 M += off_chz
578 D += off_chz
580 # load scales
581 offs_k = tl.arange(0, BLOCK_DMODEL)
583 start_n = pid * BLOCK_N1
584 start_m = start_n
586 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
587 offs_n = start_n + tl.arange(0, BLOCK_N1)
588 offs_n_mask = offs_n < KV_CTX
590 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
591 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
593 # load K and V: they stay in SRAM throughout the inner loop.
594 key = tl.load(
595 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
596 mask=offs_n_mask[:, None],
597 other=0.0,
598 )
599 value = tl.load(
600 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
601 mask=offs_n_mask[:, None],
602 other=0.0,
603 )
605 num_steps = BLOCK_N1 // MASK_BLOCK_M1
607 dk, dv = _attn_bwd_dkdv(
608 dk,
609 dv, #
610 Q,
611 key,
612 value,
613 sm_scale, #
614 DO, #
615 M,
616 D, #
617 stride_tok,
618 stride_d, #
619 H,
620 Q_CTX, #
621 KV_CTX, #
622 MASK_BLOCK_M1,
623 BLOCK_N1,
624 BLOCK_DMODEL, #
625 start_n,
626 start_m,
627 num_steps, #
628 MASK=True, #
629 )
631 # Compute dK and dV for non-masked blocks.
632 start_m += num_steps * MASK_BLOCK_M1
633 remaining_m = Q_CTX - start_m
634 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1
636 if num_steps > 0 and start_m < Q_CTX:
637 dk, dv = _attn_bwd_dkdv( #
638 dk,
639 dv, #
640 Q,
641 key,
642 value,
643 sm_scale, #
644 DO, #
645 M,
646 D, #
647 stride_tok,
648 stride_d, #
649 H,
650 Q_CTX, #
651 KV_CTX, #
652 BLOCK_M1,
653 BLOCK_N1,
654 BLOCK_DMODEL, #
655 start_n,
656 start_m,
657 num_steps, #
658 MASK=False, #
659 )
660 # tl.device_print("dv: ", dv)
662 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
663 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None])
665 # Write back dK.
666 dk *= sm_scale
667 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
668 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None])
670 # THIS BLOCK DOES DQ:
671 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
672 start_m = pid * BLOCK_M2
673 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX
674 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2
676 offs_m = start_m + tl.arange(0, BLOCK_M2)
677 offs_m_mask = offs_m < Q_CTX
679 query = tl.load(
680 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
681 mask=offs_m_mask[:, None],
682 other=0.0,
683 )
684 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
685 do = tl.load(
686 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
687 mask=offs_m_mask[:, None],
688 other=0.0,
689 )
691 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf"))
692 m = m[:, None]
694 # Stage 1 - Compute dQ for masked (diagonal) blocks.
695 # NOTE: This code scans each row of QK^T backward (from right to left,
696 # but inside each call to _attn_bwd_dq, from left to right), but that's
697 # not due to anything important. I just wanted to reuse the loop
698 # structure for dK & dV above as much as possible.
700 if num_steps > 0:
701 dq = _attn_bwd_dq(
702 dq,
703 query,
704 K,
705 V, #
706 do,
707 m,
708 D, #
709 stride_tok,
710 stride_d, #
711 H,
712 Q_CTX, #
713 KV_CTX, #
714 BLOCK_M2,
715 MASK_BLOCK_N2,
716 BLOCK_DMODEL, #
717 start_m,
718 start_n,
719 num_steps, #
720 MASK=True, #
721 )
723 # Stage 2 - non-masked blocks
724 stage2_end_n = start_n
725 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2
727 if stage2_num_steps > 0:
728 dq = _attn_bwd_dq(
729 dq,
730 query,
731 K,
732 V, #
733 do,
734 m,
735 D, #
736 stride_tok,
737 stride_d, #
738 H,
739 Q_CTX, #
740 KV_CTX, #
741 BLOCK_M2,
742 BLOCK_N2,
743 BLOCK_DMODEL, #
744 start_m,
745 stage2_end_n - stage2_num_steps * BLOCK_N2,
746 stage2_num_steps, #
747 MASK=False, #
748 )
749 # Write back dQ.
750 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
751 dq *= LN2
752 # tl.store(dq_ptrs, dq)
754 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None])
757def scaled_dot_product_attention_forward(
758 query,
759 key,
760 value,
761 attn_mask=None,
762 dropout_p=0.0,
763 is_causal=False,
764 scale=None,
765 enable_gqa=False,
766):
767 logger.debug("GEMS_CAMBRICON SCALED DOT PRODUCT ATTENTION FORWARD")
768 # shape constraints
769 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
770 # when v is in float8_e5m2 it is transposed.
771 HEAD_DIM_V = value.shape[-1]
772 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
773 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
774 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
776 o = torch.empty_like(query, dtype=value.dtype)
778 stage = 3 if is_causal else 1
780 if scale is None:
781 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
782 else:
783 sm_scale = scale
785 q_head_num = query.shape[1]
786 kv_head_num = key.shape[1]
787 assert enable_gqa or q_head_num == kv_head_num, (
788 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, "
789 "enable_gqa must be True to support different head numbers."
790 )
792 grid = lambda args: (
793 triton.cdiv(query.shape[2], args["BLOCK_M"]),
794 query.shape[0] * query.shape[1],
795 1,
796 )
798 if attn_mask is not None:
799 HAS_ATTN_MASK = True
800 if attn_mask.dtype == torch.bool:
801 attn_mask = attn_mask.to(query.dtype) * -1.0e6
802 stride_attn_mask_batch = attn_mask.stride(0)
803 stride_attn_mask_head = attn_mask.stride(1)
804 stride_attn_mask_q_seqlen = attn_mask.stride(2)
805 stride_attn_mask_kv_seqlen = attn_mask.stride(3)
806 else:
807 HAS_ATTN_MASK = False
808 stride_attn_mask_batch = 1
809 stride_attn_mask_head = 1
810 stride_attn_mask_q_seqlen = 1
811 stride_attn_mask_kv_seqlen = 1
813 M = torch.empty(
814 (query.shape[0], query.shape[1], query.shape[2]),
815 device=query.device,
816 dtype=torch.float32,
817 )
819 with torch_device_fn.device(query.device):
820 _attn_fwd[grid](
821 query,
822 key,
823 value,
824 attn_mask,
825 sm_scale,
826 M,
827 o, #
828 query.stride(0),
829 query.stride(1),
830 query.stride(2),
831 query.stride(3), #
832 key.stride(0),
833 key.stride(1),
834 key.stride(2),
835 key.stride(3), #
836 value.stride(0),
837 value.stride(1),
838 value.stride(2),
839 value.stride(3), #
840 stride_attn_mask_batch,
841 stride_attn_mask_head,
842 stride_attn_mask_q_seqlen,
843 stride_attn_mask_kv_seqlen, #
844 o.stride(0),
845 o.stride(1),
846 o.stride(2),
847 o.stride(3), #
848 query.shape[0],
849 q_head_num,
850 kv_head_num, #
851 q_head_num // kv_head_num, # group_head
852 query.shape[2], #
853 key.shape[2], #
854 HEAD_DIM_K, #
855 STAGE=stage, #
856 HAS_ATTN_MASK=HAS_ATTN_MASK, #
857 )
858 return o, M
861def scaled_dot_product_attention_backward(
862 do,
863 query,
864 key,
865 value,
866 o,
867 M,
868 attn_mask=None,
869 dropout_p=0.0,
870 is_causal=False,
871 scale=None,
872 enable_gqa=False,
873):
874 logger.debug("GEMS_CAMBRICON SCALED DOT PRODUCT ATTENTION BACKWARD")
875 # shape constraints
876 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
877 # when v is in float8_e5m2 it is transposed.
878 HEAD_DIM_V = value.shape[-1]
879 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
880 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
881 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
883 if scale is None:
884 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
885 else:
886 sm_scale = scale
888 assert do.is_contiguous()
889 assert (
890 query.is_contiguous()
891 and key.is_contiguous()
892 and value.is_contiguous()
893 and o.is_contiguous()
894 )
895 assert query.stride() == o.stride() == do.stride()
896 assert key.stride() == value.stride()
898 BLOCK_DMODEL = HEAD_DIM_K
899 BATCH, Q_HEAD, Q_CTX = query.shape[:3]
900 _, KV_HEAD, KV_CTX = key.shape[:3]
901 group_head = Q_HEAD // KV_HEAD
903 # NUM_WARPS, NUM_STAGES = 4, 1
904 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
905 BLK_SLICE_FACTOR = 2
906 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
908 RCP_LN2 = 1.0 / math.log(2)
910 arg_k = key * (sm_scale * RCP_LN2)
911 # PRE_BLOCK = 128
912 PRE_BLOCK = 256
914 # PRE_BLOCK = 32
915 # assert N_CTX % PRE_BLOCK == 0
916 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD)
917 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD)
919 delta = torch.empty_like(M)
921 # NOTE that dk & dv always have the same number of heads as q
922 dq = torch.empty_like(query).contiguous()
923 dk = torch.empty(
924 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K),
925 device=key.device,
926 dtype=key.dtype,
927 memory_format=torch.contiguous_format,
928 )
929 dv = torch.empty(
930 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V),
931 device=value.device,
932 dtype=value.dtype,
933 memory_format=torch.contiguous_format,
934 )
936 _attn_bwd_preprocess[pre_grid](
937 o,
938 do, #
939 delta, #
940 BATCH,
941 Q_HEAD,
942 Q_CTX, #
943 BLOCK_M=PRE_BLOCK,
944 D_HEAD=BLOCK_DMODEL, #
945 )
947 max_block_n1 = (
948 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward])
949 if config_backward
950 else 128
951 )
952 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD)
953 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}")
954 # logger.info(f"{M.shape=}")
956 _attn_bwd[grid](
957 query,
958 arg_k,
959 value,
960 sm_scale,
961 do,
962 dq,
963 dk,
964 dv, #
965 M,
966 delta, #
967 query.stride(0),
968 query.stride(1),
969 query.stride(2),
970 query.stride(3), #
971 key.stride(0),
972 key.stride(1), #
973 Q_HEAD,
974 Q_CTX, #
975 KV_CTX, #
976 KV_HEAD, #
977 GROUP_HEAD=group_head, #
978 # BLOCK_M1=BLOCK_M1,
979 # BLOCK_N1=BLOCK_N1, #
980 # BLOCK_M2=BLOCK_M2,
981 # BLOCK_N2=BLOCK_N2, #
982 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
983 BLOCK_DMODEL=BLOCK_DMODEL, #
984 # num_warps=NUM_WARPS, #
985 # num_stages=NUM_STAGES, #
986 )
988 if group_head > 1:
989 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K)
990 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V)
991 dk = dk.sum(dim=2)
992 dv = dv.sum(dim=2)
994 return dq, dk, dv
997class ScaleDotProductAttention(torch.autograd.Function):
998 @staticmethod
999 def forward(
1000 ctx,
1001 query,
1002 key,
1003 value,
1004 attn_mask=None,
1005 dropout_p=0.0,
1006 is_causal=False,
1007 scale=None,
1008 enable_gqa=False,
1009 ):
1010 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5)
1011 o, M = scaled_dot_product_attention_forward(
1012 query,
1013 key,
1014 value,
1015 attn_mask,
1016 dropout_p,
1017 is_causal,
1018 sm_scale,
1019 enable_gqa,
1020 )
1022 ctx.save_for_backward(query, key, value, o, M)
1023 ctx.sm_scale = sm_scale
1024 ctx.causal = is_causal
1025 ctx.enable_gqa = enable_gqa
1026 return o
1028 @staticmethod
1029 def backward(ctx, do):
1030 query, key, value, o, M = ctx.saved_tensors
1031 is_causal = ctx.causal
1032 enable_gqa = ctx.enable_gqa
1033 sm_scale = ctx.sm_scale
1034 dq, dk, dv = scaled_dot_product_attention_backward(
1035 do,
1036 query,
1037 key,
1038 value,
1039 o,
1040 M,
1041 attn_mask=None,
1042 dropout_p=0.0,
1043 is_causal=is_causal,
1044 scale=sm_scale,
1045 enable_gqa=enable_gqa,
1046 )
1047 return dq, dk, dv, None, None, None, None, None
1050def scaled_dot_product_attention(
1051 query,
1052 key,
1053 value,
1054 attn_mask=None,
1055 dropout_p=0.0,
1056 is_causal=False,
1057 scale=None,
1058 enable_gqa=False,
1059):
1060 return ScaleDotProductAttention.apply(
1061 query,
1062 key,
1063 value,
1064 attn_mask,
1065 dropout_p,
1066 is_causal,
1067 scale,
1068 enable_gqa,
1069 )
1072def flash_attention_forward(
1073 query,
1074 key,
1075 value,
1076 cumulative_sequence_length_q,
1077 cumulative_sequence_length_k,
1078 max_q,
1079 max_k,
1080 dropout_p,
1081 is_causal,
1082 return_debug_mask,
1083 *,
1084 scale=None,
1085 softcap=0.0,
1086 window_size_left=None,
1087 window_size_right=None,
1088 seqused_k=None,
1089 alibi_slopes=None,
1090 disable_splitkv=False,
1091):
1092 logger.debug("GEMS_CAMBRICON FLASH_ATTENTION_FORWARD")
1093 assert (
1094 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None
1095 ), "varlen is not supported yet."
1097 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
1098 HEAD_DIM_V = value.shape[-1]
1099 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
1100 original_head_dim = HEAD_DIM_K
1101 supported_head_dims = (16, 32, 64, 96, 128, 192, 256)
1102 if HEAD_DIM_K not in supported_head_dims:
1103 padded_head_dim = None
1104 for d in supported_head_dims:
1105 if d >= HEAD_DIM_K:
1106 padded_head_dim = d
1107 break
1108 assert (
1109 padded_head_dim is not None
1110 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}"
1111 pad = padded_head_dim - HEAD_DIM_K
1112 query = F.pad(query, (0, pad))
1113 key = F.pad(key, (0, pad))
1114 value = F.pad(value, (0, pad))
1115 HEAD_DIM_K = padded_head_dim
1117 softmax_scale = scale or 1.0 / (original_head_dim**0.5)
1118 if window_size_left is not None:
1119 non_null_window_left = window_size_left
1120 else:
1121 non_null_window_left = -1
1122 if window_size_right is not None:
1123 non_null_window_right = window_size_right
1124 else:
1125 non_null_window_right = -1
1127 out = torch.empty_like(query)
1128 if cumulative_sequence_length_q is not None:
1129 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd(
1130 query,
1131 key,
1132 value,
1133 out,
1134 cumulative_sequence_length_q,
1135 cumulative_sequence_length_k,
1136 seqused_k,
1137 None,
1138 None, # block_table
1139 alibi_slopes,
1140 max_q,
1141 max_k,
1142 dropout_p,
1143 scale,
1144 False,
1145 is_causal,
1146 non_null_window_left,
1147 non_null_window_right,
1148 softcap,
1149 return_debug_mask and dropout_p > 0,
1150 None,
1151 )
1152 else:
1153 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd(
1154 query,
1155 key,
1156 value,
1157 out,
1158 alibi_slopes,
1159 dropout_p,
1160 softmax_scale,
1161 is_causal,
1162 non_null_window_left,
1163 non_null_window_right,
1164 softcap,
1165 return_debug_mask,
1166 disable_splitkv=disable_splitkv,
1167 )
1169 if HEAD_DIM_K != original_head_dim:
1170 out = out[..., :original_head_dim]
1171 return (out, lse, philox_seed, philox_offset, p)
1174# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py
1175def maybe_contiguous(x):
1176 return x.contiguous() if x is not None and x.stride(-1) != 1 else x
1179def flash_attn_varlen_func(
1180 q,
1181 k,
1182 v,
1183 max_seqlen_q,
1184 cu_seqlens_q,
1185 max_seqlen_k,
1186 cu_seqlens_k=None, # only used for non-paged prefill
1187 seqused_k=None,
1188 q_v=None,
1189 dropout_p=0.0,
1190 softmax_scale=None,
1191 causal=False,
1192 window_size=None,
1193 softcap=0.0, # 0.0 means deactivated
1194 alibi_slopes=None,
1195 deterministic=False,
1196 return_attn_probs=False,
1197 block_table=None,
1198 return_softmax_lse=False,
1199 out=None,
1200 # Dummy FA3 arguments
1201 scheduler_metadata=None,
1202 q_descale=None,
1203 k_descale=None,
1204 v_descale=None,
1205 s_aux=None,
1206 num_splits: int = 0,
1207 cp_world_size: int = 1,
1208 cp_rank: int = 0,
1209 cp_tot_seqused_k=None,
1210 fa_version: int = 2,
1211):
1212 """dropout_p should be set to 0.0 during evaluation
1213 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1214 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1215 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1216 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1218 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1219 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1220 1 1 1 1 0
1221 1 1 1 1 1
1222 If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1223 0 0
1224 0 0
1225 0 0
1226 1 0
1227 1 1
1228 If the row of the mask is all zero, the output will be zero.
1230 If window_size != (-1, -1), implements sliding window local attention. Query at position i
1231 will only attend to keys between
1232 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1234 Arguments:
1235 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1236 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1237 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1238 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1239 of the sequences in the batch, used to index into q.
1240 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1241 of the sequences in the batch, used to index into kv.
1242 max_seqlen_q: int. Maximum query sequence length in the batch.
1243 max_seqlen_k: int. Maximum key sequence length in the batch.
1244 dropout_p: float. Dropout probability.
1245 softmax_scale: float. The scaling of QK^T before applying softmax.
1246 Default to 1 / sqrt(headdim).
1247 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1248 window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1249 softcap: float. Anything > 0 activates softcapping attention.
1250 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1251 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1252 is added to the attention score of query i and key j.
1253 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1254 which is slightly slower and uses more memory. The forward pass is always deterministic.
1255 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1256 testing only. The returned probabilities are not guaranteed to be correct
1257 (they might not have the right scaling).
1258 Return:
1259 out: (total, nheads, headdim).
1260 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
1261 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1262 normalization factor).
1263 """
1264 if fa_version != 2:
1265 raise RuntimeError("Only FA2 is implemented.")
1266 if num_splits > 0:
1267 raise RuntimeError("num_splits > 0 is not implemented in GEMS_CAMBRICON.")
1268 if use_c_extension:
1269 logger.debug("GEMS_CAMBRICON FLASH_ATTN_VARLEN_FUNC(C EXTENSION)")
1270 with torch_device_fn.device(q.device):
1271 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func(
1272 q,
1273 k,
1274 v,
1275 max_seqlen_q,
1276 cu_seqlens_q,
1277 max_seqlen_k,
1278 cu_seqlens_k,
1279 seqused_k,
1280 q_v,
1281 dropout_p,
1282 softmax_scale,
1283 causal,
1284 window_size,
1285 softcap,
1286 alibi_slopes,
1287 deterministic,
1288 return_attn_probs,
1289 block_table,
1290 return_softmax_lse,
1291 out,
1292 scheduler_metadata,
1293 q_descale,
1294 k_descale,
1295 v_descale,
1296 s_aux,
1297 num_splits,
1298 cp_world_size,
1299 cp_rank,
1300 cp_tot_seqused_k,
1301 fa_version,
1302 )
1303 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp
1304 else:
1305 logger.debug("GEMS_CAMBRICON FLASH_ATTN_VARLEN_FUNC")
1306 assert (
1307 cu_seqlens_k is not None or seqused_k is not None
1308 ), "cu_seqlens_k or seqused_k must be provided"
1309 assert (
1310 cu_seqlens_k is None or seqused_k is None
1311 ), "cu_seqlens_k and seqused_k cannot be provided at the same time"
1312 assert (
1313 block_table is None or seqused_k is not None
1314 ), "seqused_k must be provided if block_table is provided"
1315 if softmax_scale is None:
1316 softmax_scale = q.shape[-1] ** (-0.5)
1317 # custom op does not support non-tuple input
1318 if window_size is None:
1319 real_window_size = (-1, -1)
1320 else:
1321 assert len(window_size) == 2
1322 real_window_size = (window_size[0], window_size[1])
1323 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1324 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
1325 max_seqlen_q = (
1326 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q
1327 )
1328 max_seqlen_k = (
1329 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k
1330 )
1331 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd(
1332 q,
1333 k,
1334 v,
1335 out,
1336 cu_seqlens_q,
1337 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
1338 # still wants it so we pass all zeros
1339 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
1340 seqused_k,
1341 None,
1342 block_table,
1343 alibi_slopes,
1344 max_seqlen_q,
1345 max_seqlen_k,
1346 dropout_p,
1347 softmax_scale,
1348 False,
1349 causal,
1350 real_window_size[0],
1351 real_window_size[1],
1352 softcap,
1353 return_softmax_lse and dropout_p > 0,
1354 None,
1355 )
1357 return (out, softmax_lse) if return_softmax_lse else out