Coverage for src/flag_gems/ops/attention.py: 29%
399 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
2import math
3from functools import partial
5import torch
6import torch.nn.functional as F
7import triton
8import triton.language as tl
10from flag_gems import runtime
11from flag_gems.config import use_c_extension
12from flag_gems.ops.flash_api import mha_fwd, mha_varlan_fwd
13from flag_gems.ops.flash_kernel import keep
14from flag_gems.runtime import torch_device_fn
15from flag_gems.utils import libentry, libtuner
17logger = logging.getLogger(__name__)
20# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
21@triton.jit
22def _attn_fwd_inner(
23 acc,
24 l_i,
25 m_i,
26 query, #
27 K_block_ptr,
28 V_block_ptr, #
29 mask_block_ptr, #
30 stride_k_seqlen,
31 stride_v_seqlen,
32 stride_attn_mask_kv_seqlen, #
33 start_m,
34 qk_scale, #
35 q_load_mask,
36 BLOCK_M: tl.constexpr,
37 HEAD_DIM: tl.constexpr,
38 BLOCK_N: tl.constexpr, #
39 STAGE: tl.constexpr,
40 offs_m: tl.constexpr,
41 offs_n: tl.constexpr, #
42 KV_CTX: tl.constexpr,
43 fp8_v: tl.constexpr,
44 HAS_ATTN_MASK: tl.constexpr,
45 PRE_LOAD_V: tl.constexpr,
46):
47 # range of values handled by this stage
48 if STAGE == 1:
49 lo, hi = 0, start_m * BLOCK_M
50 elif STAGE == 2:
51 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
52 # causal = False
53 else:
54 lo, hi = 0, KV_CTX
56 K_block_ptr += lo * stride_k_seqlen
57 V_block_ptr += lo * stride_v_seqlen
58 if HAS_ATTN_MASK:
59 mask_block_ptr += lo * stride_attn_mask_kv_seqlen
61 LOG2E = 1.44269504 # log2(e) constant
63 # loop over key, value and update accumulator
64 for start_n in range(lo, hi, BLOCK_N):
65 kv_load_mask = (start_n + offs_n) < KV_CTX
66 # start_n = tl.multiple_of(start_n, BLOCK_N)
67 # -- compute qk ----
68 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0)
69 if PRE_LOAD_V:
70 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
72 qk = tl.dot(query, key, allow_tf32=False)
73 # incase not divisible.
74 qk = tl.where(kv_load_mask[None, :], qk, -float("inf"))
75 # qk = qk.to(tl.float32)
77 if HAS_ATTN_MASK:
78 attn_mask = tl.load(
79 mask_block_ptr,
80 mask=q_load_mask[:, None] & kv_load_mask[None, :],
81 other=0.0,
82 )
84 if STAGE == 2:
85 mask = offs_m[:, None] >= (start_n + offs_n[None, :])
87 if HAS_ATTN_MASK:
88 qk = qk * qk_scale + attn_mask
89 qk *= LOG2E
90 qk = qk + tl.where(mask, 0, -1.0e6)
91 else:
92 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6)
94 m_ij = tl.maximum(m_i, tl.max(qk, 1))
95 qk -= m_ij[:, None]
96 else:
97 qk *= qk_scale * LOG2E
98 if HAS_ATTN_MASK:
99 qk = qk + attn_mask
100 m_ij = tl.maximum(m_i, tl.max(qk, 1))
101 qk = qk - m_ij[:, None]
103 p = tl.math.exp2(qk)
104 l_ij = tl.sum(p, 1)
105 # -- update m_i and l_i
106 alpha = tl.math.exp2(m_i - m_ij)
107 l_i = l_i * alpha + l_ij
108 # -- update output accumulator --
109 acc = acc * alpha[:, None]
110 # update acc
111 if not PRE_LOAD_V:
112 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
113 if fp8_v:
114 p = p.to(tl.float8e5)
115 else:
116 p = p.to(query.dtype)
117 p = p.to(value.dtype)
118 acc = tl.dot(p, value, acc, allow_tf32=False)
119 # update m_i and l_i
120 m_i = m_ij
122 K_block_ptr += BLOCK_N * stride_k_seqlen
123 V_block_ptr += BLOCK_N * stride_v_seqlen
125 if HAS_ATTN_MASK:
126 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen
128 return acc, l_i, m_i
131# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim,
132# we need to generate more configs.
133configs = runtime.get_tuned_config("attention")
134SMALL_HEAD_DIM_CONFIGS = [
135 triton.Config(
136 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w
137 )
138 for BM in [64, 128]
139 for BN in [16, 32]
140 for s in [2, 3, 4]
141 for w in [4, 8]
142]
143configs += SMALL_HEAD_DIM_CONFIGS
146@libentry()
147@libtuner(
148 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)),
149 key=["KV_CTX", "HEAD_DIM"],
150)
151@triton.jit
152def _attn_fwd(
153 Q,
154 K,
155 V,
156 attn_mask,
157 sm_scale,
158 M,
159 Out, #
160 stride_q_batch,
161 stride_q_head,
162 stride_q_seqlen,
163 stride_q_headsize,
164 stride_k_batch,
165 stride_k_head,
166 stride_k_seqlen,
167 stride_k_headsize,
168 stride_v_batch,
169 stride_v_head,
170 stride_v_seqlen,
171 stride_v_headsize,
172 stride_attn_mask_batch,
173 stride_attn_mask_head,
174 stride_attn_mask_q_seqlen,
175 stride_attn_mask_kv_seqlen,
176 stride_o_batch,
177 stride_o_head,
178 stride_o_seqlen,
179 stride_o_headsize,
180 Z,
181 q_head_num,
182 kv_head_num,
183 GROUP_HEAD: tl.constexpr,
184 Q_CTX,
185 KV_CTX,
186 HEAD_DIM: tl.constexpr,
187 BLOCK_M: tl.constexpr,
188 BLOCK_N: tl.constexpr,
189 STAGE: tl.constexpr,
190 HAS_ATTN_MASK: tl.constexpr,
191 PRE_LOAD_V: tl.constexpr,
192):
193 tl.static_assert(BLOCK_N <= HEAD_DIM)
194 start_m = tl.program_id(0)
195 off_hz = tl.program_id(1)
196 batch_id = off_hz // q_head_num
197 head_id = off_hz % q_head_num
198 kv_head_id = head_id // GROUP_HEAD
200 q_offset = (
201 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head
202 )
203 o_offset = (
204 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head
205 )
206 kv_offset = (
207 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head
208 )
210 offs_headsize = tl.arange(0, HEAD_DIM)
212 # initialize offsets
213 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
214 q_load_mask = offs_m < Q_CTX
215 offs_n = tl.arange(0, BLOCK_N)
217 Q_block_ptr = (
218 Q
219 + q_offset
220 + offs_m[:, None] * stride_q_seqlen
221 + offs_headsize[None, :] * stride_q_headsize
222 )
223 K_block_ptr = (
224 K
225 + kv_offset
226 + offs_n[None, :] * stride_k_seqlen
227 + offs_headsize[:, None] * stride_k_headsize
228 )
229 V_block_ptr = (
230 V
231 + kv_offset
232 + offs_n[:, None] * stride_v_seqlen
233 + offs_headsize[None, :] * stride_v_headsize
234 )
236 if HAS_ATTN_MASK:
237 attn_mask_offset = (
238 batch_id.to(tl.int64) * stride_attn_mask_batch
239 + head_id.to(tl.int64) * stride_attn_mask_head
240 )
241 mask_block_ptr = (
242 attn_mask
243 + attn_mask_offset
244 + offs_m[:, None] * stride_attn_mask_q_seqlen
245 + offs_n[None, :] * stride_attn_mask_kv_seqlen
246 )
247 else:
248 mask_block_ptr = None
250 O_block_ptr = (
251 Out
252 + o_offset
253 + offs_m[:, None] * stride_o_seqlen
254 + offs_headsize[None, :] * stride_o_headsize
255 )
257 # initialize pointer to m and l
258 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
259 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
260 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
261 # load scales
262 qk_scale = sm_scale
263 # qk_scale *= 1.44269504 # 1/log(2)
264 # load query: it will stay in SRAM throughout
265 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0)
266 # stage 1: off-band
267 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
268 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
269 if STAGE & 1:
270 acc, l_i, m_i = _attn_fwd_inner(
271 acc,
272 l_i,
273 m_i,
274 query,
275 K_block_ptr,
276 V_block_ptr,
277 mask_block_ptr,
278 stride_k_seqlen,
279 stride_v_seqlen,
280 stride_attn_mask_kv_seqlen,
281 start_m,
282 qk_scale,
283 q_load_mask,
284 BLOCK_M,
285 HEAD_DIM,
286 BLOCK_N,
287 4 - STAGE,
288 offs_m,
289 offs_n,
290 KV_CTX,
291 V.dtype.element_ty == tl.float8e5,
292 HAS_ATTN_MASK,
293 PRE_LOAD_V,
294 )
295 # stage 2: on-band
296 if STAGE & 2:
297 # barrier makes it easier for compielr to schedule the
298 # two loops independently
299 acc, l_i, m_i = _attn_fwd_inner(
300 acc,
301 l_i,
302 m_i,
303 query,
304 K_block_ptr,
305 V_block_ptr,
306 mask_block_ptr,
307 stride_k_seqlen,
308 stride_v_seqlen,
309 stride_attn_mask_kv_seqlen,
310 start_m,
311 qk_scale,
312 q_load_mask,
313 BLOCK_M,
314 HEAD_DIM,
315 BLOCK_N,
316 2,
317 offs_m,
318 offs_n,
319 KV_CTX,
320 V.dtype.element_ty == tl.float8e5,
321 HAS_ATTN_MASK,
322 PRE_LOAD_V,
323 )
324 # epilogue
325 m_i += tl.math.log2(l_i)
326 acc = acc / l_i[:, None]
327 m_ptrs = M + off_hz * Q_CTX + offs_m
328 tl.store(m_ptrs, m_i, mask=q_load_mask)
329 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])
332@triton.jit
333def _attn_bwd_preprocess(
334 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
335):
336 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
337 mask = off_m < Q_CTX
339 off_hz = tl.program_id(1)
340 off_n = tl.arange(0, D_HEAD)
341 # load
342 o = tl.load(
343 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
344 mask=mask[:, None],
345 other=0.0,
346 )
347 do = tl.load(
348 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
349 mask=mask[:, None],
350 other=0.0,
351 ).to(tl.float32)
352 delta = tl.sum(o * do, axis=1)
353 # write-back
354 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask)
357# The main inner-loop logic for computing dK and dV.
358@triton.jit
359def _attn_bwd_dkdv(
360 dk,
361 dv, #
362 Q,
363 key,
364 value,
365 sm_scale, #
366 DO, #
367 M,
368 D, #
369 # shared by Q/K/V/DO.
370 stride_tok,
371 stride_d, #
372 H,
373 Q_CTX,
374 KV_CTX,
375 BLOCK_M1: tl.constexpr, #
376 BLOCK_N1: tl.constexpr, #
377 BLOCK_DMODEL: tl.constexpr, #
378 # Filled in by the wrapper.
379 start_n,
380 start_m,
381 num_steps, #
382 MASK: tl.constexpr,
383):
384 # BLOCK_M1: 32
385 # BLOCK_N1: 128
386 offs_n = start_n + tl.arange(0, BLOCK_N1)
387 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, )
389 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, )
391 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
392 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
393 curr_m = start_m
394 step_m = BLOCK_M1
395 for blk_idx in range(num_steps):
396 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, )
397 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, )
399 qT_ptrs = (
400 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
401 ) # (BLOCK_DMODEL, BLOCK_M1)
402 do_ptrs = (
403 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
404 ) # (BLOCK_M1, BLOCK_DMODEL)
406 qT = tl.load(
407 qT_ptrs, mask=offs_m_mask[None, :], other=0.0
408 ) # (BLOCK_DMODEL, BLOCK_M1)
410 # Load m before computing qk to reduce pipeline stall.
411 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, )
413 # key: (BLOCK_N1, BLOCK_DMODEL)
414 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1)
415 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1)
416 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1)
417 pT = tl.math.exp2(qkT - m)
418 # pT = tl.math.exp2(qkT - m[None, :])
420 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[
421 :, None
422 ] # (BLOCK_N1, BLOCK_M1)
423 # Autoregressive masking.
424 if MASK:
425 mask &= offs_m[None, :] >= offs_n[:, None]
426 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1)
428 do = tl.load(do_ptrs)
429 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL)
431 # Compute dV.
432 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL)
433 # D (= delta) is pre-divided by ds_scale.
434 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, )
436 # Compute dP and dS.
437 dpT = tl.dot(value, tl.trans(do)).to(
438 tl.float32
439 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1)
440 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1)
441 dsT = dsT.to(qT.dtype)
442 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1)
443 dsT = tl.where(
444 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0
445 ) # (BLOCK_N1, BLOCK_M1)
446 dk += tl.dot(
447 dsT, tl.trans(qT)
448 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL)
449 # Increment pointers.
450 curr_m += step_m
451 return dk, dv
454# the main inner-loop logic for computing dQ
455@triton.jit
456def _attn_bwd_dq(
457 dq,
458 query,
459 K,
460 V, #
461 do,
462 m,
463 D,
464 # shared by Q/K/V/DO.
465 stride_tok,
466 stride_d, #
467 H,
468 Q_CTX, #
469 KV_CTX, #
470 BLOCK_M2: tl.constexpr, #
471 BLOCK_N2: tl.constexpr, #
472 BLOCK_DMODEL: tl.constexpr,
473 # Filled in by the wrapper.
474 start_m,
475 start_n,
476 num_steps, #
477 MASK: tl.constexpr,
478):
479 offs_m = start_m + tl.arange(0, BLOCK_M2)
480 offs_m_mask = offs_m < Q_CTX
482 offs_k = tl.arange(0, BLOCK_DMODEL)
483 # D (= delta) is pre-divided by ds_scale.
484 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0)
485 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
486 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
487 curr_n = start_n
488 step_n = BLOCK_N2
489 for blk_idx in range(num_steps):
490 offs_n = curr_n + tl.arange(0, BLOCK_N2)
491 offs_n_mask = offs_n < KV_CTX
493 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
494 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
496 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0)
497 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0)
498 qk = tl.dot(query, kT)
499 p = tl.math.exp2(qk - m)
500 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :]
501 # Autoregressive masking.
502 if MASK:
503 # mask = (offs_m[:, None] >= offs_n[None, :])
504 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :]
505 mask &= offs_m[:, None] >= offs_n[None, :]
506 p = tl.where(mask, p, 0.0)
507 # Compute dP and dS.
508 dp = tl.dot(do, vT).to(tl.float32)
509 ds = p * (dp - Di[:, None])
510 ds = tl.where(mask, ds, 0.0).to(kT.dtype)
511 # Compute dQ.
512 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
513 dq += tl.dot(ds, tl.trans(kT))
514 # Increment pointers.
515 curr_n += step_n
516 return dq
519config_backward = runtime.get_tuned_config("attention_bwd")
522@libentry()
523@libtuner(
524 configs=config_backward,
525 key=["KV_CTX", "BLOCK_DMODEL"],
526)
527@triton.jit
528def _attn_bwd(
529 Q,
530 K,
531 V,
532 sm_scale, #
533 DO, #
534 DQ,
535 DK,
536 DV, #
537 M,
538 D,
539 # shared by Q/K/V/DO.
540 stride_z,
541 stride_h,
542 stride_tok,
543 stride_d, #
544 kv_stride_z,
545 kv_stride_h, #
546 H, # query head num
547 Q_CTX, #
548 KV_CTX, #
549 kv_head_num, #
550 GROUP_HEAD: tl.constexpr, #
551 BLOCK_M1: tl.constexpr, #
552 BLOCK_N1: tl.constexpr, #
553 BLOCK_M2: tl.constexpr, #
554 BLOCK_N2: tl.constexpr, #
555 BLK_SLICE_FACTOR: tl.constexpr, #
556 BLOCK_DMODEL: tl.constexpr,
557):
558 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.")
560 LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
562 bhid = tl.program_id(2)
563 off_chz = (bhid * Q_CTX).to(tl.int64)
564 batch_id = bhid // H
565 q_head_id = bhid % H
566 kv_head_id = q_head_id // GROUP_HEAD
567 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64)
568 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64)
570 pid = tl.program_id(0)
572 # offset pointers for batch/head
573 Q += adj
574 K += kv_adj
575 V += kv_adj
576 DO += adj
577 DQ += adj
578 DK += adj
579 DV += adj
580 M += off_chz
581 D += off_chz
583 # load scales
584 offs_k = tl.arange(0, BLOCK_DMODEL)
586 start_n = pid * BLOCK_N1
587 start_m = start_n
589 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
590 offs_n = start_n + tl.arange(0, BLOCK_N1)
591 offs_n_mask = offs_n < KV_CTX
593 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
594 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
596 # load K and V: they stay in SRAM throughout the inner loop.
597 key = tl.load(
598 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
599 mask=offs_n_mask[:, None],
600 other=0.0,
601 )
602 value = tl.load(
603 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
604 mask=offs_n_mask[:, None],
605 other=0.0,
606 )
608 num_steps = BLOCK_N1 // MASK_BLOCK_M1
610 dk, dv = _attn_bwd_dkdv(
611 dk,
612 dv, #
613 Q,
614 key,
615 value,
616 sm_scale, #
617 DO, #
618 M,
619 D, #
620 stride_tok,
621 stride_d, #
622 H,
623 Q_CTX, #
624 KV_CTX, #
625 MASK_BLOCK_M1,
626 BLOCK_N1,
627 BLOCK_DMODEL, #
628 start_n,
629 start_m,
630 num_steps, #
631 MASK=True, #
632 )
634 # Compute dK and dV for non-masked blocks.
635 start_m += num_steps * MASK_BLOCK_M1
636 remaining_m = Q_CTX - start_m
637 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1
639 if num_steps > 0 and start_m < Q_CTX:
640 dk, dv = _attn_bwd_dkdv( #
641 dk,
642 dv, #
643 Q,
644 key,
645 value,
646 sm_scale, #
647 DO, #
648 M,
649 D, #
650 stride_tok,
651 stride_d, #
652 H,
653 Q_CTX, #
654 KV_CTX, #
655 BLOCK_M1,
656 BLOCK_N1,
657 BLOCK_DMODEL, #
658 start_n,
659 start_m,
660 num_steps, #
661 MASK=False, #
662 )
663 # tl.device_print("dv: ", dv)
665 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
666 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None])
668 # Write back dK.
669 dk *= sm_scale
670 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
671 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None])
673 # THIS BLOCK DOES DQ:
674 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
675 start_m = pid * BLOCK_M2
676 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX
677 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2
679 offs_m = start_m + tl.arange(0, BLOCK_M2)
680 offs_m_mask = offs_m < Q_CTX
682 query = tl.load(
683 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
684 mask=offs_m_mask[:, None],
685 other=0.0,
686 )
687 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
688 do = tl.load(
689 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
690 mask=offs_m_mask[:, None],
691 other=0.0,
692 )
694 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf"))
695 m = m[:, None]
697 # Stage 1 - Compute dQ for masked (diagonal) blocks.
698 # NOTE: This code scans each row of QK^T backward (from right to left,
699 # but inside each call to _attn_bwd_dq, from left to right), but that's
700 # not due to anything important. I just wanted to reuse the loop
701 # structure for dK & dV above as much as possible.
703 if num_steps > 0:
704 dq = _attn_bwd_dq(
705 dq,
706 query,
707 K,
708 V, #
709 do,
710 m,
711 D, #
712 stride_tok,
713 stride_d, #
714 H,
715 Q_CTX, #
716 KV_CTX, #
717 BLOCK_M2,
718 MASK_BLOCK_N2,
719 BLOCK_DMODEL, #
720 start_m,
721 start_n,
722 num_steps, #
723 MASK=True, #
724 )
726 # Stage 2 - non-masked blocks
727 stage2_end_n = start_n
728 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2
730 if stage2_num_steps > 0:
731 dq = _attn_bwd_dq(
732 dq,
733 query,
734 K,
735 V, #
736 do,
737 m,
738 D, #
739 stride_tok,
740 stride_d, #
741 H,
742 Q_CTX, #
743 KV_CTX, #
744 BLOCK_M2,
745 BLOCK_N2,
746 BLOCK_DMODEL, #
747 start_m,
748 stage2_end_n - stage2_num_steps * BLOCK_N2,
749 stage2_num_steps, #
750 MASK=False, #
751 )
752 # Write back dQ.
753 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
754 dq *= LN2
755 # tl.store(dq_ptrs, dq)
757 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None])
760def scaled_dot_product_attention_forward(
761 query,
762 key,
763 value,
764 attn_mask=None,
765 dropout_p=0.0,
766 is_causal=False,
767 scale=None,
768 enable_gqa=False,
769):
770 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION FORWARD")
771 # shape constraints
772 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
773 # when v is in float8_e5m2 it is transposed.
774 HEAD_DIM_V = value.shape[-1]
775 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
776 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
777 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
779 o = torch.empty_like(query, dtype=value.dtype)
781 stage = 3 if is_causal else 1
783 if scale is None:
784 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
785 else:
786 sm_scale = scale
788 q_head_num = query.shape[1]
789 kv_head_num = key.shape[1]
790 assert enable_gqa or q_head_num == kv_head_num, (
791 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, "
792 "enable_gqa must be True to support different head numbers."
793 )
795 grid = lambda args: (
796 triton.cdiv(query.shape[2], args["BLOCK_M"]),
797 query.shape[0] * query.shape[1],
798 1,
799 )
801 if attn_mask is not None:
802 HAS_ATTN_MASK = True
803 if attn_mask.dtype == torch.bool:
804 attn_mask = attn_mask.to(query.dtype) * -1.0e6
805 stride_attn_mask_batch = attn_mask.stride(0)
806 stride_attn_mask_head = attn_mask.stride(1)
807 stride_attn_mask_q_seqlen = attn_mask.stride(2)
808 stride_attn_mask_kv_seqlen = attn_mask.stride(3)
809 else:
810 HAS_ATTN_MASK = False
811 stride_attn_mask_batch = 1
812 stride_attn_mask_head = 1
813 stride_attn_mask_q_seqlen = 1
814 stride_attn_mask_kv_seqlen = 1
816 M = torch.empty(
817 (query.shape[0], query.shape[1], query.shape[2]),
818 device=query.device,
819 dtype=torch.float32,
820 )
822 with torch_device_fn.device(query.device):
823 _attn_fwd[grid](
824 query,
825 key,
826 value,
827 attn_mask,
828 sm_scale,
829 M,
830 o, #
831 query.stride(0),
832 query.stride(1),
833 query.stride(2),
834 query.stride(3), #
835 key.stride(0),
836 key.stride(1),
837 key.stride(2),
838 key.stride(3), #
839 value.stride(0),
840 value.stride(1),
841 value.stride(2),
842 value.stride(3), #
843 stride_attn_mask_batch,
844 stride_attn_mask_head,
845 stride_attn_mask_q_seqlen,
846 stride_attn_mask_kv_seqlen, #
847 o.stride(0),
848 o.stride(1),
849 o.stride(2),
850 o.stride(3), #
851 query.shape[0],
852 q_head_num,
853 kv_head_num, #
854 q_head_num // kv_head_num, # group_head
855 query.shape[2], #
856 key.shape[2], #
857 HEAD_DIM_K, #
858 STAGE=stage, #
859 HAS_ATTN_MASK=HAS_ATTN_MASK, #
860 )
861 return o, M
864def scaled_dot_product_attention_backward(
865 do,
866 query,
867 key,
868 value,
869 o,
870 M,
871 attn_mask=None,
872 dropout_p=0.0,
873 is_causal=False,
874 scale=None,
875 enable_gqa=False,
876):
877 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION BACKWARD")
878 # shape constraints
879 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
880 # when v is in float8_e5m2 it is transposed.
881 HEAD_DIM_V = value.shape[-1]
882 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
883 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
884 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
886 if scale is None:
887 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
888 else:
889 sm_scale = scale
891 assert do.is_contiguous()
892 assert (
893 query.is_contiguous()
894 and key.is_contiguous()
895 and value.is_contiguous()
896 and o.is_contiguous()
897 )
898 assert query.stride() == o.stride() == do.stride()
899 assert key.stride() == value.stride()
901 BLOCK_DMODEL = HEAD_DIM_K
902 BATCH, Q_HEAD, Q_CTX = query.shape[:3]
903 _, KV_HEAD, KV_CTX = key.shape[:3]
904 group_head = Q_HEAD // KV_HEAD
906 # NUM_WARPS, NUM_STAGES = 4, 1
907 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
908 BLK_SLICE_FACTOR = 2
909 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
911 RCP_LN2 = 1.0 / math.log(2)
913 arg_k = key * (sm_scale * RCP_LN2)
914 # PRE_BLOCK = 128
915 PRE_BLOCK = 256
917 # PRE_BLOCK = 32
918 # assert N_CTX % PRE_BLOCK == 0
919 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD)
920 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD)
922 delta = torch.empty_like(M)
924 # NOTE that dk & dv always have the same number of heads as q
925 dq = torch.empty_like(query).contiguous()
926 dk = torch.empty(
927 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K),
928 device=key.device,
929 dtype=key.dtype,
930 memory_format=torch.contiguous_format,
931 )
932 dv = torch.empty(
933 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V),
934 device=value.device,
935 dtype=value.dtype,
936 memory_format=torch.contiguous_format,
937 )
939 _attn_bwd_preprocess[pre_grid](
940 o,
941 do, #
942 delta, #
943 BATCH,
944 Q_HEAD,
945 Q_CTX, #
946 BLOCK_M=PRE_BLOCK,
947 D_HEAD=BLOCK_DMODEL, #
948 )
950 max_block_n1 = (
951 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward])
952 if config_backward
953 else 128
954 )
955 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD)
956 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}")
957 # logger.info(f"{M.shape=}")
959 _attn_bwd[grid](
960 query,
961 arg_k,
962 value,
963 sm_scale,
964 do,
965 dq,
966 dk,
967 dv, #
968 M,
969 delta, #
970 query.stride(0),
971 query.stride(1),
972 query.stride(2),
973 query.stride(3), #
974 key.stride(0),
975 key.stride(1), #
976 Q_HEAD,
977 Q_CTX, #
978 KV_CTX, #
979 KV_HEAD, #
980 GROUP_HEAD=group_head, #
981 # BLOCK_M1=BLOCK_M1,
982 # BLOCK_N1=BLOCK_N1, #
983 # BLOCK_M2=BLOCK_M2,
984 # BLOCK_N2=BLOCK_N2, #
985 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
986 BLOCK_DMODEL=BLOCK_DMODEL, #
987 # num_warps=NUM_WARPS, #
988 # num_stages=NUM_STAGES, #
989 )
991 if group_head > 1:
992 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K)
993 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V)
994 dk = dk.sum(dim=2)
995 dv = dv.sum(dim=2)
997 return dq, dk, dv
1000class ScaleDotProductAttention(torch.autograd.Function):
1001 @staticmethod
1002 def forward(
1003 ctx,
1004 query,
1005 key,
1006 value,
1007 attn_mask=None,
1008 dropout_p=0.0,
1009 is_causal=False,
1010 scale=None,
1011 enable_gqa=False,
1012 ):
1013 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5)
1014 o, M = scaled_dot_product_attention_forward(
1015 query,
1016 key,
1017 value,
1018 attn_mask,
1019 dropout_p,
1020 is_causal,
1021 sm_scale,
1022 enable_gqa,
1023 )
1025 ctx.save_for_backward(query, key, value, o, M)
1026 ctx.sm_scale = sm_scale
1027 ctx.causal = is_causal
1028 ctx.enable_gqa = enable_gqa
1029 return o
1031 @staticmethod
1032 def backward(ctx, do):
1033 query, key, value, o, M = ctx.saved_tensors
1034 is_causal = ctx.causal
1035 enable_gqa = ctx.enable_gqa
1036 sm_scale = ctx.sm_scale
1037 dq, dk, dv = scaled_dot_product_attention_backward(
1038 do,
1039 query,
1040 key,
1041 value,
1042 o,
1043 M,
1044 attn_mask=None,
1045 dropout_p=0.0,
1046 is_causal=is_causal,
1047 scale=sm_scale,
1048 enable_gqa=enable_gqa,
1049 )
1050 return dq, dk, dv, None, None, None, None, None
1053def scaled_dot_product_attention(
1054 query,
1055 key,
1056 value,
1057 attn_mask=None,
1058 dropout_p=0.0,
1059 is_causal=False,
1060 scale=None,
1061 enable_gqa=False,
1062):
1063 return ScaleDotProductAttention.apply(
1064 query,
1065 key,
1066 value,
1067 attn_mask,
1068 dropout_p,
1069 is_causal,
1070 scale,
1071 enable_gqa,
1072 )
1075def flash_attention_forward(
1076 query,
1077 key,
1078 value,
1079 cumulative_sequence_length_q,
1080 cumulative_sequence_length_k,
1081 max_q,
1082 max_k,
1083 dropout_p,
1084 is_causal,
1085 return_debug_mask,
1086 *,
1087 scale=None,
1088 softcap=0.0,
1089 window_size_left=None,
1090 window_size_right=None,
1091 seqused_k=None,
1092 alibi_slopes=None,
1093 disable_splitkv=False,
1094):
1095 logger.debug("GEMS FLASH_ATTENTION_FORWARD")
1096 assert (
1097 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None
1098 ), "varlen is not supported yet."
1100 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
1101 HEAD_DIM_V = value.shape[-1]
1102 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
1103 original_head_dim = HEAD_DIM_K
1104 supported_head_dims = (16, 32, 64, 96, 128, 192, 256)
1105 if HEAD_DIM_K not in supported_head_dims:
1106 padded_head_dim = None
1107 for d in supported_head_dims:
1108 if d >= HEAD_DIM_K:
1109 padded_head_dim = d
1110 break
1111 assert (
1112 padded_head_dim is not None
1113 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}"
1114 pad = padded_head_dim - HEAD_DIM_K
1115 query = F.pad(query, (0, pad))
1116 key = F.pad(key, (0, pad))
1117 value = F.pad(value, (0, pad))
1118 HEAD_DIM_K = padded_head_dim
1120 softmax_scale = scale or 1.0 / (original_head_dim**0.5)
1121 if window_size_left is not None:
1122 non_null_window_left = window_size_left
1123 else:
1124 non_null_window_left = -1
1125 if window_size_right is not None:
1126 non_null_window_right = window_size_right
1127 else:
1128 non_null_window_right = -1
1130 out = torch.empty_like(query)
1131 if cumulative_sequence_length_q is not None:
1132 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd(
1133 query,
1134 key,
1135 value,
1136 out,
1137 cumulative_sequence_length_q,
1138 cumulative_sequence_length_k,
1139 seqused_k,
1140 None,
1141 None, # block_table
1142 alibi_slopes,
1143 max_q,
1144 max_k,
1145 dropout_p,
1146 scale,
1147 False,
1148 is_causal,
1149 non_null_window_left,
1150 non_null_window_right,
1151 softcap,
1152 return_debug_mask and dropout_p > 0,
1153 None,
1154 )
1155 else:
1156 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd(
1157 query,
1158 key,
1159 value,
1160 out,
1161 alibi_slopes,
1162 dropout_p,
1163 softmax_scale,
1164 is_causal,
1165 non_null_window_left,
1166 non_null_window_right,
1167 softcap,
1168 return_debug_mask,
1169 disable_splitkv=disable_splitkv,
1170 )
1172 if HEAD_DIM_K != original_head_dim:
1173 out = out[..., :original_head_dim]
1174 return (out, lse, philox_seed, philox_offset, p)
1177# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py
1178def maybe_contiguous(x):
1179 return x.contiguous() if x is not None and x.stride(-1) != 1 else x
1182def flash_attn_varlen_func(
1183 q,
1184 k,
1185 v,
1186 max_seqlen_q,
1187 cu_seqlens_q,
1188 max_seqlen_k,
1189 cu_seqlens_k=None, # only used for non-paged prefill
1190 seqused_k=None,
1191 q_v=None,
1192 dropout_p=0.0,
1193 softmax_scale=None,
1194 causal=False,
1195 window_size=None,
1196 softcap=0.0, # 0.0 means deactivated
1197 alibi_slopes=None,
1198 deterministic=False,
1199 return_attn_probs=False,
1200 block_table=None,
1201 return_softmax_lse=False,
1202 out=None,
1203 # Dummy FA3 arguments
1204 scheduler_metadata=None,
1205 q_descale=None,
1206 k_descale=None,
1207 v_descale=None,
1208 s_aux=None,
1209 num_splits: int = 0,
1210 cp_world_size: int = 1,
1211 cp_rank: int = 0,
1212 cp_tot_seqused_k=None,
1213 fa_version: int = 2,
1214):
1215 """dropout_p should be set to 0.0 during evaluation
1216 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1217 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1218 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1219 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1221 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1222 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1223 1 1 1 1 0
1224 1 1 1 1 1
1225 If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1226 0 0
1227 0 0
1228 0 0
1229 1 0
1230 1 1
1231 If the row of the mask is all zero, the output will be zero.
1233 If window_size != (-1, -1), implements sliding window local attention. Query at position i
1234 will only attend to keys between
1235 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1237 Arguments:
1238 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1239 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1240 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1241 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1242 of the sequences in the batch, used to index into q.
1243 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1244 of the sequences in the batch, used to index into kv.
1245 max_seqlen_q: int. Maximum query sequence length in the batch.
1246 max_seqlen_k: int. Maximum key sequence length in the batch.
1247 dropout_p: float. Dropout probability.
1248 softmax_scale: float. The scaling of QK^T before applying softmax.
1249 Default to 1 / sqrt(headdim).
1250 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1251 window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1252 softcap: float. Anything > 0 activates softcapping attention.
1253 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1254 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1255 is added to the attention score of query i and key j.
1256 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1257 which is slightly slower and uses more memory. The forward pass is always deterministic.
1258 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1259 testing only. The returned probabilities are not guaranteed to be correct
1260 (they might not have the right scaling).
1261 Return:
1262 out: (total, nheads, headdim).
1263 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
1264 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1265 normalization factor).
1266 """
1267 if fa_version != 2:
1268 raise RuntimeError("Only FA2 is implemented.")
1269 if num_splits > 0:
1270 raise RuntimeError("num_splits > 0 is not implemented in GEMS.")
1271 if use_c_extension:
1272 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)")
1273 with torch_device_fn.device(q.device):
1274 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func(
1275 q,
1276 k,
1277 v,
1278 max_seqlen_q,
1279 cu_seqlens_q,
1280 max_seqlen_k,
1281 cu_seqlens_k,
1282 seqused_k,
1283 q_v,
1284 dropout_p,
1285 softmax_scale,
1286 causal,
1287 window_size,
1288 softcap,
1289 alibi_slopes,
1290 deterministic,
1291 return_attn_probs,
1292 block_table,
1293 return_softmax_lse,
1294 out,
1295 scheduler_metadata,
1296 q_descale,
1297 k_descale,
1298 v_descale,
1299 s_aux,
1300 num_splits,
1301 cp_world_size,
1302 cp_rank,
1303 cp_tot_seqused_k,
1304 fa_version,
1305 )
1306 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp
1307 else:
1308 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC")
1309 assert (
1310 cu_seqlens_k is not None or seqused_k is not None
1311 ), "cu_seqlens_k or seqused_k must be provided"
1312 assert (
1313 cu_seqlens_k is None or seqused_k is None
1314 ), "cu_seqlens_k and seqused_k cannot be provided at the same time"
1315 assert (
1316 block_table is None or seqused_k is not None
1317 ), "seqused_k must be provided if block_table is provided"
1318 if softmax_scale is None:
1319 softmax_scale = q.shape[-1] ** (-0.5)
1320 # custom op does not support non-tuple input
1321 if window_size is None:
1322 real_window_size = (-1, -1)
1323 else:
1324 assert len(window_size) == 2
1325 real_window_size = (window_size[0], window_size[1])
1326 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1327 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
1328 max_seqlen_q = (
1329 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q
1330 )
1331 max_seqlen_k = (
1332 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k
1333 )
1334 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd(
1335 q,
1336 k,
1337 v,
1338 out,
1339 cu_seqlens_q,
1340 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
1341 # still wants it so we pass all zeros
1342 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
1343 seqused_k,
1344 None,
1345 block_table,
1346 alibi_slopes,
1347 max_seqlen_q,
1348 max_seqlen_k,
1349 dropout_p,
1350 softmax_scale,
1351 False,
1352 causal,
1353 real_window_size[0],
1354 real_window_size[1],
1355 softcap,
1356 return_softmax_lse and dropout_p > 0,
1357 None,
1358 )
1360 return (out, softmax_lse) if return_softmax_lse else out