Coverage for src/flag_gems/ops/flash_kernel.py: 13%
574 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import triton
2import triton.language as tl
4from flag_gems import runtime
5from flag_gems.utils import libentry, tl_extra_shim
8@triton.jit
9def u64_to_lohi(x):
10 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32)
13@triton.jit
14def u64_from_lohi(lo, hi):
15 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64)
18@triton.jit
19def philox_(seed, subsequence, offset):
20 kPhilox10A: tl.constexpr = 0x9E3779B9
21 kPhilox10B: tl.constexpr = 0xBB67AE85
22 k0, k1 = u64_to_lohi(seed.to(tl.uint64))
23 c0, c1 = u64_to_lohi(offset.to(tl.uint64))
24 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64))
26 # pragma unroll
27 kPhiloxSA: tl.constexpr = 0xD2511F53
28 kPhiloxSB: tl.constexpr = 0xCD9E8D57
29 for _ in tl.static_range(6):
30 res0 = kPhiloxSA * c0.to(tl.uint64)
31 res1 = kPhiloxSB * c2.to(tl.uint64)
32 res0_x, res0_y = u64_to_lohi(res0)
33 res1_x, res1_y = u64_to_lohi(res1)
34 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
35 k0 += kPhilox10A
36 k1 += kPhilox10B
38 res0 = kPhiloxSA * c0.to(tl.uint64)
39 res1 = kPhiloxSB * c2.to(tl.uint64)
40 res0_x, res0_y = u64_to_lohi(res0)
41 res1_x, res1_y = u64_to_lohi(res1)
42 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
44 return c0, c1, c2, c3
47@triton.jit
48def apply_dropout_mask(
49 P,
50 mask,
51 encode_dropout_in_sign_bit: tl.constexpr,
52):
53 if encode_dropout_in_sign_bit:
54 P = tl.where(mask, -P, P)
55 else:
56 P = tl.where(mask, (P * 0).to(P.dtype), P)
57 return P
60@triton.jit
61def apply_dropout(
62 P,
63 row_start,
64 col_start,
65 n_cols,
66 bid,
67 hid,
68 philox_seed,
69 philox_offset,
70 p_dropout_uint8: tl.constexpr,
71 is_dropout: tl.constexpr,
72 encode_dropout_in_sign_bit: tl.constexpr,
73 NUM_HEADS: tl.constexpr,
74 BLOCK_M: tl.constexpr,
75 BLOCK_N: tl.constexpr,
76):
77 if is_dropout:
78 row_start = tl.multiple_of(row_start, BLOCK_M)
79 col_start = tl.multiple_of(col_start, BLOCK_N)
80 row = row_start + tl.arange(0, BLOCK_M)[:, None]
81 # Down scale col_idx by 4
82 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :]
84 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64)
86 offset = philox_offset + bid * NUM_HEADS + hid
87 offset += subsequence * 0
88 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset)
90 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N)
92 mask = (r & 0xFF) >= p_dropout_uint8
94 P = apply_dropout_mask(
95 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit
96 )
97 return P
100@triton.jit
101def apply_alibi(
102 S,
103 col_idx,
104 row_idx,
105 max_seqlen_q,
106 max_seqlen_k,
107 is_causal: tl.constexpr,
108 is_alibi: tl.constexpr,
109 alibi_slope: tl.constexpr = None,
110):
111 if is_alibi:
112 if is_causal:
113 # The row independent alibi bias renders the same attention output
114 # as with the standard alibi because softmax is shift invariant, i.e.,
115 # softmax(A + bias + const) = softamx(A + bias). The following two
116 # biases are no different if causal is true.
117 # bias_1 = [
118 # -4, -3, -2, X, X,
119 # -4, -3, -2, -1, X,
120 # -4, -3, -2, -1, 0,
121 # ]
122 # bias_2 = [
123 # -2, -1, 0, X, X,
124 # -3, -2, -1, 0, X,
125 # -4, -3, -2, -1, 0,
126 # ]
127 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32)
128 S += bias
129 else:
130 bias = -alibi_slope * tl.abs(
131 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None]
132 ).to(tl.float32)
133 S += bias
135 return S
138@triton.jit
139def apply_mask(
140 S,
141 col_idx,
142 row_idx,
143 max_seqlen_q,
144 max_seqlen_k,
145 window_size_left,
146 window_size_right,
147 is_even_mn: tl.constexpr,
148 is_causal: tl.constexpr,
149 is_local: tl.constexpr,
150):
151 need_mask = is_causal | is_local | (not is_even_mn)
152 # need_mask: tl.constexpr = is_causal | is_local
153 if need_mask:
154 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive!
155 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left)
156 col_rb = min(
157 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right
158 )
160 if is_causal:
161 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S)
163 if is_local:
164 S = tl.where(
165 (col_idx[None, :] > col_rb[:, None])
166 | (col_idx[None, :] < col_lb[:, None]),
167 float("-inf"),
168 S,
169 )
171 if (not is_local) & (not is_causal) & (not is_even_mn):
172 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S)
174 return S
177@triton.jit
178def softmax_rescale(
179 O_acc,
180 S,
181 row_max,
182 row_sum,
183 softmax_scale_log2e: tl.constexpr,
184 is_border: tl.constexpr,
185 # is_init: tl.constexpr
186):
187 prev_max = row_max
188 row_max = tl.maximum(row_max, tl.max(S, 1))
190 if is_border:
191 cur_max = tl.where(row_max == float("-inf"), 0, row_max)
192 else:
193 cur_max = row_max
195 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
196 row_sum *= p_scale
197 O_acc *= p_scale[:, None]
199 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e)
200 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
201 row_sum = row_sum + tl.sum(P, 1)
202 return O_acc, P, row_max, row_sum
205@triton.jit
206def apply_softcap(S, softcap, is_softcap: tl.constexpr):
207 if is_softcap:
208 S = tl_extra_shim.tanh(S * softcap)
210 return S
213def block_m_splitkv_heuristic(headdim):
214 return 128 if headdim <= 128 else 64
217def block_n_splitkv_heuristic(headdim):
218 return 64 if headdim <= 64 else 32
221def is_even_mn(M, N, BM, BN, WL, WR):
222 if M % BM == 0 and N % BN == 0:
223 if M % N == 0 or N % M == 0:
224 if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0):
225 return True
226 return False
229def block_m_splitkv_heuristic_spec_args(args):
230 return 128 if args["d"] <= 128 else 64
233def block_n_splitkv_heuristic_spec_args(args):
234 return 64 if args["d"] <= 64 else 32
237def is_even_mn_spec_args(args):
238 if (
239 args["seqlen_q"] % args["BLOCK_M"] == 0
240 and args["seqlen_k"] % args["BLOCK_N"] == 0
241 ):
242 if (
243 args["seqlen_q"] % args["seqlen_k"] == 0
244 or args["seqlen_k"] % args["seqlen_q"] == 0
245 ):
246 if (
247 args["window_size_left"] == -1
248 or args["window_size_left"] % args["BLOCK_N"] == 0
249 ) and (
250 args["window_size_right"] == -1
251 or args["window_size_right"] % args["BLOCK_N"] == 0
252 ):
253 return True
254 return False
257def keep(cfg, must_keep=None):
258 BM = cfg.kwargs["BLOCK_M"]
259 BN = cfg.kwargs["BLOCK_N"]
260 w = cfg.num_warps
262 # we always keep configurations in `must_keep`
263 return (BM, BN, w) in ((128, 32, 4), (128, 128, 8)) or (
264 must_keep and cfg in must_keep
265 )
268def prune_fwd_configs(configs, nargs, **kwargs):
269 is_dropout = nargs["is_dropout"]
270 if is_dropout:
271 return list(
272 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs)
273 )
274 else:
275 return configs
278def flash_fwd_kernel_heur_block_k(args):
279 return triton.next_power_of_2(args["d"])
282@libentry()
283@triton.autotune(
284 configs=list(filter(keep, runtime.get_tuned_config("attention"))),
285 prune_configs_by={"early_config_prune": prune_fwd_configs},
286 key=["d", "is_dropout"],
287)
288@triton.heuristics(
289 values={
290 "BLOCK_K": flash_fwd_kernel_heur_block_k,
291 "PRE_LOAD_V": lambda args: False,
292 "IS_EVEN_MN": lambda args: is_even_mn(
293 args["seqlen_q"],
294 args["seqlen_k"],
295 args["BLOCK_M"],
296 args["BLOCK_N"],
297 args["window_size_left"],
298 args["window_size_right"],
299 ),
300 }
301)
302@triton.jit(
303 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
304)
305def flash_fwd_kernel(
306 q_ptr,
307 k_ptr,
308 v_ptr,
309 o_ptr,
310 p_ptr,
311 softmax_lse_ptr,
312 q_row_stride,
313 k_row_stride,
314 v_row_stride,
315 q_head_stride,
316 k_head_stride,
317 v_head_stride,
318 o_row_stride,
319 o_head_stride,
320 q_batch_stride,
321 k_batch_stride,
322 v_batch_stride,
323 o_batch_stride,
324 is_cu_seqlens_q,
325 cu_seqlens_q_ptr,
326 is_cu_seqlens_k,
327 cu_seqlens_k_ptr,
328 is_seqused_k,
329 seqused_k_ptr,
330 # sizes
331 b: tl.constexpr,
332 bk: tl.constexpr,
333 h: tl.constexpr,
334 hk: tl.constexpr,
335 h_hk_ratio: tl.constexpr,
336 seqlen_q,
337 seqlen_k,
338 seqlen_q_rounded,
339 seqlen_k_rounded,
340 d: tl.constexpr,
341 d_rounded: tl.constexpr,
342 # scaling factors
343 is_softcap: tl.constexpr,
344 softcap: tl.constexpr,
345 scale_softmax: tl.constexpr,
346 scale_softmax_log2: tl.constexpr,
347 # dropout
348 is_dropout: tl.constexpr,
349 p_dropout: tl.constexpr,
350 rp_dropout: tl.constexpr,
351 p_dropout_in_uint8_t: tl.constexpr,
352 philox_args,
353 return_softmax: tl.constexpr,
354 # causal and swa
355 is_causal: tl.constexpr,
356 is_local: tl.constexpr,
357 window_size_left: tl.constexpr,
358 window_size_right: tl.constexpr,
359 seqlenq_ngroups_swapped: tl.constexpr,
360 is_paged: tl.constexpr,
361 # alibi
362 is_alibi: tl.constexpr,
363 alibi_slopes_ptr,
364 alibi_slopes_batch_stride: tl.constexpr,
365 # block table
366 total_q: tl.constexpr,
367 page_table_ptr,
368 page_table_batch_stride: tl.constexpr,
369 block_size: tl.constexpr,
370 # kernel params
371 IS_EVEN_MN: tl.constexpr,
372 PRE_LOAD_V: tl.constexpr,
373 BLOCK_M: tl.constexpr,
374 BLOCK_N: tl.constexpr,
375 BLOCK_K: tl.constexpr,
376 num_warps: tl.constexpr,
377 num_stages: tl.constexpr,
378):
379 m_block = tl.program_id(0)
380 bh = tl.program_id(1)
381 hid = bh % h
382 bid = bh // h
383 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
385 # We draw a minimum covering frame on the attention map that this CTA is assigned to process.
386 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively.
388 col_min = 0
389 if is_local:
390 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left)
391 if not IS_EVEN_MN:
392 # round left
393 col_min = (col_min // BLOCK_N) * BLOCK_N
395 col_max = seqlen_k
396 if is_causal or is_local:
397 col_max += (m_block - num_m_blocks + 1) * BLOCK_M
398 if is_local:
399 col_max += window_size_right
400 col_max = min(seqlen_k, col_max)
402 if not IS_EVEN_MN:
403 # round right
404 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N
406 if (not is_causal) and (not is_local):
407 if IS_EVEN_MN:
408 masking_cols: tl.constexpr = 0
409 else:
410 masking_cols: tl.constexpr = BLOCK_N
411 elif (
412 is_causal | is_local
413 ) and IS_EVEN_MN: # causal implies window_size_right is zero
414 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
415 else:
416 # local
417 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
419 if is_dropout:
420 philox_seed = tl.load(philox_args).to(tl.uint64)
421 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
423 if is_alibi:
424 alibi_offset = bid * alibi_slopes_batch_stride + hid
425 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
426 alibi_slope /= scale_softmax
427 else:
428 alibi_slope = 0.0
430 q_batch_stride = tl.multiple_of(q_batch_stride, d * h)
431 q_ptr += bid * q_batch_stride + hid * q_head_stride
432 row_start = m_block * BLOCK_M
433 row_idx = row_start + tl.arange(0, BLOCK_M)
434 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
435 dmask = tl.arange(0, BLOCK_K) < d
436 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
437 if IS_EVEN_MN & d == BLOCK_K:
438 Q = tl.load(q_ptr + q_off, cache_modifier=".cg")
439 else:
440 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg")
442 if return_softmax:
443 p_ptr += (
444 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M
445 ) * seqlen_k_rounded
446 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(
447 0, BLOCK_N
448 )
449 p_bp0 = p_ptr + p_offset
451 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
452 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
453 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
455 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk)
456 h_hk_ratio = h // hk
457 k_ptr += bid * k_batch_stride
458 k_ptr += (hid // h_hk_ratio) * k_head_stride
459 v_ptr += bid * k_batch_stride
460 v_ptr += (hid // h_hk_ratio) * k_head_stride
462 k_offset = (
463 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
464 )
465 v_offset = (
466 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
467 )
469 p_bk0 = k_ptr + k_offset
470 p_bv0 = v_ptr + v_offset
472 if is_causal | is_local | (not IS_EVEN_MN):
473 # Cut short masking cols if there's not enough cols out there
474 masking_cols = min(col_max - col_min, masking_cols)
475 for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
476 col_start = col_max - col_shift - BLOCK_N
477 col_start = tl.multiple_of(col_start, BLOCK_N)
478 off = col_start * k_row_stride
479 if IS_EVEN_MN & d == BLOCK_K:
480 K = tl.load(p_bk0 + off, cache_modifier=".cg")
481 if PRE_LOAD_V:
482 V = tl.load(p_bv0 + off, cache_modifier=".cg")
483 elif d == BLOCK_K:
484 col_idx = col_start + tl.arange(0, BLOCK_N)
485 kvmask = col_idx < seqlen_k
486 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
487 if PRE_LOAD_V:
488 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
489 else:
490 col_idx = col_start + tl.arange(0, BLOCK_N)
491 kvmask = col_idx < seqlen_k
492 K = tl.load(
493 p_bk0 + off,
494 mask=kvmask[None, :] & dmask[:, None],
495 cache_modifier=".cg",
496 )
497 if PRE_LOAD_V:
498 V = tl.load(
499 p_bv0 + off,
500 mask=kvmask[:, None] & dmask[None, :],
501 cache_modifier=".cg",
502 )
503 S = tl.dot(Q, K, allow_tf32=False)
504 S = apply_softcap(S, softcap, is_softcap)
505 col_idx = col_start + tl.arange(0, BLOCK_N)
506 row_idx = row_start + tl.arange(0, BLOCK_M)
507 S = apply_alibi(
508 S,
509 col_idx,
510 row_idx,
511 seqlen_q,
512 seqlen_k,
513 is_causal=is_causal,
514 is_alibi=is_alibi,
515 alibi_slope=alibi_slope,
516 )
517 # tl.store(p_bp0 + col_start, S)
518 S = apply_mask(
519 S,
520 col_idx,
521 row_idx,
522 seqlen_q,
523 seqlen_k,
524 window_size_left,
525 window_size_right,
526 is_even_mn=IS_EVEN_MN,
527 is_causal=is_causal,
528 is_local=is_local,
529 )
531 acc_, P, rowmax_, rowsum_ = softmax_rescale(
532 acc_,
533 S,
534 rowmax_,
535 rowsum_,
536 softmax_scale_log2e=scale_softmax_log2,
537 is_border=(is_causal or is_local),
538 )
539 P = P.to(v_ptr.type.element_ty)
541 if is_dropout:
542 if return_softmax:
543 P_drop = P
545 P_drop = apply_dropout(
546 P_drop,
547 row_start,
548 col_start,
549 seqlen_k,
550 bid,
551 hid,
552 philox_seed,
553 philox_offset,
554 p_dropout_in_uint8_t,
555 is_dropout,
556 encode_dropout_in_sign_bit=True,
557 NUM_HEADS=h,
558 BLOCK_M=BLOCK_M,
559 BLOCK_N=BLOCK_N,
560 )
561 if IS_EVEN_MN:
562 tl.store(p_bp0 + col_start, P_drop)
563 else:
564 kvmask = col_idx < seqlen_k
565 tl.store(
566 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]
567 )
569 P = apply_dropout(
570 P,
571 row_start,
572 col_start,
573 seqlen_k,
574 bid,
575 hid,
576 philox_seed,
577 philox_offset,
578 p_dropout_in_uint8_t,
579 is_dropout,
580 encode_dropout_in_sign_bit=False,
581 NUM_HEADS=h,
582 BLOCK_M=BLOCK_M,
583 BLOCK_N=BLOCK_N,
584 )
586 if not PRE_LOAD_V:
587 off = col_start * k_row_stride
588 if IS_EVEN_MN & d == BLOCK_K:
589 V = tl.load(p_bv0 + off, cache_modifier=".cg")
590 elif d == BLOCK_K:
591 kvmask = col_idx < seqlen_k
592 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
593 else:
594 kvmask = col_idx < seqlen_k
595 V = tl.load(
596 p_bv0 + off,
597 mask=kvmask[:, None] & dmask[None, :],
598 cache_modifier=".cg",
599 )
600 acc_ = tl.dot(P, V, acc_, allow_tf32=False)
602 for col_start in tl.range(
603 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages
604 ):
605 col_start = tl.multiple_of(col_start, BLOCK_N)
606 off = col_start * k_row_stride
607 if d == BLOCK_K:
608 K = tl.load(p_bk0 + off, cache_modifier=".cg")
609 if PRE_LOAD_V:
610 V = tl.load(p_bv0 + off, cache_modifier=".cg")
611 else:
612 K = tl.load(p_bk0 + off, mask=dmask[:, None], cache_modifier=".cg")
613 if PRE_LOAD_V:
614 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
616 S = tl.dot(Q, K)
617 S = apply_softcap(S, softcap, is_softcap)
618 col_idx = col_start + tl.arange(0, BLOCK_N)
619 row_idx = row_start + tl.arange(0, BLOCK_M)
620 S = apply_alibi(
621 S,
622 col_idx,
623 row_idx,
624 seqlen_q,
625 seqlen_k,
626 is_causal=is_causal,
627 is_alibi=is_alibi,
628 alibi_slope=alibi_slope,
629 )
630 S = apply_mask(
631 S,
632 col_idx,
633 row_idx,
634 seqlen_q,
635 seqlen_k,
636 window_size_left,
637 window_size_right,
638 is_even_mn=True,
639 is_causal=False,
640 is_local=is_local,
641 )
643 acc_, P, rowmax_, rowsum_ = softmax_rescale(
644 acc_,
645 S,
646 rowmax_,
647 rowsum_,
648 softmax_scale_log2e=scale_softmax_log2,
649 is_border=is_local,
650 )
651 P = P.to(v_ptr.type.element_ty)
653 if is_dropout:
654 if return_softmax:
655 P_drop = P
656 P_drop = apply_dropout(
657 P_drop,
658 row_start,
659 col_start,
660 seqlen_k,
661 bid,
662 hid,
663 philox_seed,
664 philox_offset,
665 p_dropout_in_uint8_t,
666 is_dropout,
667 encode_dropout_in_sign_bit=True,
668 NUM_HEADS=h,
669 BLOCK_M=BLOCK_M,
670 BLOCK_N=BLOCK_N,
671 )
672 if IS_EVEN_MN:
673 tl.store(p_bp0 + col_start, P_drop)
674 else:
675 kvmask = col_idx < seqlen_k
676 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
678 P = apply_dropout(
679 P,
680 row_start,
681 col_start,
682 seqlen_k,
683 bid,
684 hid,
685 philox_seed,
686 philox_offset,
687 p_dropout_in_uint8_t,
688 is_dropout,
689 encode_dropout_in_sign_bit=False,
690 NUM_HEADS=h,
691 BLOCK_M=BLOCK_M,
692 BLOCK_N=BLOCK_N,
693 )
695 if not PRE_LOAD_V:
696 off = col_start * k_row_stride
697 if d == BLOCK_K:
698 V = tl.load(p_bv0 + off, cache_modifier=".cg")
699 else:
700 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
701 acc_ = tl.dot(P, V, acc_)
703 # LSE
704 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels
705 # the effect of rowmax and outputs lse only.
706 lse = tl.where(
707 rowsum_ == 0 | (rowsum_ != rowsum_),
708 float("inf"),
709 rowmax_ * scale_softmax + tl.log(rowsum_),
710 )
711 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
713 if is_dropout:
714 acc_ *= inv_sum[:, None] * rp_dropout
715 else:
716 acc_ *= inv_sum[:, None]
718 out = acc_.to(o_ptr.type.element_ty) # noqa
720 # Write back output
721 o_batch_stride = tl.multiple_of(o_batch_stride, d * h)
722 o_ptr += bid * o_batch_stride
723 o_ptr += hid * o_head_stride
724 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, BLOCK_K)
726 if IS_EVEN_MN & d == BLOCK_K:
727 tl.store(o_ptr + o_offset, out)
728 else:
729 tl.store(o_ptr + o_offset, out, mask=qmask)
731 # Write back lse
732 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q
733 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
735 if IS_EVEN_MN:
736 tl.store(p_lse + row_idx, lse)
737 else:
738 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q)
741@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"])
742def flash_fwd_bh_parallel_kernel():
743 # (TODO)
744 pass
747def flash_fwd_splitkv_kernel_heur_block_k(args):
748 return triton.next_power_of_2(args["d"])
751@libentry()
752@triton.heuristics(
753 values={
754 "BLOCK_M": block_m_splitkv_heuristic_spec_args,
755 "BLOCK_N": block_n_splitkv_heuristic_spec_args,
756 "BLOCK_K": flash_fwd_splitkv_kernel_heur_block_k,
757 "num_warps": lambda args: 4,
758 "num_stages": lambda args: 3,
759 "PRE_LOAD_V": lambda args: True,
760 "IS_EVEN_MN": is_even_mn_spec_args,
761 }
762)
763@triton.jit(
764 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
765)
766def flash_fwd_splitkv_kernel(
767 q_ptr,
768 k_ptr,
769 v_ptr,
770 o_ptr,
771 p_ptr,
772 softmax_lse_ptr,
773 q_row_stride,
774 k_row_stride,
775 v_row_stride,
776 q_head_stride,
777 k_head_stride,
778 v_head_stride,
779 o_row_stride,
780 o_head_stride,
781 q_batch_stride,
782 k_batch_stride,
783 v_batch_stride,
784 o_batch_stride,
785 is_cu_seqlens_q,
786 cu_seqlens_q_ptr,
787 is_cu_seqlens_k: tl.constexpr,
788 cu_seqlens_k_ptr,
789 is_seqused_k: tl.constexpr,
790 seqused_k_ptr,
791 # sizes
792 b: tl.constexpr,
793 bk: tl.constexpr,
794 h: tl.constexpr,
795 hk: tl.constexpr,
796 h_hk_ratio: tl.constexpr,
797 seqlen_q,
798 seqlen_k,
799 seqlen_q_rounded,
800 seqlen_k_rounded,
801 d: tl.constexpr,
802 d_rounded: tl.constexpr,
803 # scaling factors
804 is_softcap: tl.constexpr,
805 softcap: tl.constexpr,
806 scale_softmax: tl.constexpr,
807 scale_softmax_log2: tl.constexpr,
808 # dropout
809 is_dropout: tl.constexpr,
810 p_dropout: tl.constexpr,
811 rp_dropout: tl.constexpr,
812 p_dropout_in_uint8_t: tl.constexpr,
813 philox_args,
814 return_softmax: tl.constexpr,
815 # causal and swa
816 is_causal: tl.constexpr,
817 is_local: tl.constexpr,
818 window_size_left: tl.constexpr,
819 window_size_right: tl.constexpr,
820 seqlenq_ngroups_swapped: tl.constexpr,
821 is_paged: tl.constexpr,
822 # alibi
823 is_alibi: tl.constexpr,
824 alibi_slopes_ptr,
825 alibi_slopes_batch_stride: tl.constexpr,
826 # block table
827 total_q,
828 page_table_ptr,
829 page_table_batch_stride: tl.constexpr,
830 block_size: tl.constexpr,
831 # kernel params
832 IS_EVEN_MN: tl.constexpr,
833 PRE_LOAD_V: tl.constexpr,
834 blocks_per_split: tl.constexpr,
835 BLOCK_M: tl.constexpr,
836 BLOCK_N: tl.constexpr,
837 BLOCK_K: tl.constexpr,
838 num_warps: tl.constexpr,
839 num_stages: tl.constexpr,
840):
841 m_block = tl.program_id(0)
842 split_id = tl.program_id(1)
843 bid = tl.program_id(2) // h
844 hid = tl.program_id(2) % h
846 split_block_min = split_id * blocks_per_split
847 split_block_max = split_block_min + blocks_per_split
849 n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
850 if is_causal:
851 n_block_max = min(
852 n_block_max,
853 tl.cdiv(
854 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right,
855 BLOCK_N,
856 ),
857 )
859 if is_alibi:
860 alibi_offset = bid * alibi_slopes_batch_stride + hid
861 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
862 alibi_slope /= scale_softmax
863 else:
864 alibi_slope = 0
866 if not is_causal:
867 if IS_EVEN_MN:
868 masking_block_min = n_block_max
869 else:
870 masking_block_min = n_block_max - 1
871 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero
872 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N)
873 else:
874 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1
876 q_ptr += bid * q_batch_stride
877 q_ptr += hid * q_head_stride
878 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
879 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
880 p_qm = q_ptr + q_off
881 dmask = tl.arange(0, BLOCK_K) < d
882 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
883 if IS_EVEN_MN & BLOCK_K == d:
884 Q = tl.load(p_qm, cache_modifier=".cg")
885 else:
886 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg")
888 h_hk_ratio = h // hk
889 k_ptr += bid * k_batch_stride
890 k_ptr += (hid // h_hk_ratio) * k_head_stride
891 v_ptr += bid * k_batch_stride
892 v_ptr += (hid // h_hk_ratio) * k_head_stride
894 k_offset = (
895 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
896 )
897 p_k0 = k_ptr + k_offset
899 v_offset = (
900 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
901 )
902 p_v0 = v_ptr + v_offset
904 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
905 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
906 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
908 if split_block_max <= masking_block_min:
909 # no masking needed
910 for n_block in tl.range(
911 split_block_min, split_block_max, num_stages=num_stages
912 ):
913 kv_off = n_block * BLOCK_N * k_row_stride
914 if d == BLOCK_K:
915 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
916 else:
917 K = tl.load(
918 p_k0 + kv_off, mask=dmask[:, None], cache_modifier=".cg", other=0.0
919 )
920 if PRE_LOAD_V:
921 if d == BLOCK_K:
922 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
923 else:
924 V = tl.load(
925 p_v0 + kv_off,
926 mask=dmask[None, :],
927 cache_modifier=".cg",
928 other=0.0,
929 )
930 S = tl.dot(Q, K)
931 S = apply_softcap(S, softcap, is_softcap)
932 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
933 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
934 S = apply_alibi(
935 S,
936 col_idx,
937 row_idx,
938 seqlen_q,
939 seqlen_k,
940 is_causal=is_causal,
941 is_alibi=is_alibi,
942 alibi_slope=alibi_slope,
943 )
944 acc_, P, rowmax_, rowsum_ = softmax_rescale(
945 acc_,
946 S,
947 rowmax_,
948 rowsum_,
949 softmax_scale_log2e=scale_softmax_log2,
950 is_border=False,
951 )
953 if not PRE_LOAD_V:
954 if d == BLOCK_K:
955 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
956 else:
957 V = tl.load(
958 p_v0 + kv_off,
959 mask=dmask[None, :],
960 cache_modifier=".cg",
961 other=0.0,
962 )
963 P = P.to(v_ptr.type.element_ty)
964 acc_ = tl.dot(P, V, acc_)
965 else:
966 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)):
967 kv_off = n_block * BLOCK_N * k_row_stride
968 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
969 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
970 if IS_EVEN_MN & d == BLOCK_K:
971 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
972 if PRE_LOAD_V:
973 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
974 elif d == BLOCK_K:
975 kvmask = col_idx < seqlen_k
976 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg")
977 if PRE_LOAD_V:
978 V = tl.load(
979 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
980 )
981 else:
982 kvmask = col_idx < seqlen_k
983 K = tl.load(
984 p_k0 + kv_off,
985 mask=dmask[:, None] & kvmask[None, :],
986 cache_modifier=".cg",
987 other=0.0,
988 )
989 if PRE_LOAD_V:
990 V = tl.load(
991 p_v0 + kv_off,
992 mask=dmask[None, :] & kvmask[:, None],
993 cache_modifier=".cg",
994 other=0.0,
995 )
997 S = tl.dot(Q, K)
998 S = apply_softcap(S, softcap, is_softcap)
999 S = apply_alibi(
1000 S,
1001 col_idx,
1002 row_idx,
1003 seqlen_q,
1004 seqlen_k,
1005 is_causal=is_causal,
1006 is_alibi=is_alibi,
1007 alibi_slope=alibi_slope,
1008 )
1009 S = apply_mask(
1010 S,
1011 col_idx,
1012 row_idx,
1013 seqlen_q,
1014 seqlen_k,
1015 window_size_left,
1016 window_size_right,
1017 is_even_mn=IS_EVEN_MN,
1018 is_causal=is_causal,
1019 is_local=False,
1020 )
1022 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1023 acc_,
1024 S,
1025 rowmax_,
1026 rowsum_,
1027 softmax_scale_log2e=scale_softmax_log2,
1028 is_border=(is_causal or is_local),
1029 )
1031 if not PRE_LOAD_V:
1032 if IS_EVEN_MN & d == BLOCK_K:
1033 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
1034 elif d == BLOCK_K:
1035 V = tl.load(
1036 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
1037 )
1038 else:
1039 V = tl.load(
1040 p_v0 + kv_off,
1041 mask=dmask[None, :] & kvmask[:, None],
1042 cache_modifier=".cg",
1043 other=0.0,
1044 )
1045 P = P.to(v_ptr.type.element_ty)
1046 acc_ = tl.dot(P, V, acc_)
1048 # LSE
1049 lse = tl.where(
1050 rowsum_ == 0 | (rowsum_ != rowsum_),
1051 float("-inf"),
1052 rowmax_ * scale_softmax + tl.log(rowsum_),
1053 )
1054 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1056 # Rescale output
1057 acc_ *= inv_sum[:, None]
1059 # Write back output
1060 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size)
1061 # grid = (seq_block, split, batch * head)
1062 o_split_ptr = o_ptr
1063 # + split, batch, head offsets, seq_block offsets are already added in row_idx
1064 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d
1065 o_split_offset = row_idx[:, None] * d + tl.arange(0, BLOCK_K)
1066 o_split_ptr = tl.multiple_of(o_split_ptr, d)
1067 p_om = o_split_ptr + o_split_offset
1069 if IS_EVEN_MN & BLOCK_K == d:
1070 tl.store(p_om, acc_, cache_modifier=".cg")
1071 else:
1072 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg")
1074 # Write back lse
1075 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
1076 lse_split_ptr = softmax_lse_ptr
1077 # + split, batch, head, seq_block offsets
1078 lse_split_ptr += (
1079 split_id * tl.num_programs(2) + tl.program_id(2)
1080 ) * seqlen_q + m_block * BLOCK_M
1082 if IS_EVEN_MN:
1083 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg")
1084 else:
1085 tl.store(
1086 lse_split_ptr + tl.arange(0, BLOCK_M),
1087 lse,
1088 mask=row_idx < seqlen_q,
1089 cache_modifier=".cg",
1090 )
1093@libentry()
1094@triton.jit
1095def flash_fwd_splitkv_combine_kernel(
1096 out_ptr,
1097 lse_ptr,
1098 out_splits_ptr,
1099 lse_splits_ptr,
1100 head_size: tl.constexpr,
1101 out_split_stride,
1102 lse_split_stride,
1103 out_b_stride,
1104 out_s_stride,
1105 out_h_stride,
1106 n_splits,
1107 BLOCK_M: tl.constexpr,
1108 BLOCK_K: tl.constexpr,
1109 q_total,
1110 MAX_N_SPLITS: tl.constexpr,
1111):
1112 pid = tl.program_id(0)
1113 lse_splits_ptr += pid * BLOCK_M
1114 lse_ptr += pid * BLOCK_M
1115 out_splits_ptr += pid * BLOCK_M * head_size
1116 out_ptr += pid * BLOCK_M * head_size
1118 # Subtracting maximum from each of the split lse's for better numerical stability
1119 lse_split_offset = (
1120 tl.arange(0, BLOCK_M)[:, None]
1121 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride
1122 )
1123 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & (
1124 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits
1125 )
1126 lse_splits = tl.load(
1127 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf")
1128 )
1129 max_lse = tl.max(lse_splits, 1)
1131 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse)
1132 Zi_scaled = tl.exp(lse_splits - max_lse[:, None])
1133 Z_scaled = tl.sum(Zi_scaled, 1)
1134 Zi_Z = Zi_scaled / Z_scaled[:, None]
1136 # Write back LSE
1137 lse = tl.log(Z_scaled) + max_lse
1138 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total
1139 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask)
1141 out_split_offset = (
1142 tl.arange(0, BLOCK_M)[:, None, None] * head_size
1143 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride
1144 + tl.arange(0, BLOCK_K)[None, None, :]
1145 )
1146 out_split_mask = (
1147 (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total)
1148 & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits)
1149 & (tl.arange(0, BLOCK_K)[None, None, :] < head_size)
1150 )
1151 out_splits = tl.load(
1152 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0.0
1153 )
1154 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1)
1155 out = out.to(out_ptr.type.element_ty)
1157 # Write back output
1158 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, BLOCK_K)
1159 dmask = tl.arange(0, BLOCK_K) < head_size
1160 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None] & dmask[None, :])
1163@triton.jit
1164def virtual_to_cache(
1165 virtual_index,
1166 max_virtual_index,
1167 page_table_ptr,
1168 block_size,
1169 boundary_check: tl.constexpr = False,
1170):
1171 # virtual_index is the kv sequence index in the current batch element
1172 # page_table_ptr is already pointed at current batch element's block table entry
1173 # block_size is the size of each block in the page table
1174 virtual_page_index = virtual_index // block_size
1175 page_offset = virtual_index % block_size
1176 if boundary_check:
1177 page_block_index = tl.load(
1178 page_table_ptr + virtual_page_index,
1179 mask=virtual_index < max_virtual_index,
1180 other=0,
1181 ).to(tl.int32)
1182 else:
1183 page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32)
1184 return page_block_index * block_size + page_offset
1187@triton.jit
1188def load_from_kvcache(
1189 virtual_index,
1190 max_virtual_index,
1191 page_table_ptr,
1192 k_ptr_base,
1193 v_ptr_base,
1194 block_size,
1195 d: tl.constexpr,
1196 k_row_stride,
1197 BLOCK_K: tl.constexpr,
1198 boundary_check: tl.constexpr = False,
1199):
1200 kvcache_idx = virtual_to_cache(
1201 virtual_index, max_virtual_index, page_table_ptr, block_size, boundary_check
1202 )
1203 k_offset = tl.arange(0, BLOCK_K)[:, None] + kvcache_idx[None, :] * k_row_stride
1204 v_offset = tl.arange(0, BLOCK_K)[None, :] + kvcache_idx[:, None] * k_row_stride
1205 if d == BLOCK_K:
1206 bK_mask = virtual_index[None, :] < max_virtual_index[None, :]
1207 bV_mask = virtual_index[:, None] < max_virtual_index[:, None]
1208 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0)
1209 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0)
1210 else:
1211 bK_mask = (tl.arange(0, BLOCK_K)[:, None] < d) & (
1212 virtual_index[None, :] < max_virtual_index[None, :]
1213 )
1214 bV_mask = (tl.arange(0, BLOCK_K)[None, :] < d) & (
1215 virtual_index[:, None] < max_virtual_index[:, None]
1216 )
1217 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0)
1218 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0)
1219 return bK, bV
1222@libentry()
1223@triton.jit(
1224 do_not_specialize=[
1225 "q_batch_stride",
1226 "k_batch_stride",
1227 "v_batch_stride",
1228 "o_batch_stride",
1229 "b",
1230 "bk",
1231 "seqlen_q",
1232 "seqlen_k",
1233 "seqlen_q_rounded",
1234 "seqlen_k_rounded",
1235 "total_q",
1236 ]
1237)
1238def flash_varlen_fwd_kernel(
1239 q_ptr,
1240 k_ptr,
1241 v_ptr,
1242 o_ptr,
1243 p_ptr,
1244 softmax_lse_ptr,
1245 q_row_stride,
1246 k_row_stride,
1247 v_row_stride,
1248 q_head_stride,
1249 k_head_stride,
1250 v_head_stride,
1251 o_row_stride,
1252 o_head_stride,
1253 q_batch_stride,
1254 k_batch_stride,
1255 v_batch_stride,
1256 o_batch_stride,
1257 is_cu_seqlens_q: tl.constexpr,
1258 cu_seqlens_q_ptr,
1259 is_cu_seqlens_k: tl.constexpr,
1260 cu_seqlens_k_ptr,
1261 is_seqused_k: tl.constexpr,
1262 seqused_k_ptr,
1263 # sizes
1264 b,
1265 bk,
1266 h: tl.constexpr,
1267 hk: tl.constexpr,
1268 h_hk_ratio: tl.constexpr,
1269 seqlen_q,
1270 seqlen_k,
1271 seqlen_q_rounded,
1272 seqlen_k_rounded,
1273 d: tl.constexpr,
1274 d_rounded: tl.constexpr,
1275 # scaling factors
1276 is_softcap: tl.constexpr,
1277 softcap: tl.constexpr,
1278 scale_softmax: tl.constexpr,
1279 scale_softmax_log2: tl.constexpr,
1280 # dropout
1281 is_dropout: tl.constexpr,
1282 p_dropout: tl.constexpr,
1283 rp_dropout: tl.constexpr,
1284 p_dropout_in_uint8_t: tl.constexpr,
1285 philox_args,
1286 return_softmax: tl.constexpr,
1287 # causal and swa
1288 is_causal: tl.constexpr,
1289 is_local: tl.constexpr,
1290 window_size_left: tl.constexpr,
1291 window_size_right: tl.constexpr,
1292 seqlenq_ngroups_swapped: tl.constexpr,
1293 is_paged: tl.constexpr,
1294 # alibi
1295 is_alibi: tl.constexpr,
1296 alibi_slopes_ptr,
1297 alibi_slopes_batch_stride: tl.constexpr,
1298 # block table
1299 total_q,
1300 page_table_ptr,
1301 page_table_batch_stride: tl.constexpr,
1302 block_size: tl.constexpr,
1303 # kernel params
1304 BLOCK_M: tl.constexpr,
1305 BLOCK_N: tl.constexpr,
1306 BLOCK_K: tl.constexpr,
1307 num_warps: tl.constexpr,
1308 num_stages: tl.constexpr,
1309):
1310 m_block = tl.program_id(0)
1311 bid = tl.program_id(1)
1312 hid = tl.program_id(2)
1313 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
1315 if is_cu_seqlens_q:
1316 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32)
1317 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32)
1318 q_len = q_eos - q_bos
1319 # Current request's start offset in the batched Q
1320 q_offset = q_bos * q_row_stride
1321 o_offset = q_bos * o_row_stride
1322 lse_offset = q_bos * 1
1323 else:
1324 q_len = seqlen_q
1325 q_offset = bid * q_batch_stride
1326 o_offset = bid * o_batch_stride
1327 lse_offset = bid * seqlen_q
1329 if is_cu_seqlens_k:
1330 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32)
1331 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32)
1332 k_len_cache = k_eos - k_bos
1333 # k_offset = k_bos * k_row_stride
1334 else:
1335 k_len_cache = seqlen_k
1336 # k_offset = bid * k_batch_stride
1338 if is_seqused_k:
1339 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32)
1340 else:
1341 k_len = k_len_cache
1343 # Noop CTA
1344 if m_block * BLOCK_M > q_len:
1345 return
1347 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0)
1348 is_even_mn: tl.constexpr = False
1350 if is_local:
1351 n_block_min = max(
1352 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N
1353 )
1354 else:
1355 n_block_min = 0
1357 n_block_max = tl.cdiv(k_len, BLOCK_N)
1358 if is_causal or is_local:
1359 n_block_max = min(
1360 n_block_max,
1361 tl.cdiv(
1362 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N
1363 ),
1364 )
1366 if is_dropout:
1367 philox_seed = tl.load(philox_args).to(tl.uint64)
1368 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
1370 # Locate the page table entry for the current batch element
1371 if is_paged:
1372 page_table_ptr += bid * page_table_batch_stride
1373 # Calculate the starting offset of q for the current head
1374 q_row_offset = hid * q_head_stride
1375 # Calculate the starting offset of k and v for the current head
1376 k_row_offset = (hid // h_hk_ratio) * k_head_stride
1377 # Shift the k, v pointers to align with the current head
1378 k_ptr_base = k_ptr + k_row_offset
1379 v_ptr_base = v_ptr + k_row_offset
1381 gQ = tl.make_block_ptr(
1382 base=q_ptr + q_offset + q_row_offset,
1383 shape=(q_len, d),
1384 strides=(q_row_stride, 1),
1385 offsets=(0, 0),
1386 block_shape=(BLOCK_M, BLOCK_K),
1387 order=(1, 0),
1388 )
1389 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0, 1))
1391 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
1392 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
1393 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
1395 if is_alibi:
1396 alibi_offset = bid * alibi_slopes_batch_stride + hid
1397 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
1398 alibi_slope /= scale_softmax
1399 else:
1400 alibi_slope = 0.0
1402 if not is_causal and not is_local:
1403 n_masking_steps = 1
1404 elif is_even_mn:
1405 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
1406 else:
1407 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
1409 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps)
1411 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1412 n_block = n_block_max - 1
1413 for step in tl.range(0, n_masking_steps):
1414 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1415 if is_paged:
1416 bK, bV = load_from_kvcache(
1417 col_idx,
1418 k_len,
1419 page_table_ptr,
1420 k_ptr_base,
1421 v_ptr_base,
1422 block_size,
1423 d,
1424 k_row_stride,
1425 BLOCK_K=BLOCK_K,
1426 boundary_check=True,
1427 )
1428 else:
1429 start_n = n_block * BLOCK_N
1430 k_ptr_seq = k_ptr_base + k_bos * k_row_stride
1431 v_ptr_seq = v_ptr_base + k_bos * k_row_stride
1432 gK = tl.make_block_ptr(
1433 base=k_ptr_seq,
1434 shape=(k_len, d),
1435 strides=(k_row_stride, 1),
1436 offsets=(start_n, 0),
1437 block_shape=(BLOCK_N, BLOCK_K),
1438 order=(0, 1),
1439 )
1440 gV = tl.make_block_ptr(
1441 base=v_ptr_seq,
1442 shape=(k_len, d),
1443 strides=(k_row_stride, 1),
1444 offsets=(start_n, 0),
1445 block_shape=(BLOCK_N, BLOCK_K),
1446 order=(0, 1),
1447 )
1448 bK = tl.load(gK, boundary_check=(0, 1))
1449 bK = tl.trans(bK)
1450 bV = tl.load(gV, boundary_check=(0, 1))
1451 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1452 S = apply_softcap(S, softcap, is_softcap)
1453 S = apply_alibi(
1454 S,
1455 col_idx,
1456 row_idx,
1457 q_len,
1458 k_len,
1459 is_causal=is_causal,
1460 is_alibi=is_alibi,
1461 alibi_slope=alibi_slope,
1462 )
1463 S = apply_mask(
1464 S,
1465 col_idx,
1466 row_idx,
1467 q_len,
1468 k_len,
1469 window_size_left,
1470 window_size_right,
1471 is_even_mn=is_even_mn,
1472 is_causal=is_causal,
1473 is_local=is_local,
1474 )
1476 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1477 acc_,
1478 S,
1479 rowmax_,
1480 rowsum_,
1481 softmax_scale_log2e=scale_softmax_log2,
1482 is_border=True,
1483 )
1484 P = P.to(v_ptr.type.element_ty)
1486 if is_dropout:
1487 P = apply_dropout(
1488 P,
1489 n_block * BLOCK_N,
1490 m_block * BLOCK_M,
1491 k_len,
1492 bid,
1493 hid,
1494 philox_seed,
1495 philox_offset,
1496 p_dropout_in_uint8_t,
1497 is_dropout,
1498 encode_dropout_in_sign_bit=False,
1499 NUM_HEADS=h,
1500 BLOCK_M=BLOCK_M,
1501 BLOCK_N=BLOCK_N,
1502 )
1504 acc_ = tl.dot(P, bV, acc_)
1505 n_block -= 1
1507 for n_block in tl.range(
1508 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1
1509 ):
1510 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1511 if is_paged:
1512 bK, bV = load_from_kvcache(
1513 col_idx,
1514 k_len,
1515 page_table_ptr,
1516 k_ptr_base,
1517 v_ptr_base,
1518 block_size,
1519 d,
1520 k_row_stride,
1521 BLOCK_K=BLOCK_K,
1522 )
1523 else:
1524 start_n = n_block * BLOCK_N
1525 k_ptr_seq = k_ptr_base + k_bos * k_row_stride
1526 v_ptr_seq = v_ptr_base + k_bos * k_row_stride
1527 gK = tl.make_block_ptr(
1528 base=k_ptr_seq,
1529 shape=(k_len, d),
1530 strides=(k_row_stride, 1),
1531 offsets=(start_n, 0),
1532 block_shape=(BLOCK_N, BLOCK_K),
1533 order=(0, 1),
1534 )
1535 gV = tl.make_block_ptr(
1536 base=v_ptr_seq,
1537 shape=(k_len, d),
1538 strides=(k_row_stride, 1),
1539 offsets=(start_n, 0),
1540 block_shape=(BLOCK_N, BLOCK_K),
1541 order=(0, 1),
1542 )
1543 bK = tl.load(gK)
1544 bK = tl.trans(bK)
1545 bV = tl.load(gV)
1546 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1547 S = apply_softcap(S, softcap, is_softcap)
1548 S = apply_alibi(
1549 S,
1550 col_idx,
1551 row_idx,
1552 q_len,
1553 k_len,
1554 is_causal=is_causal,
1555 is_alibi=is_alibi,
1556 alibi_slope=alibi_slope,
1557 )
1558 S = apply_mask(
1559 S,
1560 col_idx,
1561 row_idx,
1562 q_len,
1563 k_len,
1564 window_size_left,
1565 window_size_right,
1566 is_even_mn=True,
1567 is_causal=False,
1568 is_local=is_local,
1569 )
1571 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1572 acc_,
1573 S,
1574 rowmax_,
1575 rowsum_,
1576 softmax_scale_log2e=scale_softmax_log2,
1577 is_border=is_local,
1578 )
1579 P = P.to(v_ptr.type.element_ty)
1581 if is_dropout:
1582 P = apply_dropout(
1583 P,
1584 m_block * BLOCK_M,
1585 n_block * BLOCK_N,
1586 k_len,
1587 bid,
1588 hid,
1589 philox_seed,
1590 philox_offset,
1591 p_dropout_in_uint8_t,
1592 is_dropout,
1593 encode_dropout_in_sign_bit=False,
1594 NUM_HEADS=h,
1595 BLOCK_M=BLOCK_M,
1596 BLOCK_N=BLOCK_N,
1597 )
1598 acc_ = tl.dot(P, bV, acc_)
1600 # LSE
1601 lse = tl.where(
1602 rowsum_ == 0 | (rowsum_ != rowsum_),
1603 float("inf"),
1604 rowmax_ * scale_softmax + tl.log(rowsum_),
1605 )
1606 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1608 acc_ *= inv_sum[:, None]
1610 out = acc_.to(o_ptr.type.element_ty) # noqa
1612 # Write back output
1613 o_row_offset = hid * o_head_stride
1615 gO = tl.make_block_ptr(
1616 base=o_ptr + o_offset + o_row_offset,
1617 shape=(q_len, d),
1618 strides=(o_row_stride, 1),
1619 offsets=(0, 0),
1620 block_shape=(BLOCK_M, BLOCK_K),
1621 order=(1, 0),
1622 )
1623 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0, 1))
1625 # Write back lse
1626 # lse shape: [h, total_q]
1627 softmax_lse_ptr += hid * total_q
1628 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1629 tl.store(
1630 softmax_lse_ptr + lse_row_offset,
1631 lse,
1632 mask=lse_row_offset < (lse_offset + q_len),
1633 )