Coverage for src/flag_gems/runtime/backend/_hygon/ops/attention.py: 0%
399 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2import math
3from functools import partial
5import torch
6import torch.nn.functional as F
7import triton
8import triton.language as tl
10from flag_gems import runtime
11from flag_gems.config import use_c_extension
12from flag_gems.runtime import torch_device_fn
13from flag_gems.utils import libentry, libtuner
15from .flash_api import mha_fwd, mha_varlan_fwd
16from .flash_kernel import keep
18logger = logging.getLogger(__name__)
21# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
22@triton.jit
23def _attn_fwd_inner(
24 acc,
25 l_i,
26 m_i,
27 query, #
28 K_block_ptr,
29 V_block_ptr, #
30 mask_block_ptr, #
31 stride_k_seqlen,
32 stride_v_seqlen,
33 stride_attn_mask_kv_seqlen, #
34 start_m,
35 qk_scale, #
36 q_load_mask,
37 BLOCK_M: tl.constexpr,
38 HEAD_DIM: tl.constexpr,
39 BLOCK_N: tl.constexpr, #
40 STAGE: tl.constexpr,
41 offs_m: tl.constexpr,
42 offs_n: tl.constexpr, #
43 KV_CTX: tl.constexpr,
44 fp8_v: tl.constexpr,
45 HAS_ATTN_MASK: tl.constexpr,
46 PRE_LOAD_V: tl.constexpr,
47):
48 # range of values handled by this stage
49 if STAGE == 1:
50 lo, hi = 0, start_m * BLOCK_M
51 elif STAGE == 2:
52 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
53 # causal = False
54 else:
55 lo, hi = 0, KV_CTX
57 K_block_ptr += lo * stride_k_seqlen
58 V_block_ptr += lo * stride_v_seqlen
59 if HAS_ATTN_MASK:
60 mask_block_ptr += lo * stride_attn_mask_kv_seqlen
62 LOG2E = 1.44269504 # log2(e) constant
64 # loop over key, value and update accumulator
65 for start_n in range(lo, hi, BLOCK_N):
66 kv_load_mask = (start_n + offs_n) < KV_CTX
67 # -- compute qk ----
68 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0)
69 if PRE_LOAD_V:
70 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
72 qk = tl.dot(query, key, allow_tf32=False)
73 # incase not divisible.
74 qk = tl.where(kv_load_mask[None, :], qk, -float("inf"))
75 # qk = qk.to(tl.float32)
77 if HAS_ATTN_MASK:
78 attn_mask = tl.load(
79 mask_block_ptr,
80 mask=q_load_mask[:, None] & kv_load_mask[None, :],
81 other=0.0,
82 )
84 if STAGE == 2:
85 mask = offs_m[:, None] >= (start_n + offs_n[None, :])
87 if HAS_ATTN_MASK:
88 qk = qk * qk_scale + attn_mask
89 qk *= LOG2E
90 qk = qk + tl.where(mask, 0, -1.0e6)
91 else:
92 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6)
94 m_ij = tl.maximum(m_i, tl.max(qk, 1))
95 qk -= m_ij[:, None]
96 else:
97 qk *= qk_scale * LOG2E
98 if HAS_ATTN_MASK:
99 qk = qk + attn_mask
100 m_ij = tl.maximum(m_i, tl.max(qk, 1))
101 qk = qk - m_ij[:, None]
103 p = tl.math.exp2(qk)
104 l_ij = tl.sum(p, 1)
105 # -- update m_i and l_i
106 alpha = tl.math.exp2(m_i - m_ij)
107 l_i = l_i * alpha + l_ij
108 # -- update output accumulator --
109 acc = acc * alpha[:, None]
110 # update acc
111 if not PRE_LOAD_V:
112 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
113 if fp8_v:
114 p = p.to(tl.float8e5)
115 else:
116 p = p.to(query.dtype)
117 p = p.to(value.dtype)
118 acc = tl.dot(p, value, acc, allow_tf32=False)
119 # update m_i and l_i
120 m_i = m_ij
122 K_block_ptr += BLOCK_N * stride_k_seqlen
123 V_block_ptr += BLOCK_N * stride_v_seqlen
125 if HAS_ATTN_MASK:
126 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen
128 return acc, l_i, m_i
131# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim,
132# we need to generate more configs.
133configs = runtime.get_tuned_config("attention")
134SMALL_HEAD_DIM_CONFIGS = [
135 triton.Config(
136 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w
137 )
138 for BM in [64, 128]
139 for BN in [16, 32]
140 for s in [2, 3, 4]
141 for w in [4, 8]
142]
143configs += SMALL_HEAD_DIM_CONFIGS
146@libentry()
147@libtuner(
148 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)),
149 key=["KV_CTX", "HEAD_DIM"],
150)
151@triton.jit
152def _attn_fwd(
153 Q,
154 K,
155 V,
156 attn_mask,
157 sm_scale,
158 M,
159 Out, #
160 stride_q_batch,
161 stride_q_head,
162 stride_q_seqlen,
163 stride_q_headsize,
164 stride_k_batch,
165 stride_k_head,
166 stride_k_seqlen,
167 stride_k_headsize,
168 stride_v_batch,
169 stride_v_head,
170 stride_v_seqlen,
171 stride_v_headsize,
172 stride_attn_mask_batch,
173 stride_attn_mask_head,
174 stride_attn_mask_q_seqlen,
175 stride_attn_mask_kv_seqlen,
176 stride_o_batch,
177 stride_o_head,
178 stride_o_seqlen,
179 stride_o_headsize,
180 Z,
181 q_head_num,
182 kv_head_num,
183 GROUP_HEAD: tl.constexpr,
184 Q_CTX,
185 KV_CTX,
186 HEAD_DIM: tl.constexpr,
187 BLOCK_M: tl.constexpr,
188 BLOCK_N: tl.constexpr,
189 STAGE: tl.constexpr,
190 HAS_ATTN_MASK: tl.constexpr,
191 PRE_LOAD_V: tl.constexpr,
192):
193 tl.static_assert(BLOCK_N <= HEAD_DIM)
194 start_m = tl.program_id(0)
195 off_hz = tl.program_id(1)
196 batch_id = off_hz // q_head_num
197 head_id = off_hz % q_head_num
198 kv_head_id = head_id // GROUP_HEAD
200 q_offset = (
201 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head
202 )
203 o_offset = (
204 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head
205 )
206 kv_offset = (
207 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head
208 )
210 offs_headsize = tl.arange(0, HEAD_DIM)
212 # initialize offsets
213 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
214 q_load_mask = offs_m < Q_CTX
215 offs_n = tl.arange(0, BLOCK_N)
217 Q_block_ptr = (
218 Q
219 + q_offset
220 + offs_m[:, None] * stride_q_seqlen
221 + offs_headsize[None, :] * stride_q_headsize
222 )
223 K_block_ptr = (
224 K
225 + kv_offset
226 + offs_n[None, :] * stride_k_seqlen
227 + offs_headsize[:, None] * stride_k_headsize
228 )
229 V_block_ptr = (
230 V
231 + kv_offset
232 + offs_n[:, None] * stride_v_seqlen
233 + offs_headsize[None, :] * stride_v_headsize
234 )
236 if HAS_ATTN_MASK:
237 attn_mask_offset = (
238 batch_id.to(tl.int64) * stride_attn_mask_batch
239 + head_id.to(tl.int64) * stride_attn_mask_head
240 )
241 mask_block_ptr = (
242 attn_mask
243 + attn_mask_offset
244 + offs_m[:, None] * stride_attn_mask_q_seqlen
245 + offs_n[None, :] * stride_attn_mask_kv_seqlen
246 )
247 else:
248 mask_block_ptr = None
250 O_block_ptr = (
251 Out
252 + o_offset
253 + offs_m[:, None] * stride_o_seqlen
254 + offs_headsize[None, :] * stride_o_headsize
255 )
257 # initialize pointer to m and l
258 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
259 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
260 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
261 # load scales
262 qk_scale = sm_scale
263 # qk_scale *= 1.44269504 # 1/log(2)
264 # load query: it will stay in SRAM throughout
265 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0)
266 # stage 1: off-band
267 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
268 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
269 if STAGE & 1:
270 acc, l_i, m_i = _attn_fwd_inner(
271 acc,
272 l_i,
273 m_i,
274 query,
275 K_block_ptr,
276 V_block_ptr,
277 mask_block_ptr,
278 stride_k_seqlen,
279 stride_v_seqlen,
280 stride_attn_mask_kv_seqlen,
281 start_m,
282 qk_scale,
283 q_load_mask,
284 BLOCK_M,
285 HEAD_DIM,
286 BLOCK_N,
287 4 - STAGE,
288 offs_m,
289 offs_n,
290 KV_CTX,
291 V.dtype.element_ty == tl.float8e5,
292 HAS_ATTN_MASK,
293 PRE_LOAD_V,
294 )
295 # stage 2: on-band
296 if STAGE & 2:
297 # barrier makes it easier for compielr to schedule the
298 # two loops independently
299 acc, l_i, m_i = _attn_fwd_inner(
300 acc,
301 l_i,
302 m_i,
303 query,
304 K_block_ptr,
305 V_block_ptr,
306 mask_block_ptr,
307 stride_k_seqlen,
308 stride_v_seqlen,
309 stride_attn_mask_kv_seqlen,
310 start_m,
311 qk_scale,
312 q_load_mask,
313 BLOCK_M,
314 HEAD_DIM,
315 BLOCK_N,
316 2,
317 offs_m,
318 offs_n,
319 KV_CTX,
320 V.dtype.element_ty == tl.float8e5,
321 HAS_ATTN_MASK,
322 PRE_LOAD_V,
323 )
324 # epilogue
325 m_i += tl.math.log2(l_i)
326 acc = acc / l_i[:, None]
327 m_ptrs = M + off_hz * Q_CTX + offs_m
328 tl.store(m_ptrs, m_i, mask=q_load_mask)
329 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])
332@triton.jit
333def _attn_bwd_preprocess(
334 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
335):
336 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
337 mask = off_m < Q_CTX
339 off_hz = tl.program_id(1)
340 off_n = tl.arange(0, D_HEAD)
341 # load
342 o = tl.load(
343 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
344 mask=mask[:, None],
345 other=0.0,
346 )
347 do = tl.load(
348 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
349 mask=mask[:, None],
350 other=0.0,
351 ).to(tl.float32)
352 delta = tl.sum(o * do, axis=1)
353 # write-back
354 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask)
357# The main inner-loop logic for computing dK and dV.
358@triton.jit
359def _attn_bwd_dkdv(
360 dk,
361 dv, #
362 Q,
363 key,
364 value,
365 sm_scale, #
366 DO, #
367 M,
368 D, #
369 # shared by Q/K/V/DO.
370 stride_tok,
371 stride_d, #
372 H,
373 Q_CTX,
374 KV_CTX,
375 BLOCK_M1: tl.constexpr, #
376 BLOCK_N1: tl.constexpr, #
377 BLOCK_DMODEL: tl.constexpr, #
378 # Filled in by the wrapper.
379 start_n,
380 start_m,
381 num_steps, #
382 MASK: tl.constexpr,
383):
384 # BLOCK_M1: 32
385 # BLOCK_N1: 128
386 offs_n = start_n + tl.arange(0, BLOCK_N1)
387 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, )
389 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, )
391 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
392 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
393 curr_m = start_m
394 step_m = BLOCK_M1
395 for blk_idx in range(num_steps):
396 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, )
397 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, )
399 qT_ptrs = (
400 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
401 ) # (BLOCK_DMODEL, BLOCK_M1)
402 do_ptrs = (
403 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
404 ) # (BLOCK_M1, BLOCK_DMODEL)
406 qT = tl.load(
407 qT_ptrs, mask=offs_m_mask[None, :], other=0.0
408 ) # (BLOCK_DMODEL, BLOCK_M1)
410 # Load m before computing qk to reduce pipeline stall.
411 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, )
413 # key: (BLOCK_N1, BLOCK_DMODEL)
414 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1)
415 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1)
416 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1)
417 pT = tl.math.exp2(qkT - m)
418 # pT = tl.math.exp2(qkT - m[None, :])
420 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[
421 :, None
422 ] # (BLOCK_N1, BLOCK_M1)
423 # Autoregressive masking.
424 if MASK:
425 mask &= offs_m[None, :] >= offs_n[:, None]
426 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1)
428 do = tl.load(do_ptrs)
429 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL)
431 # Compute dV.
432 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL)
433 # D (= delta) is pre-divided by ds_scale.
434 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, )
436 # Compute dP and dS.
437 dpT = tl.dot(value, tl.trans(do)).to(
438 tl.float32
439 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1)
440 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1)
441 dsT = dsT.to(qT.dtype)
442 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1)
443 dsT = tl.where(
444 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0
445 ) # (BLOCK_N1, BLOCK_M1)
446 dk += tl.dot(
447 dsT, tl.trans(qT)
448 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL)
449 # Increment pointers.
450 curr_m += step_m
451 return dk, dv
454# the main inner-loop logic for computing dQ
455@triton.jit
456def _attn_bwd_dq(
457 dq,
458 query,
459 K,
460 V,
461 do,
462 m,
463 D,
464 # shared by Q/K/V/DO.
465 stride_tok,
466 stride_d,
467 H,
468 Q_CTX,
469 KV_CTX,
470 BLOCK_M2: tl.constexpr,
471 BLOCK_N2: tl.constexpr,
472 BLOCK_DMODEL: tl.constexpr,
473 # Filled in by the wrapper.
474 start_m,
475 start_n,
476 num_steps,
477 MASK: tl.constexpr,
478):
479 offs_m = start_m + tl.arange(0, BLOCK_M2)
480 offs_m_mask = offs_m < Q_CTX
482 offs_k = tl.arange(0, BLOCK_DMODEL)
483 # D (= delta) is pre-divided by ds_scale.
484 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0)
485 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
486 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
487 curr_n = start_n
488 step_n = BLOCK_N2
489 for blk_idx in range(num_steps):
490 offs_n = curr_n + tl.arange(0, BLOCK_N2)
491 offs_n_mask = offs_n < KV_CTX
493 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
494 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
496 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0)
497 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0)
498 qk = tl.dot(query, kT)
499 p = tl.math.exp2(qk - m)
500 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :]
501 # Autoregressive masking.
502 if MASK:
503 # mask = (offs_m[:, None] >= offs_n[None, :])
504 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :]
505 mask &= offs_m[:, None] >= offs_n[None, :]
506 p = tl.where(mask, p, 0.0)
507 # Compute dP and dS.
508 dp = tl.dot(do, vT).to(tl.float32)
509 ds = p * (dp - Di[:, None])
510 ds = tl.where(mask, ds, 0.0).to(kT.dtype)
511 # Compute dQ.
512 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
513 dq += tl.dot(ds, tl.trans(kT))
514 # Increment pointers.
515 curr_n += step_n
516 return dq
519config_backward = runtime.get_tuned_config("attention_bwd")
522@libentry()
523@libtuner(
524 configs=config_backward,
525 key=["KV_CTX", "BLOCK_DMODEL"],
526)
527@triton.jit
528def _attn_bwd(
529 Q,
530 K,
531 V,
532 sm_scale, #
533 DO, #
534 DQ,
535 DK,
536 DV, #
537 M,
538 D,
539 # shared by Q/K/V/DO.
540 stride_z,
541 stride_h,
542 stride_tok,
543 stride_d, #
544 kv_stride_z,
545 kv_stride_h, #
546 H, # query head num
547 Q_CTX, #
548 KV_CTX, #
549 kv_head_num, #
550 GROUP_HEAD: tl.constexpr, #
551 BLOCK_M1: tl.constexpr, #
552 BLOCK_N1: tl.constexpr, #
553 BLOCK_M2: tl.constexpr, #
554 BLOCK_N2: tl.constexpr, #
555 BLK_SLICE_FACTOR: tl.constexpr, #
556 BLOCK_DMODEL: tl.constexpr,
557):
558 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.")
560 LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
562 bhid = tl.program_id(2)
563 off_chz = (bhid * Q_CTX).to(tl.int64)
564 batch_id = bhid // H
565 q_head_id = bhid % H
566 kv_head_id = q_head_id // GROUP_HEAD
567 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64)
568 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64)
570 pid = tl.program_id(0)
572 # offset pointers for batch/head
573 Q += adj
574 K += kv_adj
575 V += kv_adj
576 DO += adj
577 DQ += adj
578 DK += adj
579 DV += adj
580 M += off_chz
581 D += off_chz
583 # load scales
584 offs_k = tl.arange(0, BLOCK_DMODEL)
586 start_n = pid * BLOCK_N1
587 start_m = start_n
589 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
590 offs_n = start_n + tl.arange(0, BLOCK_N1)
591 offs_n_mask = offs_n < KV_CTX
593 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
594 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
596 # load K and V: they stay in SRAM throughout the inner loop.
597 key = tl.load(
598 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
599 mask=offs_n_mask[:, None],
600 other=0.0,
601 )
602 value = tl.load(
603 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
604 mask=offs_n_mask[:, None],
605 other=0.0,
606 )
608 num_steps = BLOCK_N1 // MASK_BLOCK_M1
610 dk, dv = _attn_bwd_dkdv(
611 dk,
612 dv, #
613 Q,
614 key,
615 value,
616 sm_scale, #
617 DO, #
618 M,
619 D, #
620 stride_tok,
621 stride_d, #
622 H,
623 Q_CTX, #
624 KV_CTX, #
625 MASK_BLOCK_M1,
626 BLOCK_N1,
627 BLOCK_DMODEL, #
628 start_n,
629 start_m,
630 num_steps, #
631 MASK=True, #
632 )
634 # Compute dK and dV for non-masked blocks.
635 start_m += num_steps * MASK_BLOCK_M1
636 remaining_m = Q_CTX - start_m
637 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1
639 if num_steps > 0 and start_m < Q_CTX:
640 dk, dv = _attn_bwd_dkdv( #
641 dk,
642 dv, #
643 Q,
644 key,
645 value,
646 sm_scale, #
647 DO, #
648 M,
649 D, #
650 stride_tok,
651 stride_d, #
652 H,
653 Q_CTX, #
654 KV_CTX, #
655 BLOCK_M1,
656 BLOCK_N1,
657 BLOCK_DMODEL, #
658 start_n,
659 start_m,
660 num_steps, #
661 MASK=False, #
662 )
664 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
665 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None])
667 # Write back dK.
668 dk *= sm_scale
669 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
670 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None])
672 # THIS BLOCK DOES DQ:
673 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
674 start_m = pid * BLOCK_M2
675 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX
676 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2
678 offs_m = start_m + tl.arange(0, BLOCK_M2)
679 offs_m_mask = offs_m < Q_CTX
681 query = tl.load(
682 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
683 mask=offs_m_mask[:, None],
684 other=0.0,
685 )
686 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
687 do = tl.load(
688 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
689 mask=offs_m_mask[:, None],
690 other=0.0,
691 )
693 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf"))
694 m = m[:, None]
696 # Stage 1 - Compute dQ for masked (diagonal) blocks.
697 # NOTE: This code scans each row of QK^T backward (from right to left,
698 # but inside each call to _attn_bwd_dq, from left to right), but that's
699 # not due to anything important. I just wanted to reuse the loop
700 # structure for dK & dV above as much as possible.
702 if num_steps > 0:
703 dq = _attn_bwd_dq(
704 dq,
705 query,
706 K,
707 V, #
708 do,
709 m,
710 D, #
711 stride_tok,
712 stride_d, #
713 H,
714 Q_CTX, #
715 KV_CTX, #
716 BLOCK_M2,
717 MASK_BLOCK_N2,
718 BLOCK_DMODEL, #
719 start_m,
720 start_n,
721 num_steps, #
722 MASK=True, #
723 )
725 # Stage 2 - non-masked blocks
726 stage2_end_n = start_n
727 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2
729 if stage2_num_steps > 0:
730 dq = _attn_bwd_dq(
731 dq,
732 query,
733 K,
734 V, #
735 do,
736 m,
737 D, #
738 stride_tok,
739 stride_d, #
740 H,
741 Q_CTX, #
742 KV_CTX, #
743 BLOCK_M2,
744 BLOCK_N2,
745 BLOCK_DMODEL, #
746 start_m,
747 stage2_end_n - stage2_num_steps * BLOCK_N2,
748 stage2_num_steps, #
749 MASK=False, #
750 )
751 # Write back dQ.
752 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
753 dq *= LN2
754 # tl.store(dq_ptrs, dq)
756 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None])
759def scaled_dot_product_attention_forward(
760 query,
761 key,
762 value,
763 attn_mask=None,
764 dropout_p=0.0,
765 is_causal=False,
766 scale=None,
767 enable_gqa=False,
768):
769 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION FORWARD")
770 # shape constraints
771 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
772 # when v is in float8_e5m2 it is transposed.
773 HEAD_DIM_V = value.shape[-1]
774 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
775 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
776 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
778 o = torch.empty_like(query, dtype=value.dtype)
780 stage = 3 if is_causal else 1
782 if scale is None:
783 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
784 else:
785 sm_scale = scale
787 q_head_num = query.shape[1]
788 kv_head_num = key.shape[1]
789 assert enable_gqa or q_head_num == kv_head_num, (
790 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, "
791 "enable_gqa must be True to support different head numbers."
792 )
794 grid = lambda args: (
795 triton.cdiv(query.shape[2], args["BLOCK_M"]),
796 query.shape[0] * query.shape[1],
797 1,
798 )
800 if attn_mask is not None:
801 HAS_ATTN_MASK = True
802 if attn_mask.dtype == torch.bool:
803 attn_mask = attn_mask.to(query.dtype) * -1.0e6
804 stride_attn_mask_batch = attn_mask.stride(0)
805 stride_attn_mask_head = attn_mask.stride(1)
806 stride_attn_mask_q_seqlen = attn_mask.stride(2)
807 stride_attn_mask_kv_seqlen = attn_mask.stride(3)
808 else:
809 HAS_ATTN_MASK = False
810 stride_attn_mask_batch = 1
811 stride_attn_mask_head = 1
812 stride_attn_mask_q_seqlen = 1
813 stride_attn_mask_kv_seqlen = 1
815 M = torch.empty(
816 (query.shape[0], query.shape[1], query.shape[2]),
817 device=query.device,
818 dtype=torch.float32,
819 )
821 with torch_device_fn.device(query.device):
822 _attn_fwd[grid](
823 query,
824 key,
825 value,
826 attn_mask,
827 sm_scale,
828 M,
829 o, #
830 query.stride(0),
831 query.stride(1),
832 query.stride(2),
833 query.stride(3), #
834 key.stride(0),
835 key.stride(1),
836 key.stride(2),
837 key.stride(3), #
838 value.stride(0),
839 value.stride(1),
840 value.stride(2),
841 value.stride(3), #
842 stride_attn_mask_batch,
843 stride_attn_mask_head,
844 stride_attn_mask_q_seqlen,
845 stride_attn_mask_kv_seqlen, #
846 o.stride(0),
847 o.stride(1),
848 o.stride(2),
849 o.stride(3), #
850 query.shape[0],
851 q_head_num,
852 kv_head_num, #
853 q_head_num // kv_head_num, # group_head
854 query.shape[2], #
855 key.shape[2], #
856 HEAD_DIM_K, #
857 STAGE=stage, #
858 HAS_ATTN_MASK=HAS_ATTN_MASK, #
859 )
860 return o, M
863def scaled_dot_product_attention_backward(
864 do,
865 query,
866 key,
867 value,
868 o,
869 M,
870 attn_mask=None,
871 dropout_p=0.0,
872 is_causal=False,
873 scale=None,
874 enable_gqa=False,
875):
876 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION BACKWARD")
877 # shape constraints
878 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
879 # when v is in float8_e5m2 it is transposed.
880 HEAD_DIM_V = value.shape[-1]
881 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
882 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
883 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
885 if scale is None:
886 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
887 else:
888 sm_scale = scale
890 assert do.is_contiguous()
891 assert (
892 query.is_contiguous()
893 and key.is_contiguous()
894 and value.is_contiguous()
895 and o.is_contiguous()
896 )
897 assert query.stride() == o.stride() == do.stride()
898 assert key.stride() == value.stride()
900 BLOCK_DMODEL = HEAD_DIM_K
901 BATCH, Q_HEAD, Q_CTX = query.shape[:3]
902 _, KV_HEAD, KV_CTX = key.shape[:3]
903 group_head = Q_HEAD // KV_HEAD
905 # NUM_WARPS, NUM_STAGES = 4, 1
906 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
907 BLK_SLICE_FACTOR = 2
908 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
910 RCP_LN2 = 1.0 / math.log(2)
912 arg_k = key * (sm_scale * RCP_LN2)
913 # PRE_BLOCK = 128
914 PRE_BLOCK = 256
916 # PRE_BLOCK = 32
917 # assert N_CTX % PRE_BLOCK == 0
918 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD)
919 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD)
921 delta = torch.empty_like(M)
923 # NOTE that dk & dv always have the same number of heads as q
924 dq = torch.empty_like(query).contiguous()
925 dk = torch.empty(
926 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K),
927 device=key.device,
928 dtype=key.dtype,
929 memory_format=torch.contiguous_format,
930 )
931 dv = torch.empty(
932 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V),
933 device=value.device,
934 dtype=value.dtype,
935 memory_format=torch.contiguous_format,
936 )
938 _attn_bwd_preprocess[pre_grid](
939 o,
940 do, #
941 delta, #
942 BATCH,
943 Q_HEAD,
944 Q_CTX, #
945 BLOCK_M=PRE_BLOCK,
946 D_HEAD=BLOCK_DMODEL, #
947 )
949 max_block_n1 = (
950 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward])
951 if config_backward
952 else 128
953 )
954 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD)
956 _attn_bwd[grid](
957 query,
958 arg_k,
959 value,
960 sm_scale,
961 do,
962 dq,
963 dk,
964 dv, #
965 M,
966 delta, #
967 query.stride(0),
968 query.stride(1),
969 query.stride(2),
970 query.stride(3), #
971 key.stride(0),
972 key.stride(1), #
973 Q_HEAD,
974 Q_CTX, #
975 KV_CTX, #
976 KV_HEAD, #
977 GROUP_HEAD=group_head, #
978 # BLOCK_M1=BLOCK_M1,
979 # BLOCK_N1=BLOCK_N1, #
980 # BLOCK_M2=BLOCK_M2,
981 # BLOCK_N2=BLOCK_N2, #
982 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
983 BLOCK_DMODEL=BLOCK_DMODEL, #
984 # num_warps=NUM_WARPS, #
985 # num_stages=NUM_STAGES, #
986 )
988 if group_head > 1:
989 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K)
990 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V)
991 dk = dk.sum(dim=2)
992 dv = dv.sum(dim=2)
994 return dq, dk, dv
997class ScaleDotProductAttention(torch.autograd.Function):
998 @staticmethod
999 def forward(
1000 ctx,
1001 query,
1002 key,
1003 value,
1004 attn_mask=None,
1005 dropout_p=0.0,
1006 is_causal=False,
1007 scale=None,
1008 enable_gqa=False,
1009 ):
1010 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5)
1011 o, M = scaled_dot_product_attention_forward(
1012 query,
1013 key,
1014 value,
1015 attn_mask,
1016 dropout_p,
1017 is_causal,
1018 sm_scale,
1019 enable_gqa,
1020 )
1022 ctx.save_for_backward(query, key, value, o, M)
1023 ctx.sm_scale = sm_scale
1024 ctx.causal = is_causal
1025 ctx.enable_gqa = enable_gqa
1026 return o
1028 @staticmethod
1029 def backward(ctx, do):
1030 query, key, value, o, M = ctx.saved_tensors
1031 is_causal = ctx.causal
1032 enable_gqa = ctx.enable_gqa
1033 sm_scale = ctx.sm_scale
1034 dq, dk, dv = scaled_dot_product_attention_backward(
1035 do,
1036 query,
1037 key,
1038 value,
1039 o,
1040 M,
1041 attn_mask=None,
1042 dropout_p=0.0,
1043 is_causal=is_causal,
1044 scale=sm_scale,
1045 enable_gqa=enable_gqa,
1046 )
1047 return dq, dk, dv, None, None, None, None, None
1050def scaled_dot_product_attention(
1051 query,
1052 key,
1053 value,
1054 attn_mask=None,
1055 dropout_p=0.0,
1056 is_causal=False,
1057 scale=None,
1058 enable_gqa=False,
1059):
1060 return ScaleDotProductAttention.apply(
1061 query,
1062 key,
1063 value,
1064 attn_mask,
1065 dropout_p,
1066 is_causal,
1067 scale,
1068 enable_gqa,
1069 )
1072def flash_attention_forward(
1073 query,
1074 key,
1075 value,
1076 cumulative_sequence_length_q,
1077 cumulative_sequence_length_k,
1078 max_q,
1079 max_k,
1080 dropout_p,
1081 is_causal,
1082 return_debug_mask,
1083 *,
1084 scale=None,
1085 softcap=0.0,
1086 window_size_left=None,
1087 window_size_right=None,
1088 seqused_k=None,
1089 alibi_slopes=None,
1090 disable_splitkv=False,
1091):
1092 logger.debug("GEMS FLASH_ATTENTION_FORWARD")
1093 assert (
1094 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None
1095 ), "varlen is not supported yet."
1097 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
1098 HEAD_DIM_V = value.shape[-1]
1099 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
1100 original_head_dim = HEAD_DIM_K
1101 supported_head_dims = (16, 32, 64, 96, 128, 192, 256)
1102 if HEAD_DIM_K not in supported_head_dims:
1103 padded_head_dim = None
1104 for d in supported_head_dims:
1105 if d >= HEAD_DIM_K:
1106 padded_head_dim = d
1107 break
1108 assert (
1109 padded_head_dim is not None
1110 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}"
1111 pad = padded_head_dim - HEAD_DIM_K
1112 query = F.pad(query, (0, pad))
1113 key = F.pad(key, (0, pad))
1114 value = F.pad(value, (0, pad))
1115 HEAD_DIM_K = padded_head_dim
1117 softmax_scale = scale or 1.0 / (original_head_dim**0.5)
1118 if window_size_left is not None:
1119 non_null_window_left = window_size_left
1120 else:
1121 non_null_window_left = -1
1122 if window_size_right is not None:
1123 non_null_window_right = window_size_right
1124 else:
1125 non_null_window_right = -1
1127 out = torch.empty_like(query)
1128 if cumulative_sequence_length_q is not None:
1129 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd(
1130 query,
1131 key,
1132 value,
1133 out,
1134 cumulative_sequence_length_q,
1135 cumulative_sequence_length_k,
1136 seqused_k,
1137 None,
1138 None, # block_table
1139 alibi_slopes,
1140 max_q,
1141 max_k,
1142 dropout_p,
1143 scale,
1144 False,
1145 is_causal,
1146 non_null_window_left,
1147 non_null_window_right,
1148 softcap,
1149 return_debug_mask and dropout_p > 0,
1150 None,
1151 )
1152 else:
1153 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd(
1154 query,
1155 key,
1156 value,
1157 out,
1158 alibi_slopes,
1159 dropout_p,
1160 softmax_scale,
1161 is_causal,
1162 non_null_window_left,
1163 non_null_window_right,
1164 softcap,
1165 return_debug_mask,
1166 disable_splitkv=disable_splitkv,
1167 )
1169 if HEAD_DIM_K != original_head_dim:
1170 out = out[..., :original_head_dim]
1171 return (out, lse, philox_seed, philox_offset, p)
1174# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py
1175def maybe_contiguous(x):
1176 return x.contiguous() if x is not None and x.stride(-1) != 1 else x
1179def flash_attn_varlen_func(
1180 q,
1181 k,
1182 v,
1183 max_seqlen_q,
1184 cu_seqlens_q,
1185 max_seqlen_k,
1186 cu_seqlens_k=None, # only used for non-paged prefill
1187 seqused_k=None,
1188 q_v=None,
1189 dropout_p=0.0,
1190 softmax_scale=None,
1191 causal=False,
1192 window_size=None,
1193 softcap=0.0, # 0.0 means deactivated
1194 alibi_slopes=None,
1195 deterministic=False,
1196 return_attn_probs=False,
1197 block_table=None,
1198 return_softmax_lse=False,
1199 out=None,
1200 # Dummy FA3 arguments
1201 scheduler_metadata=None,
1202 q_descale=None,
1203 k_descale=None,
1204 v_descale=None,
1205 s_aux=None,
1206 num_splits: int = 0,
1207 cp_world_size: int = 1,
1208 cp_rank: int = 0,
1209 cp_tot_seqused_k=None,
1210 fa_version: int = 2,
1211):
1212 """dropout_p should be set to 0.0 during evaluation
1213 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1214 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1215 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1216 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1218 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1219 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1220 1 1 1 1 0
1221 1 1 1 1 1
1222 If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1223 0 0
1224 0 0
1225 0 0
1226 1 0
1227 1 1
1228 If the row of the mask is all zero, the output will be zero.
1230 If window_size != (-1, -1), implements sliding window local attention. Query at position i
1231 will only attend to keys between
1232 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1234 Arguments:
1235 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1236 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1237 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1238 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1239 of the sequences in the batch, used to index into q.
1240 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1241 of the sequences in the batch, used to index into kv.
1242 max_seqlen_q: int. Maximum query sequence length in the batch.
1243 max_seqlen_k: int. Maximum key sequence length in the batch.
1244 dropout_p: float. Dropout probability.
1245 softmax_scale: float. The scaling of QK^T before applying softmax.
1246 Default to 1 / sqrt(headdim).
1247 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1248 window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1249 softcap: float. Anything > 0 activates softcapping attention.
1250 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1251 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1252 is added to the attention score of query i and key j.
1253 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1254 which is slightly slower and uses more memory. The forward pass is always deterministic.
1255 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1256 testing only. The returned probabilities are not guaranteed to be correct
1257 (they might not have the right scaling).
1258 Return:
1259 out: (total, nheads, headdim).
1260 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
1261 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1262 normalization factor).
1263 """
1264 if fa_version != 2:
1265 raise RuntimeError("Only FA2 is implemented.")
1266 if num_splits > 0:
1267 raise RuntimeError("num_splits > 0 is not implemented in GEMS.")
1268 if use_c_extension:
1269 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)")
1270 with torch_device_fn.device(q.device):
1271 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func(
1272 q,
1273 k,
1274 v,
1275 max_seqlen_q,
1276 cu_seqlens_q,
1277 max_seqlen_k,
1278 cu_seqlens_k,
1279 seqused_k,
1280 q_v,
1281 dropout_p,
1282 softmax_scale,
1283 causal,
1284 window_size,
1285 softcap,
1286 alibi_slopes,
1287 deterministic,
1288 return_attn_probs,
1289 block_table,
1290 return_softmax_lse,
1291 out,
1292 scheduler_metadata,
1293 q_descale,
1294 k_descale,
1295 v_descale,
1296 s_aux,
1297 num_splits,
1298 cp_world_size,
1299 cp_rank,
1300 cp_tot_seqused_k,
1301 fa_version,
1302 )
1303 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp
1304 else:
1305 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC")
1306 assert (
1307 cu_seqlens_k is not None or seqused_k is not None
1308 ), "cu_seqlens_k or seqused_k must be provided"
1309 assert (
1310 cu_seqlens_k is None or seqused_k is None
1311 ), "cu_seqlens_k and seqused_k cannot be provided at the same time"
1312 assert (
1313 block_table is None or seqused_k is not None
1314 ), "seqused_k must be provided if block_table is provided"
1315 if softmax_scale is None:
1316 softmax_scale = q.shape[-1] ** (-0.5)
1317 # custom op does not support non-tuple input
1318 if window_size is None:
1319 real_window_size = (-1, -1)
1320 else:
1321 assert len(window_size) == 2
1322 real_window_size = (window_size[0], window_size[1])
1323 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1324 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
1325 max_seqlen_q = (
1326 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q
1327 )
1328 max_seqlen_k = (
1329 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k
1330 )
1331 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd(
1332 q,
1333 k,
1334 v,
1335 out,
1336 cu_seqlens_q,
1337 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
1338 # still wants it so we pass all zeros
1339 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
1340 seqused_k,
1341 None,
1342 block_table,
1343 alibi_slopes,
1344 max_seqlen_q,
1345 max_seqlen_k,
1346 dropout_p,
1347 softmax_scale,
1348 False,
1349 causal,
1350 real_window_size[0],
1351 real_window_size[1],
1352 softcap,
1353 return_softmax_lse and dropout_p > 0,
1354 None,
1355 )
1357 return (out, softmax_lse) if return_softmax_lse else out