Coverage for src/flag_gems/runtime/backend/_hygon/ops/flash_kernel.py: 0%
551 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import triton
2import triton.language as tl
4from flag_gems import runtime
5from flag_gems.utils import libentry, tl_extra_shim
8@triton.jit
9def u64_to_lohi(x):
10 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32)
13@triton.jit
14def u64_from_lohi(lo, hi):
15 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64)
18@triton.jit
19def philox_(seed, subsequence, offset):
20 kPhilox10A: tl.constexpr = 0x9E3779B9
21 kPhilox10B: tl.constexpr = 0xBB67AE85
22 k0, k1 = u64_to_lohi(seed.to(tl.uint64))
23 c0, c1 = u64_to_lohi(offset.to(tl.uint64))
24 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64))
26 # pragma unroll
27 kPhiloxSA: tl.constexpr = 0xD2511F53
28 kPhiloxSB: tl.constexpr = 0xCD9E8D57
29 for _ in tl.static_range(6):
30 res0 = kPhiloxSA * c0.to(tl.uint64)
31 res1 = kPhiloxSB * c2.to(tl.uint64)
32 res0_x, res0_y = u64_to_lohi(res0)
33 res1_x, res1_y = u64_to_lohi(res1)
34 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
35 k0 += kPhilox10A
36 k1 += kPhilox10B
38 res0 = kPhiloxSA * c0.to(tl.uint64)
39 res1 = kPhiloxSB * c2.to(tl.uint64)
40 res0_x, res0_y = u64_to_lohi(res0)
41 res1_x, res1_y = u64_to_lohi(res1)
42 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
44 return c0, c1, c2, c3
47@triton.jit
48def apply_dropout_mask(
49 P,
50 mask,
51 encode_dropout_in_sign_bit: tl.constexpr,
52):
53 if encode_dropout_in_sign_bit:
54 P = tl.where(mask, -P, P)
55 else:
56 P = tl.where(mask, (P * 0).to(P.dtype), P)
57 return P
60@triton.jit
61def apply_dropout(
62 P,
63 row_start,
64 col_start,
65 n_cols,
66 bid,
67 hid,
68 philox_seed,
69 philox_offset,
70 p_dropout_uint8: tl.constexpr,
71 is_dropout: tl.constexpr,
72 encode_dropout_in_sign_bit: tl.constexpr,
73 NUM_HEADS: tl.constexpr,
74 BLOCK_M: tl.constexpr,
75 BLOCK_N: tl.constexpr,
76):
77 if is_dropout:
78 row_start = tl.multiple_of(row_start, BLOCK_M)
79 col_start = tl.multiple_of(col_start, BLOCK_N)
80 row = row_start + tl.arange(0, BLOCK_M)[:, None]
81 # Down scale col_idx by 4
82 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :]
84 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64)
86 offset = philox_offset + bid * NUM_HEADS + hid
87 offset += subsequence * 0
88 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset)
90 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N)
92 mask = (r & 0xFF) >= p_dropout_uint8
94 P = apply_dropout_mask(
95 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit
96 )
97 return P
100@triton.jit
101def apply_alibi(
102 S,
103 col_idx,
104 row_idx,
105 max_seqlen_q,
106 max_seqlen_k,
107 is_causal: tl.constexpr,
108 is_alibi: tl.constexpr,
109 alibi_slope: tl.constexpr = None,
110):
111 if is_alibi:
112 if is_causal:
113 # The row independent alibi bias renders the same attention output
114 # as with the standard alibi because softmax is shift invariant, i.e.,
115 # softmax(A + bias + const) = softamx(A + bias). The following two
116 # biases are no different if causal is true.
117 # bias_1 = [
118 # -4, -3, -2, X, X,
119 # -4, -3, -2, -1, X,
120 # -4, -3, -2, -1, 0,
121 # ]
122 # bias_2 = [
123 # -2, -1, 0, X, X,
124 # -3, -2, -1, 0, X,
125 # -4, -3, -2, -1, 0,
126 # ]
127 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32)
128 S += bias
129 else:
130 bias = -alibi_slope * tl.abs(
131 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None]
132 ).to(tl.float32)
133 S += bias
135 return S
138@triton.jit
139def apply_mask(
140 S,
141 col_idx,
142 row_idx,
143 max_seqlen_q,
144 max_seqlen_k,
145 window_size_left,
146 window_size_right,
147 is_even_mn: tl.constexpr,
148 is_causal: tl.constexpr,
149 is_local: tl.constexpr,
150):
151 need_mask = is_causal | is_local | (not is_even_mn)
152 # need_mask: tl.constexpr = is_causal | is_local
153 if need_mask:
154 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive!
155 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left)
156 col_rb = min(
157 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right
158 )
160 if is_causal:
161 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S)
163 if is_local:
164 S = tl.where(
165 (col_idx[None, :] > col_rb[:, None])
166 | (col_idx[None, :] < col_lb[:, None]),
167 float("-inf"),
168 S,
169 )
171 if (not is_local) & (not is_causal) & (not is_even_mn):
172 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S)
174 return S
177@triton.jit
178def softmax_rescale(
179 O_acc,
180 S,
181 row_max,
182 row_sum,
183 softmax_scale_log2e: tl.constexpr,
184 is_border: tl.constexpr,
185 # is_init: tl.constexpr
186):
187 prev_max = row_max
188 row_max = tl.maximum(row_max, tl.max(S, 1))
190 if is_border:
191 cur_max = tl.where(row_max == float("-inf"), 0, row_max)
192 else:
193 cur_max = row_max
195 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
196 row_sum *= p_scale
197 O_acc *= p_scale[:, None]
199 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e)
200 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
201 row_sum = row_sum + tl.sum(P, 1)
202 return O_acc, P, row_max, row_sum
205@triton.jit
206def apply_softcap(S, softcap, is_softcap: tl.constexpr):
207 if is_softcap:
208 S = tl_extra_shim.tanh(S * softcap)
210 return S
213def block_m_splitkv_heuristic(headdim):
214 return 128 if headdim <= 128 else 32
217def block_n_splitkv_heuristic(headdim):
218 return 64 if headdim <= 64 else 16
221def is_even_mn(M, N, BM, BN, WL, WR):
222 if M % BM == 0 and N % BN == 0:
223 if M % N == 0 or N % M == 0:
224 if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0):
225 return True
226 return False
229def block_m_splitkv_heuristic_spec_args(args):
230 return 128 if args["d"] <= 128 else 32
233def block_n_splitkv_heuristic_spec_args(args):
234 return 64 if args["d"] <= 64 else 16
237def is_even_mn_spec_args(args):
238 if (
239 args["seqlen_q"] % args["BLOCK_M"] == 0
240 and args["seqlen_k"] % args["BLOCK_N"] == 0
241 ):
242 if (
243 args["seqlen_q"] % args["seqlen_k"] == 0
244 or args["seqlen_k"] % args["seqlen_q"] == 0
245 ):
246 if (
247 args["window_size_left"] == -1
248 or args["window_size_left"] % args["BLOCK_N"] == 0
249 ) and (
250 args["window_size_right"] == -1
251 or args["window_size_right"] % args["BLOCK_N"] == 0
252 ):
253 return True
254 return False
257def keep(cfg, must_keep=None):
258 BM = cfg.kwargs["BLOCK_M"]
259 BN = cfg.kwargs["BLOCK_N"]
260 w = cfg.num_warps
262 # we always keep configurations in `must_keep`
263 return (BM, BN, w) in ((32, 16, 4), (64, 16, 4), (128, 16, 4)) or (
264 must_keep and cfg in must_keep
265 )
268def prune_fwd_configs(configs, nargs, **kwargs):
269 is_dropout = nargs["is_dropout"]
270 if is_dropout:
271 return list(
272 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs)
273 )
274 else:
275 return configs
278def flash_fwd_kernel_heur_block_k(args):
279 return triton.next_power_of_2(args["d"])
282@libentry()
283@triton.autotune(
284 configs=list(filter(keep, runtime.get_tuned_config("attention"))),
285 prune_configs_by={"early_config_prune": prune_fwd_configs},
286 key=["d", "is_dropout"],
287)
288@triton.heuristics(
289 values={
290 "BLOCK_K": flash_fwd_kernel_heur_block_k,
291 "PRE_LOAD_V": lambda args: False,
292 "IS_EVEN_MN": lambda args: is_even_mn(
293 args["seqlen_q"],
294 args["seqlen_k"],
295 args["BLOCK_M"],
296 args["BLOCK_N"],
297 args["window_size_left"],
298 args["window_size_right"],
299 ),
300 }
301)
302@triton.jit(
303 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
304)
305def flash_fwd_kernel(
306 q_ptr,
307 k_ptr,
308 v_ptr,
309 o_ptr,
310 p_ptr,
311 softmax_lse_ptr,
312 q_row_stride,
313 k_row_stride,
314 v_row_stride,
315 q_head_stride,
316 k_head_stride,
317 v_head_stride,
318 o_row_stride,
319 o_head_stride,
320 q_batch_stride,
321 k_batch_stride,
322 v_batch_stride,
323 o_batch_stride,
324 is_cu_seqlens_q,
325 cu_seqlens_q_ptr,
326 is_cu_seqlens_k,
327 cu_seqlens_k_ptr,
328 is_seqused_k,
329 seqused_k_ptr,
330 # sizes
331 b: tl.constexpr,
332 bk: tl.constexpr,
333 h: tl.constexpr,
334 hk: tl.constexpr,
335 h_hk_ratio: tl.constexpr,
336 seqlen_q,
337 seqlen_k,
338 seqlen_q_rounded,
339 seqlen_k_rounded,
340 d: tl.constexpr,
341 d_rounded: tl.constexpr,
342 # scaling factors
343 is_softcap: tl.constexpr,
344 softcap: tl.constexpr,
345 scale_softmax: tl.constexpr,
346 scale_softmax_log2: tl.constexpr,
347 # dropout
348 is_dropout: tl.constexpr,
349 p_dropout: tl.constexpr,
350 rp_dropout: tl.constexpr,
351 p_dropout_in_uint8_t: tl.constexpr,
352 philox_args,
353 return_softmax: tl.constexpr,
354 # causal and swa
355 is_causal: tl.constexpr,
356 is_local: tl.constexpr,
357 window_size_left: tl.constexpr,
358 window_size_right: tl.constexpr,
359 seqlenq_ngroups_swapped: tl.constexpr,
360 # alibi
361 is_alibi: tl.constexpr,
362 alibi_slopes_ptr,
363 alibi_slopes_batch_stride: tl.constexpr,
364 # block table
365 total_q: tl.constexpr,
366 page_table_ptr,
367 page_table_batch_stride: tl.constexpr,
368 block_size: tl.constexpr,
369 # kernel params
370 IS_EVEN_MN: tl.constexpr,
371 PRE_LOAD_V: tl.constexpr,
372 BLOCK_M: tl.constexpr,
373 BLOCK_N: tl.constexpr,
374 BLOCK_K: tl.constexpr,
375 num_warps: tl.constexpr,
376 num_stages: tl.constexpr,
377):
378 m_block = tl.program_id(0)
379 bh = tl.program_id(1)
380 hid = bh % h
381 bid = bh // h
382 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
384 # We draw a minimum covering frame on the attention map that this CTA is assigned to process.
385 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively.
387 col_min = 0
388 if is_local:
389 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left)
390 if not IS_EVEN_MN:
391 # round left
392 col_min = (col_min // BLOCK_N) * BLOCK_N
394 col_max = seqlen_k
395 if is_causal or is_local:
396 col_max += (m_block - num_m_blocks + 1) * BLOCK_M
397 if is_local:
398 col_max += window_size_right
399 col_max = min(seqlen_k, col_max)
401 if not IS_EVEN_MN:
402 # round right
403 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N
405 if (not is_causal) and (not is_local):
406 if IS_EVEN_MN:
407 masking_cols: tl.constexpr = 0
408 else:
409 masking_cols: tl.constexpr = BLOCK_N
410 elif (
411 is_causal | is_local
412 ) and IS_EVEN_MN: # causal implies window_size_right is zero
413 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
414 else:
415 # local
416 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
418 if is_dropout:
419 philox_seed = tl.load(philox_args).to(tl.uint64)
420 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
422 if is_alibi:
423 alibi_offset = bid * alibi_slopes_batch_stride + hid
424 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
425 alibi_slope /= scale_softmax
426 else:
427 alibi_slope = 0.0
429 q_batch_stride = tl.multiple_of(q_batch_stride, d * h)
430 q_ptr += bid * q_batch_stride + hid * q_head_stride
431 row_start = m_block * BLOCK_M
432 row_idx = row_start + tl.arange(0, BLOCK_M)
433 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
434 dmask = tl.arange(0, BLOCK_K) < d
435 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
436 if IS_EVEN_MN & d == BLOCK_K:
437 Q = tl.load(q_ptr + q_off, cache_modifier=".cg")
438 else:
439 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg")
441 if return_softmax:
442 p_ptr += (
443 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M
444 ) * seqlen_k_rounded
445 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(
446 0, BLOCK_N
447 )
448 p_bp0 = p_ptr + p_offset
450 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
451 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
452 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
454 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk)
455 h_hk_ratio = h // hk
456 k_ptr += bid * k_batch_stride
457 k_ptr += (hid // h_hk_ratio) * k_head_stride
458 v_ptr += bid * k_batch_stride
459 v_ptr += (hid // h_hk_ratio) * k_head_stride
461 k_offset = (
462 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
463 )
464 v_offset = (
465 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
466 )
468 p_bk0 = k_ptr + k_offset
469 p_bv0 = v_ptr + v_offset
471 if is_causal | is_local | (not IS_EVEN_MN):
472 # Cut short masking cols if there's not enough cols out there
473 masking_cols = min(col_max - col_min, masking_cols)
474 for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
475 col_start = col_max - col_shift - BLOCK_N
476 col_start = tl.multiple_of(col_start, BLOCK_N)
477 off = col_start * k_row_stride
478 if IS_EVEN_MN & d == BLOCK_K:
479 K = tl.load(p_bk0 + off, cache_modifier=".cg")
480 if PRE_LOAD_V:
481 V = tl.load(p_bv0 + off, cache_modifier=".cg")
482 elif d == BLOCK_K:
483 col_idx = col_start + tl.arange(0, BLOCK_N)
484 kvmask = col_idx < seqlen_k
485 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
486 if PRE_LOAD_V:
487 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
488 else:
489 col_idx = col_start + tl.arange(0, BLOCK_N)
490 kvmask = col_idx < seqlen_k
491 K = tl.load(
492 p_bk0 + off,
493 mask=kvmask[None, :] & dmask[:, None],
494 cache_modifier=".cg",
495 )
496 if PRE_LOAD_V:
497 V = tl.load(
498 p_bv0 + off,
499 mask=kvmask[:, None] & dmask[None, :],
500 cache_modifier=".cg",
501 )
502 S = tl.dot(Q, K, allow_tf32=False)
503 S = apply_softcap(S, softcap, is_softcap)
504 col_idx = col_start + tl.arange(0, BLOCK_N)
505 row_idx = row_start + tl.arange(0, BLOCK_M)
506 S = apply_alibi(
507 S,
508 col_idx,
509 row_idx,
510 seqlen_q,
511 seqlen_k,
512 is_causal=is_causal,
513 is_alibi=is_alibi,
514 alibi_slope=alibi_slope,
515 )
516 # tl.store(p_bp0 + col_start, S)
517 S = apply_mask(
518 S,
519 col_idx,
520 row_idx,
521 seqlen_q,
522 seqlen_k,
523 window_size_left,
524 window_size_right,
525 is_even_mn=IS_EVEN_MN,
526 is_causal=is_causal,
527 is_local=is_local,
528 )
530 acc_, P, rowmax_, rowsum_ = softmax_rescale(
531 acc_,
532 S,
533 rowmax_,
534 rowsum_,
535 softmax_scale_log2e=scale_softmax_log2,
536 is_border=(is_causal or is_local),
537 )
538 P = P.to(v_ptr.type.element_ty)
540 if is_dropout:
541 if return_softmax:
542 P_drop = P
544 P_drop = apply_dropout(
545 P_drop,
546 row_start,
547 col_start,
548 seqlen_k,
549 bid,
550 hid,
551 philox_seed,
552 philox_offset,
553 p_dropout_in_uint8_t,
554 is_dropout,
555 encode_dropout_in_sign_bit=True,
556 NUM_HEADS=h,
557 BLOCK_M=BLOCK_M,
558 BLOCK_N=BLOCK_N,
559 )
560 if IS_EVEN_MN:
561 tl.store(p_bp0 + col_start, P_drop)
562 else:
563 kvmask = col_idx < seqlen_k
564 tl.store(
565 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]
566 )
568 P = apply_dropout(
569 P,
570 row_start,
571 col_start,
572 seqlen_k,
573 bid,
574 hid,
575 philox_seed,
576 philox_offset,
577 p_dropout_in_uint8_t,
578 is_dropout,
579 encode_dropout_in_sign_bit=False,
580 NUM_HEADS=h,
581 BLOCK_M=BLOCK_M,
582 BLOCK_N=BLOCK_N,
583 )
585 if not PRE_LOAD_V:
586 off = col_start * k_row_stride
587 if IS_EVEN_MN & d == BLOCK_K:
588 V = tl.load(p_bv0 + off, cache_modifier=".cg")
589 elif d == BLOCK_K:
590 kvmask = col_idx < seqlen_k
591 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
592 else:
593 kvmask = col_idx < seqlen_k
594 V = tl.load(
595 p_bv0 + off,
596 mask=kvmask[:, None] & dmask[None, :],
597 cache_modifier=".cg",
598 )
599 acc_ = tl.dot(P, V, acc_, allow_tf32=False)
601 for col_start in tl.range(
602 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages
603 ):
604 col_start = tl.multiple_of(col_start, BLOCK_N)
605 off = col_start * k_row_stride
606 if d == BLOCK_K:
607 K = tl.load(p_bk0 + off, cache_modifier=".cg")
608 if PRE_LOAD_V:
609 V = tl.load(p_bv0 + off, cache_modifier=".cg")
610 else:
611 K = tl.load(p_bk0 + off, mask=dmask[:, None], cache_modifier=".cg")
612 if PRE_LOAD_V:
613 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
615 S = tl.dot(Q, K)
616 S = apply_softcap(S, softcap, is_softcap)
617 col_idx = col_start + tl.arange(0, BLOCK_N)
618 row_idx = row_start + tl.arange(0, BLOCK_M)
619 S = apply_alibi(
620 S,
621 col_idx,
622 row_idx,
623 seqlen_q,
624 seqlen_k,
625 is_causal=is_causal,
626 is_alibi=is_alibi,
627 alibi_slope=alibi_slope,
628 )
629 S = apply_mask(
630 S,
631 col_idx,
632 row_idx,
633 seqlen_q,
634 seqlen_k,
635 window_size_left,
636 window_size_right,
637 is_even_mn=True,
638 is_causal=False,
639 is_local=is_local,
640 )
642 acc_, P, rowmax_, rowsum_ = softmax_rescale(
643 acc_,
644 S,
645 rowmax_,
646 rowsum_,
647 softmax_scale_log2e=scale_softmax_log2,
648 is_border=is_local,
649 )
650 P = P.to(v_ptr.type.element_ty)
652 if is_dropout:
653 if return_softmax:
654 P_drop = P
655 P_drop = apply_dropout(
656 P_drop,
657 row_start,
658 col_start,
659 seqlen_k,
660 bid,
661 hid,
662 philox_seed,
663 philox_offset,
664 p_dropout_in_uint8_t,
665 is_dropout,
666 encode_dropout_in_sign_bit=True,
667 NUM_HEADS=h,
668 BLOCK_M=BLOCK_M,
669 BLOCK_N=BLOCK_N,
670 )
671 if IS_EVEN_MN:
672 tl.store(p_bp0 + col_start, P_drop)
673 else:
674 kvmask = col_idx < seqlen_k
675 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
677 P = apply_dropout(
678 P,
679 row_start,
680 col_start,
681 seqlen_k,
682 bid,
683 hid,
684 philox_seed,
685 philox_offset,
686 p_dropout_in_uint8_t,
687 is_dropout,
688 encode_dropout_in_sign_bit=False,
689 NUM_HEADS=h,
690 BLOCK_M=BLOCK_M,
691 BLOCK_N=BLOCK_N,
692 )
694 if not PRE_LOAD_V:
695 off = col_start * k_row_stride
696 if d == BLOCK_K:
697 V = tl.load(p_bv0 + off, cache_modifier=".cg")
698 else:
699 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
700 acc_ = tl.dot(P, V, acc_)
702 # LSE
703 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels
704 # the effect of rowmax and outputs lse only.
705 lse = tl.where(
706 rowsum_ == 0 | (rowsum_ != rowsum_),
707 float("inf"),
708 rowmax_ * scale_softmax + tl.log(rowsum_),
709 )
710 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
712 if is_dropout:
713 acc_ *= inv_sum[:, None] * rp_dropout
714 else:
715 acc_ *= inv_sum[:, None]
717 out = acc_.to(o_ptr.type.element_ty) # noqa
719 # Write back output
720 o_batch_stride = tl.multiple_of(o_batch_stride, d * h)
721 o_ptr += bid * o_batch_stride
722 o_ptr += hid * o_head_stride
723 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, BLOCK_K)
725 if IS_EVEN_MN & d == BLOCK_K:
726 tl.store(o_ptr + o_offset, out)
727 else:
728 tl.store(o_ptr + o_offset, out, mask=qmask)
730 # Write back lse
731 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q
732 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
734 if IS_EVEN_MN:
735 tl.store(p_lse + row_idx, lse)
736 else:
737 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q)
740@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"])
741def flash_fwd_bh_parallel_kernel():
742 # (TODO)
743 pass
746def flash_fwd_splitkv_kernel_heur_block_k(args):
747 return triton.next_power_of_2(args["d"])
750@libentry()
751@triton.heuristics(
752 values={
753 "BLOCK_M": block_m_splitkv_heuristic_spec_args,
754 "BLOCK_N": block_n_splitkv_heuristic_spec_args,
755 "BLOCK_K": flash_fwd_splitkv_kernel_heur_block_k,
756 "num_warps": lambda args: 4,
757 "num_stages": lambda args: 3,
758 "PRE_LOAD_V": lambda args: True,
759 "IS_EVEN_MN": is_even_mn_spec_args,
760 }
761)
762@triton.jit(
763 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
764)
765def flash_fwd_splitkv_kernel(
766 q_ptr,
767 k_ptr,
768 v_ptr,
769 o_ptr,
770 p_ptr,
771 softmax_lse_ptr,
772 q_row_stride,
773 k_row_stride,
774 v_row_stride,
775 q_head_stride,
776 k_head_stride,
777 v_head_stride,
778 o_row_stride,
779 o_head_stride,
780 q_batch_stride,
781 k_batch_stride,
782 v_batch_stride,
783 o_batch_stride,
784 is_cu_seqlens_q,
785 cu_seqlens_q_ptr,
786 is_cu_seqlens_k: tl.constexpr,
787 cu_seqlens_k_ptr,
788 is_seqused_k: tl.constexpr,
789 seqused_k_ptr,
790 # sizes
791 b: tl.constexpr,
792 bk: tl.constexpr,
793 h: tl.constexpr,
794 hk: tl.constexpr,
795 h_hk_ratio: tl.constexpr,
796 seqlen_q,
797 seqlen_k,
798 seqlen_q_rounded,
799 seqlen_k_rounded,
800 d: tl.constexpr,
801 d_rounded: tl.constexpr,
802 # scaling factors
803 is_softcap: tl.constexpr,
804 softcap: tl.constexpr,
805 scale_softmax: tl.constexpr,
806 scale_softmax_log2: tl.constexpr,
807 # dropout
808 is_dropout: tl.constexpr,
809 p_dropout: tl.constexpr,
810 rp_dropout: tl.constexpr,
811 p_dropout_in_uint8_t: tl.constexpr,
812 philox_args,
813 return_softmax: tl.constexpr,
814 # causal and swa
815 is_causal: tl.constexpr,
816 is_local: tl.constexpr,
817 window_size_left: tl.constexpr,
818 window_size_right: tl.constexpr,
819 seqlenq_ngroups_swapped: tl.constexpr,
820 # alibi
821 is_alibi: tl.constexpr,
822 alibi_slopes_ptr,
823 alibi_slopes_batch_stride: tl.constexpr,
824 # block table
825 total_q,
826 page_table_ptr,
827 page_table_batch_stride: tl.constexpr,
828 block_size: tl.constexpr,
829 # kernel params
830 IS_EVEN_MN: tl.constexpr,
831 PRE_LOAD_V: tl.constexpr,
832 blocks_per_split: tl.constexpr,
833 BLOCK_M: tl.constexpr,
834 BLOCK_N: tl.constexpr,
835 BLOCK_K: tl.constexpr,
836 num_warps: tl.constexpr,
837 num_stages: tl.constexpr,
838):
839 m_block = tl.program_id(0)
840 split_id = tl.program_id(1)
841 bid = tl.program_id(2) // h
842 hid = tl.program_id(2) % h
844 split_block_min = split_id * blocks_per_split
845 split_block_max = split_block_min + blocks_per_split
847 n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
848 if is_causal:
849 n_block_max = min(
850 n_block_max,
851 tl.cdiv(
852 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right,
853 BLOCK_N,
854 ),
855 )
857 if is_alibi:
858 alibi_offset = bid * alibi_slopes_batch_stride + hid
859 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
860 alibi_slope /= scale_softmax
861 else:
862 alibi_slope = 0
864 if not is_causal:
865 if IS_EVEN_MN:
866 masking_block_min = n_block_max
867 else:
868 masking_block_min = n_block_max - 1
869 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero
870 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N)
871 else:
872 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1
874 q_ptr += bid * q_batch_stride
875 q_ptr += hid * q_head_stride
876 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
877 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
878 p_qm = q_ptr + q_off
879 dmask = tl.arange(0, BLOCK_K) < d
880 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
881 if IS_EVEN_MN & BLOCK_K == d:
882 Q = tl.load(p_qm, cache_modifier=".cg")
883 else:
884 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg")
886 h_hk_ratio = h // hk
887 k_ptr += bid * k_batch_stride
888 k_ptr += (hid // h_hk_ratio) * k_head_stride
889 v_ptr += bid * k_batch_stride
890 v_ptr += (hid // h_hk_ratio) * k_head_stride
892 k_offset = (
893 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
894 )
895 p_k0 = k_ptr + k_offset
897 v_offset = (
898 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
899 )
900 p_v0 = v_ptr + v_offset
902 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
903 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
904 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
906 if split_block_max <= masking_block_min:
907 # no masking needed
908 for n_block in tl.range(
909 split_block_min, split_block_max, num_stages=num_stages
910 ):
911 kv_off = n_block * BLOCK_N * k_row_stride
912 if d == BLOCK_K:
913 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
914 else:
915 K = tl.load(
916 p_k0 + kv_off, mask=dmask[:, None], cache_modifier=".cg", other=0.0
917 )
918 if PRE_LOAD_V:
919 if d == BLOCK_K:
920 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
921 else:
922 V = tl.load(
923 p_v0 + kv_off,
924 mask=dmask[None, :],
925 cache_modifier=".cg",
926 other=0.0,
927 )
928 S = tl.dot(Q, K)
929 S = apply_softcap(S, softcap, is_softcap)
930 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
931 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
932 S = apply_alibi(
933 S,
934 col_idx,
935 row_idx,
936 seqlen_q,
937 seqlen_k,
938 is_causal=is_causal,
939 is_alibi=is_alibi,
940 alibi_slope=alibi_slope,
941 )
942 acc_, P, rowmax_, rowsum_ = softmax_rescale(
943 acc_,
944 S,
945 rowmax_,
946 rowsum_,
947 softmax_scale_log2e=scale_softmax_log2,
948 is_border=False,
949 )
951 if not PRE_LOAD_V:
952 if d == BLOCK_K:
953 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
954 else:
955 V = tl.load(
956 p_v0 + kv_off,
957 mask=dmask[None, :],
958 cache_modifier=".cg",
959 other=0.0,
960 )
961 P = P.to(v_ptr.type.element_ty)
962 acc_ = tl.dot(P, V, acc_)
963 else:
964 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)):
965 kv_off = n_block * BLOCK_N * k_row_stride
966 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
967 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
968 if IS_EVEN_MN & d == BLOCK_K:
969 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
970 if PRE_LOAD_V:
971 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
972 elif d == BLOCK_K:
973 kvmask = col_idx < seqlen_k
974 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg")
975 if PRE_LOAD_V:
976 V = tl.load(
977 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
978 )
979 else:
980 kvmask = col_idx < seqlen_k
981 K = tl.load(
982 p_k0 + kv_off,
983 mask=dmask[:, None] & kvmask[None, :],
984 cache_modifier=".cg",
985 other=0.0,
986 )
987 if PRE_LOAD_V:
988 V = tl.load(
989 p_v0 + kv_off,
990 mask=dmask[None, :] & kvmask[:, None],
991 cache_modifier=".cg",
992 other=0.0,
993 )
995 S = tl.dot(Q, K)
996 S = apply_softcap(S, softcap, is_softcap)
997 S = apply_alibi(
998 S,
999 col_idx,
1000 row_idx,
1001 seqlen_q,
1002 seqlen_k,
1003 is_causal=is_causal,
1004 is_alibi=is_alibi,
1005 alibi_slope=alibi_slope,
1006 )
1007 S = apply_mask(
1008 S,
1009 col_idx,
1010 row_idx,
1011 seqlen_q,
1012 seqlen_k,
1013 window_size_left,
1014 window_size_right,
1015 is_even_mn=IS_EVEN_MN,
1016 is_causal=is_causal,
1017 is_local=False,
1018 )
1020 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1021 acc_,
1022 S,
1023 rowmax_,
1024 rowsum_,
1025 softmax_scale_log2e=scale_softmax_log2,
1026 is_border=(is_causal or is_local),
1027 )
1029 if not PRE_LOAD_V:
1030 if IS_EVEN_MN & d == BLOCK_K:
1031 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
1032 elif d == BLOCK_K:
1033 V = tl.load(
1034 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
1035 )
1036 else:
1037 V = tl.load(
1038 p_v0 + kv_off,
1039 mask=dmask[None, :] & kvmask[:, None],
1040 cache_modifier=".cg",
1041 other=0.0,
1042 )
1043 P = P.to(v_ptr.type.element_ty)
1044 acc_ = tl.dot(P, V, acc_)
1046 # LSE
1047 lse = tl.where(
1048 rowsum_ == 0 | (rowsum_ != rowsum_),
1049 float("-inf"),
1050 rowmax_ * scale_softmax + tl.log(rowsum_),
1051 )
1052 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1054 # Rescale output
1055 acc_ *= inv_sum[:, None]
1057 # Write back output
1058 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size)
1059 # grid = (seq_block, split, batch * head)
1060 o_split_ptr = o_ptr
1061 # + split, batch, head offsets, seq_block offsets are already added in row_idx
1062 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d
1063 o_split_offset = row_idx[:, None] * d + tl.arange(0, BLOCK_K)
1064 o_split_ptr = tl.multiple_of(o_split_ptr, d)
1065 p_om = o_split_ptr + o_split_offset
1067 if IS_EVEN_MN & BLOCK_K == d:
1068 tl.store(p_om, acc_, cache_modifier=".cg")
1069 else:
1070 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg")
1072 # Write back lse
1073 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
1074 lse_split_ptr = softmax_lse_ptr
1075 # + split, batch, head, seq_block offsets
1076 lse_split_ptr += (
1077 split_id * tl.num_programs(2) + tl.program_id(2)
1078 ) * seqlen_q + m_block * BLOCK_M
1080 if IS_EVEN_MN:
1081 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg")
1082 else:
1083 tl.store(
1084 lse_split_ptr + tl.arange(0, BLOCK_M),
1085 lse,
1086 mask=row_idx < seqlen_q,
1087 cache_modifier=".cg",
1088 )
1091@libentry()
1092@triton.jit
1093def flash_fwd_splitkv_combine_kernel(
1094 out_ptr,
1095 lse_ptr,
1096 out_splits_ptr,
1097 lse_splits_ptr,
1098 head_size: tl.constexpr,
1099 out_split_stride,
1100 lse_split_stride,
1101 out_b_stride,
1102 out_s_stride,
1103 out_h_stride,
1104 n_splits,
1105 BLOCK_M: tl.constexpr,
1106 BLOCK_K: tl.constexpr,
1107 q_total,
1108 MAX_N_SPLITS: tl.constexpr,
1109):
1110 pid = tl.program_id(0)
1111 lse_splits_ptr += pid * BLOCK_M
1112 lse_ptr += pid * BLOCK_M
1113 out_splits_ptr += pid * BLOCK_M * head_size
1114 out_ptr += pid * BLOCK_M * head_size
1116 # Subtracting maximum from each of the split lse's for better numerical stability
1117 lse_split_offset = (
1118 tl.arange(0, BLOCK_M)[:, None]
1119 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride
1120 )
1121 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & (
1122 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits
1123 )
1124 lse_splits = tl.load(
1125 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf")
1126 )
1127 max_lse = tl.max(lse_splits, 1)
1129 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse)
1130 Zi_scaled = tl.exp(lse_splits - max_lse[:, None])
1131 Z_scaled = tl.sum(Zi_scaled, 1)
1132 Zi_Z = Zi_scaled / Z_scaled[:, None]
1134 # Write back LSE
1135 lse = tl.log(Z_scaled) + max_lse
1136 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total
1137 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask)
1139 out_split_offset = (
1140 tl.arange(0, BLOCK_M)[:, None, None] * head_size
1141 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride
1142 + tl.arange(0, BLOCK_K)[None, None, :]
1143 )
1144 out_split_mask = (
1145 (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total)
1146 & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits)
1147 & (tl.arange(0, BLOCK_K)[None, None, :] < head_size)
1148 )
1149 out_splits = tl.load(
1150 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0.0
1151 )
1152 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1)
1153 out = out.to(out_ptr.type.element_ty)
1155 # Write back output
1156 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, BLOCK_K)
1157 dmask = tl.arange(0, BLOCK_K) < head_size
1158 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None] & dmask[None, :])
1161@triton.jit
1162def virtual_to_cache(
1163 virtual_index,
1164 max_virtual_index,
1165 page_table_ptr,
1166 block_size,
1167 boundary_check: tl.constexpr = False,
1168):
1169 # virtual_index is the kv sequence index in the current batch element
1170 # page_table_ptr is already pointed at current batch element's block table entry
1171 # block_size is the size of each block in the page table
1172 virtual_page_index = virtual_index // block_size
1173 page_offset = virtual_index % block_size
1174 if boundary_check:
1175 page_block_index = tl.load(
1176 page_table_ptr + virtual_page_index,
1177 mask=virtual_index < max_virtual_index,
1178 other=0,
1179 ).to(tl.int32)
1180 else:
1181 page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32)
1182 return page_block_index * block_size + page_offset
1185@triton.jit
1186def load_from_kvcache(
1187 virtual_index,
1188 max_virtual_index,
1189 page_table_ptr,
1190 k_ptr_base,
1191 v_ptr_base,
1192 block_size,
1193 d: tl.constexpr,
1194 k_row_stride,
1195 BLOCK_K: tl.constexpr,
1196 boundary_check: tl.constexpr = False,
1197):
1198 kvcache_idx = virtual_to_cache(
1199 virtual_index, max_virtual_index, page_table_ptr, block_size, boundary_check
1200 )
1201 k_offset = tl.arange(0, BLOCK_K)[:, None] + kvcache_idx[None, :] * k_row_stride
1202 v_offset = tl.arange(0, BLOCK_K)[None, :] + kvcache_idx[:, None] * k_row_stride
1203 if d == BLOCK_K:
1204 bK = tl.load(k_ptr_base + k_offset)
1205 bV = tl.load(v_ptr_base + v_offset)
1206 else:
1207 bK = tl.load(
1208 k_ptr_base + k_offset, mask=tl.arange(0, BLOCK_K)[:, None] < d, other=0.0
1209 )
1210 bV = tl.load(
1211 v_ptr_base + v_offset, mask=tl.arange(0, BLOCK_K)[None, :] < d, other=0.0
1212 )
1213 return bK, bV
1216@libentry()
1217@triton.jit(
1218 do_not_specialize=[
1219 "q_batch_stride",
1220 "k_batch_stride",
1221 "v_batch_stride",
1222 "o_batch_stride",
1223 "b",
1224 "bk",
1225 "seqlen_q",
1226 "seqlen_k",
1227 "seqlen_q_rounded",
1228 "seqlen_k_rounded",
1229 "total_q",
1230 ]
1231)
1232def flash_varlen_fwd_kernel(
1233 q_ptr,
1234 k_ptr,
1235 v_ptr,
1236 o_ptr,
1237 p_ptr,
1238 softmax_lse_ptr,
1239 q_row_stride,
1240 k_row_stride,
1241 v_row_stride,
1242 q_head_stride,
1243 k_head_stride,
1244 v_head_stride,
1245 o_row_stride,
1246 o_head_stride,
1247 q_batch_stride,
1248 k_batch_stride,
1249 v_batch_stride,
1250 o_batch_stride,
1251 is_cu_seqlens_q: tl.constexpr,
1252 cu_seqlens_q_ptr,
1253 is_cu_seqlens_k: tl.constexpr,
1254 cu_seqlens_k_ptr,
1255 is_seqused_k: tl.constexpr,
1256 seqused_k_ptr,
1257 # sizes
1258 b,
1259 bk,
1260 h: tl.constexpr,
1261 hk: tl.constexpr,
1262 h_hk_ratio: tl.constexpr,
1263 seqlen_q,
1264 seqlen_k,
1265 seqlen_q_rounded,
1266 seqlen_k_rounded,
1267 d: tl.constexpr,
1268 d_rounded: tl.constexpr,
1269 # scaling factors
1270 is_softcap: tl.constexpr,
1271 softcap: tl.constexpr,
1272 scale_softmax: tl.constexpr,
1273 scale_softmax_log2: tl.constexpr,
1274 # dropout
1275 is_dropout: tl.constexpr,
1276 p_dropout: tl.constexpr,
1277 rp_dropout: tl.constexpr,
1278 p_dropout_in_uint8_t: tl.constexpr,
1279 philox_args,
1280 return_softmax: tl.constexpr,
1281 # causal and swa
1282 is_causal: tl.constexpr,
1283 is_local: tl.constexpr,
1284 window_size_left: tl.constexpr,
1285 window_size_right: tl.constexpr,
1286 seqlenq_ngroups_swapped: tl.constexpr,
1287 # alibi
1288 is_alibi: tl.constexpr,
1289 alibi_slopes_ptr,
1290 alibi_slopes_batch_stride: tl.constexpr,
1291 # block table
1292 total_q,
1293 page_table_ptr,
1294 page_table_batch_stride: tl.constexpr,
1295 block_size: tl.constexpr,
1296 # kernel params
1297 BLOCK_M: tl.constexpr,
1298 BLOCK_N: tl.constexpr,
1299 BLOCK_K: tl.constexpr,
1300 num_warps: tl.constexpr,
1301 num_stages: tl.constexpr,
1302):
1303 m_block = tl.program_id(0)
1304 bid = tl.program_id(1)
1305 hid = tl.program_id(2)
1306 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
1308 if is_cu_seqlens_q:
1309 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32)
1310 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32)
1311 q_len = q_eos - q_bos
1312 # Current request's start offset in the batched Q
1313 q_offset = q_bos * q_row_stride
1314 o_offset = q_bos * o_row_stride
1315 lse_offset = q_bos * 1
1316 else:
1317 q_len = seqlen_q
1318 q_offset = bid * q_batch_stride
1319 o_offset = bid * o_batch_stride
1320 lse_offset = bid * seqlen_q
1322 if is_cu_seqlens_k:
1323 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32)
1324 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32)
1325 k_len_cache = k_eos - k_bos
1326 # k_offset = k_bos * k_row_stride
1327 else:
1328 k_len_cache = seqlen_k
1329 # k_offset = bid * k_batch_stride
1331 if is_seqused_k:
1332 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32)
1333 else:
1334 k_len = k_len_cache
1336 # Noop CTA
1337 if m_block * BLOCK_M > q_len:
1338 return
1340 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0)
1341 is_even_mn: tl.constexpr = False
1343 if is_local:
1344 n_block_min = max(
1345 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N
1346 )
1347 else:
1348 n_block_min = 0
1350 n_block_max = tl.cdiv(k_len, BLOCK_N)
1351 if is_causal or is_local:
1352 n_block_max = min(
1353 n_block_max,
1354 tl.cdiv(
1355 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N
1356 ),
1357 )
1359 if is_dropout:
1360 philox_seed = tl.load(philox_args).to(tl.uint64)
1361 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
1363 # Locate the page table entry for the current batch element
1364 page_table_ptr += bid * page_table_batch_stride
1365 # Calculate the starting offset of q for the current head
1366 q_row_offset = hid * q_head_stride
1367 # Calculate the starting offset of k and v for the current head
1368 k_row_offset = (hid // h_hk_ratio) * k_head_stride
1369 # Shift the k, v pointers to align with the current head
1370 k_ptr_base = k_ptr + k_row_offset
1371 v_ptr_base = v_ptr + k_row_offset
1373 gQ = tl.make_block_ptr(
1374 base=q_ptr + q_offset + q_row_offset,
1375 shape=(q_len, d),
1376 strides=(q_row_stride, 1),
1377 offsets=(0, 0),
1378 block_shape=(BLOCK_M, BLOCK_K),
1379 order=(1, 0),
1380 )
1381 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0, 1))
1383 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
1384 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
1385 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
1387 if is_alibi:
1388 alibi_offset = bid * alibi_slopes_batch_stride + hid
1389 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
1390 alibi_slope /= scale_softmax
1391 else:
1392 alibi_slope = 0.0
1394 if not is_causal and not is_local:
1395 n_masking_steps = 1
1396 elif is_even_mn:
1397 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
1398 else:
1399 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
1401 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps)
1403 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1404 n_block = n_block_max - 1
1405 for step in tl.range(0, n_masking_steps):
1406 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1407 bK, bV = load_from_kvcache(
1408 col_idx,
1409 k_len,
1410 page_table_ptr,
1411 k_ptr_base,
1412 v_ptr_base,
1413 block_size,
1414 d,
1415 k_row_stride,
1416 BLOCK_K=BLOCK_K,
1417 boundary_check=True,
1418 )
1419 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1420 S = apply_softcap(S, softcap, is_softcap)
1421 S = apply_alibi(
1422 S,
1423 col_idx,
1424 row_idx,
1425 q_len,
1426 k_len,
1427 is_causal=is_causal,
1428 is_alibi=is_alibi,
1429 alibi_slope=alibi_slope,
1430 )
1431 S = apply_mask(
1432 S,
1433 col_idx,
1434 row_idx,
1435 q_len,
1436 k_len,
1437 window_size_left,
1438 window_size_right,
1439 is_even_mn=is_even_mn,
1440 is_causal=is_causal,
1441 is_local=is_local,
1442 )
1444 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1445 acc_,
1446 S,
1447 rowmax_,
1448 rowsum_,
1449 softmax_scale_log2e=scale_softmax_log2,
1450 is_border=True,
1451 )
1452 P = P.to(v_ptr.type.element_ty)
1454 if is_dropout:
1455 P = apply_dropout(
1456 P,
1457 n_block * BLOCK_N,
1458 m_block * BLOCK_M,
1459 k_len,
1460 bid,
1461 hid,
1462 philox_seed,
1463 philox_offset,
1464 p_dropout_in_uint8_t,
1465 is_dropout,
1466 encode_dropout_in_sign_bit=False,
1467 NUM_HEADS=h,
1468 BLOCK_M=BLOCK_M,
1469 BLOCK_N=BLOCK_N,
1470 )
1472 acc_ = tl.dot(P, bV, acc_)
1473 n_block -= 1
1475 for n_block in tl.range(
1476 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1
1477 ):
1478 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1479 bK, bV = load_from_kvcache(
1480 col_idx,
1481 k_len,
1482 page_table_ptr,
1483 k_ptr_base,
1484 v_ptr_base,
1485 block_size,
1486 d,
1487 k_row_stride,
1488 BLOCK_K=BLOCK_K,
1489 )
1490 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1491 S = apply_softcap(S, softcap, is_softcap)
1492 S = apply_alibi(
1493 S,
1494 col_idx,
1495 row_idx,
1496 q_len,
1497 k_len,
1498 is_causal=is_causal,
1499 is_alibi=is_alibi,
1500 alibi_slope=alibi_slope,
1501 )
1502 S = apply_mask(
1503 S,
1504 col_idx,
1505 row_idx,
1506 q_len,
1507 k_len,
1508 window_size_left,
1509 window_size_right,
1510 is_even_mn=True,
1511 is_causal=False,
1512 is_local=is_local,
1513 )
1515 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1516 acc_,
1517 S,
1518 rowmax_,
1519 rowsum_,
1520 softmax_scale_log2e=scale_softmax_log2,
1521 is_border=is_local,
1522 )
1523 P = P.to(v_ptr.type.element_ty)
1525 if is_dropout:
1526 P = apply_dropout(
1527 P,
1528 m_block * BLOCK_M,
1529 n_block * BLOCK_N,
1530 k_len,
1531 bid,
1532 hid,
1533 philox_seed,
1534 philox_offset,
1535 p_dropout_in_uint8_t,
1536 is_dropout,
1537 encode_dropout_in_sign_bit=False,
1538 NUM_HEADS=h,
1539 BLOCK_M=BLOCK_M,
1540 BLOCK_N=BLOCK_N,
1541 )
1542 acc_ = tl.dot(P, bV, acc_)
1544 # LSE
1545 lse = tl.where(
1546 rowsum_ == 0 | (rowsum_ != rowsum_),
1547 float("inf"),
1548 rowmax_ * scale_softmax + tl.log(rowsum_),
1549 )
1550 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1552 acc_ *= inv_sum[:, None]
1554 out = acc_.to(o_ptr.type.element_ty) # noqa
1556 # Write back output
1557 o_row_offset = hid * o_head_stride
1559 gO = tl.make_block_ptr(
1560 base=o_ptr + o_offset + o_row_offset,
1561 shape=(q_len, d),
1562 strides=(o_row_stride, 1),
1563 offsets=(0, 0),
1564 block_shape=(BLOCK_M, BLOCK_K),
1565 order=(1, 0),
1566 )
1567 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0, 1))
1569 # Write back lse
1570 # lse shape: [h, total_q]
1571 softmax_lse_ptr += hid * total_q
1572 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1573 tl.store(
1574 softmax_lse_ptr + lse_row_offset,
1575 lse,
1576 mask=lse_row_offset < (lse_offset + q_len),
1577 )