Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/flash_kernel.py: 0%
534 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import triton
2import triton.language as tl
4# from flag_gems import runtime
5from flag_gems.utils import libentry, tl_extra_shim
8@triton.jit
9def u64_to_lohi(x):
10 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32)
13@triton.jit
14def u64_from_lohi(lo, hi):
15 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64)
18@triton.jit
19def philox_(seed, subsequence, offset):
20 kPhilox10A: tl.constexpr = 0x9E3779B9
21 kPhilox10B: tl.constexpr = 0xBB67AE85
22 k0, k1 = u64_to_lohi(seed.to(tl.uint64))
23 c0, c1 = u64_to_lohi(offset.to(tl.uint64))
24 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64))
26 # pragma unroll
27 kPhiloxSA: tl.constexpr = 0xD2511F53
28 kPhiloxSB: tl.constexpr = 0xCD9E8D57
29 for _ in tl.static_range(6):
30 res0 = kPhiloxSA * c0.to(tl.uint64)
31 res1 = kPhiloxSB * c2.to(tl.uint64)
32 res0_x, res0_y = u64_to_lohi(res0)
33 res1_x, res1_y = u64_to_lohi(res1)
34 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
35 k0 += kPhilox10A
36 k1 += kPhilox10B
38 res0 = kPhiloxSA * c0.to(tl.uint64)
39 res1 = kPhiloxSB * c2.to(tl.uint64)
40 res0_x, res0_y = u64_to_lohi(res0)
41 res1_x, res1_y = u64_to_lohi(res1)
42 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
44 return c0, c1, c2, c3
47@triton.jit
48def apply_dropout_mask(
49 P,
50 mask,
51 encode_dropout_in_sign_bit: tl.constexpr,
52):
53 if encode_dropout_in_sign_bit:
54 P = tl.where(mask, -P, P)
55 else:
56 P = tl.where(mask, (P * 0).to(P.dtype), P)
57 return P
60@triton.jit
61def apply_dropout(
62 P,
63 row_start,
64 col_start,
65 n_cols,
66 bid,
67 hid,
68 philox_seed,
69 philox_offset,
70 p_dropout_uint8: tl.constexpr,
71 is_dropout: tl.constexpr,
72 encode_dropout_in_sign_bit: tl.constexpr,
73 NUM_HEADS: tl.constexpr,
74 BLOCK_M: tl.constexpr,
75 BLOCK_N: tl.constexpr,
76):
77 if is_dropout:
78 row_start = tl.multiple_of(row_start, BLOCK_M)
79 col_start = tl.multiple_of(col_start, BLOCK_N)
80 row = row_start + tl.arange(0, BLOCK_M)[:, None]
81 # Down scale col_idx by 4
82 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :]
84 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64)
86 offset = philox_offset + bid * NUM_HEADS + hid
87 offset += subsequence * 0
88 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset)
90 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N)
92 mask = (r & 0xFF) >= p_dropout_uint8
94 P = apply_dropout_mask(
95 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit
96 )
97 return P
100@triton.jit
101def apply_alibi(
102 S,
103 col_idx,
104 row_idx,
105 max_seqlen_q,
106 max_seqlen_k,
107 is_causal: tl.constexpr,
108 is_alibi: tl.constexpr,
109 alibi_slope: tl.constexpr = None,
110):
111 if is_alibi:
112 if is_causal:
113 # The row independent alibi bias renders the same attention output
114 # as with the standard alibi because softmax is shift invariant, i.e.,
115 # softmax(A + bias + const) = softamx(A + bias). The following two
116 # biases are no different if causal is true.
117 # bias_1 = [
118 # -4, -3, -2, X, X,
119 # -4, -3, -2, -1, X,
120 # -4, -3, -2, -1, 0,
121 # ]
122 # bias_2 = [
123 # -2, -1, 0, X, X,
124 # -3, -2, -1, 0, X,
125 # -4, -3, -2, -1, 0,
126 # ]
127 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32)
128 S += bias
129 else:
130 bias = -alibi_slope * tl.abs(
131 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None]
132 ).to(tl.float32)
133 S += bias
135 return S
138@triton.jit
139def apply_mask(
140 S,
141 col_idx,
142 row_idx,
143 max_seqlen_q,
144 max_seqlen_k,
145 window_size_left,
146 window_size_right,
147 is_even_mn: tl.constexpr,
148 is_causal: tl.constexpr,
149 is_local: tl.constexpr,
150):
151 need_mask = is_causal | is_local | (not is_even_mn)
152 # need_mask: tl.constexpr = is_causal | is_local
153 if need_mask:
154 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive!
155 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left)
156 col_rb = min(
157 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right
158 )
160 if is_causal:
161 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S)
163 if is_local:
164 S = tl.where(
165 (col_idx[None, :] > col_rb[:, None])
166 | (col_idx[None, :] < col_lb[:, None]),
167 float("-inf"),
168 S,
169 )
171 if (not is_local) & (not is_causal) & (not is_even_mn):
172 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S)
174 return S
177@triton.jit
178def softmax_rescale(
179 O_acc,
180 S,
181 row_max,
182 row_sum,
183 softmax_scale_log2e: tl.constexpr,
184 is_border: tl.constexpr,
185 # is_init: tl.constexpr
186):
187 prev_max = row_max
188 row_max = tl.maximum(row_max, tl.max(S, 1))
190 if is_border:
191 cur_max = tl.where(row_max == float("-inf"), 0, row_max)
192 else:
193 cur_max = row_max
195 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
196 row_sum *= p_scale
197 O_acc *= p_scale[:, None]
199 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e)
200 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
201 row_sum = row_sum + tl.sum(P, 1)
202 return O_acc, P, row_max, row_sum
205@triton.jit
206def apply_softcap(S, softcap, is_softcap: tl.constexpr):
207 if is_softcap:
208 S = tl_extra_shim.tanh(S * softcap)
210 return S
213def block_m_splitkv_heuristic(headdim):
214 return 128 if headdim <= 128 else 64
217def block_n_splitkv_heuristic(headdim):
218 return 64 if headdim <= 64 else 32
221def is_even_mn(args):
222 if args["M"] % args["BM"] == 0 and args["N"] % args["BN"] == 0:
223 if args["M"] % args["N"] == 0 or args["N"] % args["M"] == 0:
224 if (args["WL"] == -1 or args["WL"] % args["BN"] == 0) and (
225 args["WR"] == -1 or args["WR"] % args["BN"] == 0
226 ):
227 return True
228 return False
231def block_m_splitkv_heuristic_spec_args(args):
232 return 128 if args["d"] <= 128 else 64
235def block_n_splitkv_heuristic_spec_args(args):
236 return 64 if args["d"] <= 64 else 32
239def is_even_mn_spec_args(args):
240 if (
241 args["seqlen_q"] % args["BLOCK_M"] == 0
242 and args["seqlen_k"] % args["BLOCK_N"] == 0
243 ):
244 if (
245 args["seqlen_q"] % args["seqlen_k"] == 0
246 or args["seqlen_k"] % args["seqlen_q"] == 0
247 ):
248 if (
249 args["window_size_left"] == -1
250 or args["window_size_left"] % args["BLOCK_N"] == 0
251 ) and (
252 args["window_size_right"] == -1
253 or args["window_size_right"] % args["BLOCK_N"] == 0
254 ):
255 return True
256 return False
259def keep(cfg, must_keep=None):
260 BM = cfg.kwargs["BLOCK_M"]
261 BN = cfg.kwargs["BLOCK_N"]
262 w = cfg.num_warps
264 # we always keep configurations in `must_keep`
265 return (BM, BN, w) in ((128, 32, 4), (128, 128, 8)) or (
266 must_keep and cfg in must_keep
267 )
270def prune_fwd_configs(configs, nargs, **kwargs):
271 is_dropout = nargs["is_dropout"]
272 if is_dropout:
273 return list(
274 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs)
275 )
276 else:
277 return configs
280# @libentry()
281# @triton.autotune(
282# configs=list(filter(keep, runtime.get_tuned_config("attention"))),
283# prune_configs_by={"early_config_prune": prune_fwd_configs},
284# key=["d", "is_dropout"],
285# )
286# @triton.heuristics(
287# values={
288# "BLOCK_K": lambda args: triton.next_power_of_2(args["d"]),
289# "PRE_LOAD_V": lambda args: False,
290# "IS_EVEN_MN": is_even_mn,
291# }
292# )
293# @triton.jit(
294# do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
295# )
296def flash_fwd_kernel(
297 q_ptr,
298 k_ptr,
299 v_ptr,
300 o_ptr,
301 p_ptr,
302 softmax_lse_ptr,
303 q_row_stride,
304 k_row_stride,
305 v_row_stride,
306 q_head_stride,
307 k_head_stride,
308 v_head_stride,
309 o_row_stride,
310 o_head_stride,
311 q_batch_stride,
312 k_batch_stride,
313 v_batch_stride,
314 o_batch_stride,
315 is_cu_seqlens_q,
316 cu_seqlens_q_ptr,
317 is_cu_seqlens_k,
318 cu_seqlens_k_ptr,
319 is_seqused_k,
320 seqused_k_ptr,
321 # sizes
322 b: tl.constexpr,
323 bk: tl.constexpr,
324 h: tl.constexpr,
325 hk: tl.constexpr,
326 h_hk_ratio: tl.constexpr,
327 seqlen_q,
328 seqlen_k,
329 seqlen_q_rounded,
330 seqlen_k_rounded,
331 d: tl.constexpr,
332 d_rounded: tl.constexpr,
333 # scaling factors
334 is_softcap: tl.constexpr,
335 softcap: tl.constexpr,
336 scale_softmax: tl.constexpr,
337 scale_softmax_log2: tl.constexpr,
338 # dropout
339 is_dropout: tl.constexpr,
340 p_dropout: tl.constexpr,
341 rp_dropout: tl.constexpr,
342 p_dropout_in_uint8_t: tl.constexpr,
343 philox_args,
344 return_softmax: tl.constexpr,
345 # causal and swa
346 is_causal: tl.constexpr,
347 is_local: tl.constexpr,
348 window_size_left: tl.constexpr,
349 window_size_right: tl.constexpr,
350 seqlenq_ngroups_swapped: tl.constexpr,
351 # alibi
352 is_alibi: tl.constexpr,
353 alibi_slopes_ptr,
354 alibi_slopes_batch_stride: tl.constexpr,
355 # block table
356 total_q: tl.constexpr,
357 page_table_ptr,
358 page_table_batch_stride: tl.constexpr,
359 block_size: tl.constexpr,
360 # kernel params
361 IS_EVEN_MN: tl.constexpr,
362 PRE_LOAD_V: tl.constexpr,
363 BLOCK_M: tl.constexpr,
364 BLOCK_N: tl.constexpr,
365 BLOCK_K: tl.constexpr,
366 num_warps: tl.constexpr,
367 num_stages: tl.constexpr,
368):
369 m_block = tl.program_id(0)
370 bh = tl.program_id(1)
371 hid = bh % h
372 bid = bh // h
373 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
375 # We draw a minimum covering frame on the attention map that this CTA is assigned to process.
376 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively.
378 col_min = 0
379 if is_local:
380 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left)
381 if not IS_EVEN_MN:
382 # round left
383 col_min = (col_min // BLOCK_N) * BLOCK_N
385 col_max = seqlen_k
386 if is_causal or is_local:
387 col_max += (m_block - num_m_blocks + 1) * BLOCK_M
388 if is_local:
389 col_max += window_size_right
390 col_max = min(seqlen_k, col_max)
392 if not IS_EVEN_MN:
393 # round right
394 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N
396 if (not is_causal) and (not is_local):
397 if IS_EVEN_MN:
398 masking_cols: tl.constexpr = 0
399 else:
400 masking_cols: tl.constexpr = BLOCK_N
401 elif (
402 is_causal | is_local
403 ) and IS_EVEN_MN: # causal implies window_size_right is zero
404 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
405 else:
406 # local
407 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
409 if is_dropout:
410 philox_seed = tl.load(philox_args).to(tl.uint64)
411 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
413 if is_alibi:
414 alibi_offset = bid * alibi_slopes_batch_stride + hid
415 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
416 alibi_slope /= scale_softmax
417 else:
418 alibi_slope = 0.0
420 q_batch_stride = tl.multiple_of(q_batch_stride, d * h)
421 q_ptr += bid * q_batch_stride + hid * q_head_stride
422 row_start = m_block * BLOCK_M
423 row_idx = row_start + tl.arange(0, BLOCK_M)
424 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
425 dmask = tl.arange(0, BLOCK_K) < d
426 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
427 if IS_EVEN_MN & d == BLOCK_K:
428 Q = tl.load(q_ptr + q_off, cache_modifier=".cg")
429 else:
430 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg")
432 if return_softmax:
433 p_ptr += (
434 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M
435 ) * seqlen_k_rounded
436 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(
437 0, BLOCK_N
438 )
439 p_bp0 = p_ptr + p_offset
441 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
442 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
443 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
445 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk)
446 h_hk_ratio = h // hk
447 k_ptr += bid * k_batch_stride
448 k_ptr += (hid // h_hk_ratio) * k_head_stride
449 v_ptr += bid * k_batch_stride
450 v_ptr += (hid // h_hk_ratio) * k_head_stride
452 k_offset = (
453 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
454 )
455 v_offset = (
456 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
457 )
459 p_bk0 = k_ptr + k_offset
460 p_bv0 = v_ptr + v_offset
462 if is_causal | is_local | (not IS_EVEN_MN):
463 # Cut short masking cols if there's not enough cols out there
464 masking_cols = min(col_max - col_min, masking_cols)
465 for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
466 col_start = col_max - col_shift - BLOCK_N
467 col_start = tl.multiple_of(col_start, BLOCK_N)
468 off = col_start * k_row_stride
469 if IS_EVEN_MN & d == BLOCK_K:
470 K = tl.load(p_bk0 + off, cache_modifier=".cg")
471 if PRE_LOAD_V:
472 V = tl.load(p_bv0 + off, cache_modifier=".cg")
473 elif d == BLOCK_K:
474 col_idx = col_start + tl.arange(0, BLOCK_N)
475 kvmask = col_idx < seqlen_k
476 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
477 if PRE_LOAD_V:
478 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
479 else:
480 col_idx = col_start + tl.arange(0, BLOCK_N)
481 kvmask = col_idx < seqlen_k
482 K = tl.load(
483 p_bk0 + off,
484 mask=kvmask[None, :] & dmask[:, None],
485 cache_modifier=".cg",
486 )
487 if PRE_LOAD_V:
488 V = tl.load(
489 p_bv0 + off,
490 mask=kvmask[:, None] & dmask[None, :],
491 cache_modifier=".cg",
492 )
493 S = tl.dot(Q, K, allow_tf32=False)
494 S = apply_softcap(S, softcap, is_softcap)
495 col_idx = col_start + tl.arange(0, BLOCK_N)
496 row_idx = row_start + tl.arange(0, BLOCK_M)
497 S = apply_alibi(
498 S,
499 col_idx,
500 row_idx,
501 seqlen_q,
502 seqlen_k,
503 is_causal=is_causal,
504 is_alibi=is_alibi,
505 alibi_slope=alibi_slope,
506 )
507 # tl.store(p_bp0 + col_start, S)
508 S = apply_mask(
509 S,
510 col_idx,
511 row_idx,
512 seqlen_q,
513 seqlen_k,
514 window_size_left,
515 window_size_right,
516 is_even_mn=IS_EVEN_MN,
517 is_causal=is_causal,
518 is_local=is_local,
519 )
521 acc_, P, rowmax_, rowsum_ = softmax_rescale(
522 acc_,
523 S,
524 rowmax_,
525 rowsum_,
526 softmax_scale_log2e=scale_softmax_log2,
527 is_border=(is_causal or is_local),
528 )
529 P = P.to(v_ptr.type.element_ty)
531 if is_dropout:
532 if return_softmax:
533 P_drop = P
535 P_drop = apply_dropout(
536 P_drop,
537 row_start,
538 col_start,
539 seqlen_k,
540 bid,
541 hid,
542 philox_seed,
543 philox_offset,
544 p_dropout_in_uint8_t,
545 is_dropout,
546 encode_dropout_in_sign_bit=True,
547 NUM_HEADS=h,
548 BLOCK_M=BLOCK_M,
549 BLOCK_N=BLOCK_N,
550 )
551 if IS_EVEN_MN:
552 tl.store(p_bp0 + col_start, P_drop)
553 else:
554 kvmask = col_idx < seqlen_k
555 tl.store(
556 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]
557 )
559 P = apply_dropout(
560 P,
561 row_start,
562 col_start,
563 seqlen_k,
564 bid,
565 hid,
566 philox_seed,
567 philox_offset,
568 p_dropout_in_uint8_t,
569 is_dropout,
570 encode_dropout_in_sign_bit=False,
571 NUM_HEADS=h,
572 BLOCK_M=BLOCK_M,
573 BLOCK_N=BLOCK_N,
574 )
576 if not PRE_LOAD_V:
577 off = col_start * k_row_stride
578 if IS_EVEN_MN & d == BLOCK_K:
579 V = tl.load(p_bv0 + off, cache_modifier=".cg")
580 elif d == BLOCK_K:
581 kvmask = col_idx < seqlen_k
582 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
583 else:
584 kvmask = col_idx < seqlen_k
585 V = tl.load(
586 p_bv0 + off,
587 mask=kvmask[:, None] & dmask[None, :],
588 cache_modifier=".cg",
589 )
590 acc_ = tl.dot(P, V, acc_, allow_tf32=False)
592 for col_start in tl.range(
593 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages
594 ):
595 col_start = tl.multiple_of(col_start, BLOCK_N)
596 off = col_start * k_row_stride
597 if d == BLOCK_K:
598 K = tl.load(p_bk0 + off, cache_modifier=".cg")
599 if PRE_LOAD_V:
600 V = tl.load(p_bv0 + off, cache_modifier=".cg")
601 else:
602 K = tl.load(p_bk0 + off, mask=dmask[:, None], cache_modifier=".cg")
603 if PRE_LOAD_V:
604 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
606 S = tl.dot(Q, K)
607 S = apply_softcap(S, softcap, is_softcap)
608 col_idx = col_start + tl.arange(0, BLOCK_N)
609 row_idx = row_start + tl.arange(0, BLOCK_M)
610 S = apply_alibi(
611 S,
612 col_idx,
613 row_idx,
614 seqlen_q,
615 seqlen_k,
616 is_causal=is_causal,
617 is_alibi=is_alibi,
618 alibi_slope=alibi_slope,
619 )
620 S = apply_mask(
621 S,
622 col_idx,
623 row_idx,
624 seqlen_q,
625 seqlen_k,
626 window_size_left,
627 window_size_right,
628 is_even_mn=True,
629 is_causal=False,
630 is_local=is_local,
631 )
633 acc_, P, rowmax_, rowsum_ = softmax_rescale(
634 acc_,
635 S,
636 rowmax_,
637 rowsum_,
638 softmax_scale_log2e=scale_softmax_log2,
639 is_border=is_local,
640 )
641 P = P.to(v_ptr.type.element_ty)
643 if is_dropout:
644 if return_softmax:
645 P_drop = P
646 P_drop = apply_dropout(
647 P_drop,
648 row_start,
649 col_start,
650 seqlen_k,
651 bid,
652 hid,
653 philox_seed,
654 philox_offset,
655 p_dropout_in_uint8_t,
656 is_dropout,
657 encode_dropout_in_sign_bit=True,
658 NUM_HEADS=h,
659 BLOCK_M=BLOCK_M,
660 BLOCK_N=BLOCK_N,
661 )
662 if IS_EVEN_MN:
663 tl.store(p_bp0 + col_start, P_drop)
664 else:
665 kvmask = col_idx < seqlen_k
666 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
668 P = apply_dropout(
669 P,
670 row_start,
671 col_start,
672 seqlen_k,
673 bid,
674 hid,
675 philox_seed,
676 philox_offset,
677 p_dropout_in_uint8_t,
678 is_dropout,
679 encode_dropout_in_sign_bit=False,
680 NUM_HEADS=h,
681 BLOCK_M=BLOCK_M,
682 BLOCK_N=BLOCK_N,
683 )
685 if not PRE_LOAD_V:
686 off = col_start * k_row_stride
687 if d == BLOCK_K:
688 V = tl.load(p_bv0 + off, cache_modifier=".cg")
689 else:
690 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
691 acc_ = tl.dot(P, V, acc_)
693 # LSE
694 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels
695 # the effect of rowmax and outputs lse only.
696 lse = tl.where(
697 rowsum_ == 0 | (rowsum_ != rowsum_),
698 float("inf"),
699 rowmax_ * scale_softmax + tl.log(rowsum_),
700 )
701 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
703 if is_dropout:
704 acc_ *= inv_sum[:, None] * rp_dropout
705 else:
706 acc_ *= inv_sum[:, None]
708 out = acc_.to(o_ptr.type.element_ty) # noqa
710 # Write back output
711 o_batch_stride = tl.multiple_of(o_batch_stride, d * h)
712 o_ptr += bid * o_batch_stride
713 o_ptr += hid * o_head_stride
714 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, BLOCK_K)
716 if IS_EVEN_MN & d == BLOCK_K:
717 tl.store(o_ptr + o_offset, out)
718 else:
719 tl.store(o_ptr + o_offset, out, mask=qmask)
721 # Write back lse
722 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q
723 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
725 if IS_EVEN_MN:
726 tl.store(p_lse + row_idx, lse)
727 else:
728 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q)
731@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"])
732def flash_fwd_bh_parallel_kernel():
733 # (TODO)
734 pass
737# @libentry()
738# @triton.heuristics(
739# values={
740# "BLOCK_M": block_m_splitkv_heuristic_spec_args,
741# "BLOCK_N": block_n_splitkv_heuristic_spec_args,
742# "BLOCK_K": lambda args: triton.next_power_of_2(args["d"]),
743# "num_warps": lambda args: 4,
744# "num_stages": lambda args: 3,
745# "PRE_LOAD_V": lambda args: True,
746# "IS_EVEN_MN": is_even_mn_spec_args,
747# }
748# )
749# @triton.jit(
750# do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
751# )
752def flash_fwd_splitkv_kernel(
753 q_ptr,
754 k_ptr,
755 v_ptr,
756 o_ptr,
757 p_ptr,
758 softmax_lse_ptr,
759 q_row_stride,
760 k_row_stride,
761 v_row_stride,
762 q_head_stride,
763 k_head_stride,
764 v_head_stride,
765 o_row_stride,
766 o_head_stride,
767 q_batch_stride,
768 k_batch_stride,
769 v_batch_stride,
770 o_batch_stride,
771 is_cu_seqlens_q,
772 cu_seqlens_q_ptr,
773 is_cu_seqlens_k: tl.constexpr,
774 cu_seqlens_k_ptr,
775 is_seqused_k: tl.constexpr,
776 seqused_k_ptr,
777 # sizes
778 b: tl.constexpr,
779 bk: tl.constexpr,
780 h: tl.constexpr,
781 hk: tl.constexpr,
782 h_hk_ratio: tl.constexpr,
783 seqlen_q,
784 seqlen_k,
785 seqlen_q_rounded,
786 seqlen_k_rounded,
787 d: tl.constexpr,
788 d_rounded: tl.constexpr,
789 # scaling factors
790 is_softcap: tl.constexpr,
791 softcap: tl.constexpr,
792 scale_softmax: tl.constexpr,
793 scale_softmax_log2: tl.constexpr,
794 # dropout
795 is_dropout: tl.constexpr,
796 p_dropout: tl.constexpr,
797 rp_dropout: tl.constexpr,
798 p_dropout_in_uint8_t: tl.constexpr,
799 philox_args,
800 return_softmax: tl.constexpr,
801 # causal and swa
802 is_causal: tl.constexpr,
803 is_local: tl.constexpr,
804 window_size_left: tl.constexpr,
805 window_size_right: tl.constexpr,
806 seqlenq_ngroups_swapped: tl.constexpr,
807 # alibi
808 is_alibi: tl.constexpr,
809 alibi_slopes_ptr,
810 alibi_slopes_batch_stride: tl.constexpr,
811 # block table
812 total_q,
813 page_table_ptr,
814 page_table_batch_stride: tl.constexpr,
815 block_size: tl.constexpr,
816 # kernel params
817 IS_EVEN_MN: tl.constexpr,
818 PRE_LOAD_V: tl.constexpr,
819 blocks_per_split: tl.constexpr,
820 BLOCK_M: tl.constexpr,
821 BLOCK_N: tl.constexpr,
822 BLOCK_K: tl.constexpr,
823 num_warps: tl.constexpr,
824 num_stages: tl.constexpr,
825):
826 m_block = tl.program_id(0)
827 split_id = tl.program_id(1)
828 bid = tl.program_id(2) // h
829 hid = tl.program_id(2) % h
831 split_block_min = split_id * blocks_per_split
832 split_block_max = split_block_min + blocks_per_split
834 n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
835 if is_causal:
836 n_block_max = min(
837 n_block_max,
838 tl.cdiv(
839 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right,
840 BLOCK_N,
841 ),
842 )
844 if is_alibi:
845 alibi_offset = bid * alibi_slopes_batch_stride + hid
846 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
847 alibi_slope /= scale_softmax
848 else:
849 alibi_slope = 0
851 if not is_causal:
852 if IS_EVEN_MN:
853 masking_block_min = n_block_max
854 else:
855 masking_block_min = n_block_max - 1
856 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero
857 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N)
858 else:
859 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1
861 q_ptr += bid * q_batch_stride
862 q_ptr += hid * q_head_stride
863 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
864 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
865 p_qm = q_ptr + q_off
866 dmask = tl.arange(0, BLOCK_K) < d
867 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
868 if IS_EVEN_MN & BLOCK_K == d:
869 Q = tl.load(p_qm, cache_modifier=".cg")
870 else:
871 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg")
873 h_hk_ratio = h // hk
874 k_ptr += bid * k_batch_stride
875 k_ptr += (hid // h_hk_ratio) * k_head_stride
876 v_ptr += bid * k_batch_stride
877 v_ptr += (hid // h_hk_ratio) * k_head_stride
879 k_offset = (
880 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
881 )
882 p_k0 = k_ptr + k_offset
884 v_offset = (
885 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
886 )
887 p_v0 = v_ptr + v_offset
889 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
890 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
891 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
893 if split_block_max <= masking_block_min:
894 # no masking needed
895 for n_block in tl.range(
896 split_block_min, split_block_max, num_stages=num_stages
897 ):
898 kv_off = n_block * BLOCK_N * k_row_stride
899 if d == BLOCK_K:
900 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
901 else:
902 K = tl.load(
903 p_k0 + kv_off, mask=dmask[:, None], cache_modifier=".cg", other=0.0
904 )
905 if PRE_LOAD_V:
906 if d == BLOCK_K:
907 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
908 else:
909 V = tl.load(
910 p_v0 + kv_off,
911 mask=dmask[None, :],
912 cache_modifier=".cg",
913 other=0.0,
914 )
915 S = tl.dot(Q, K)
916 S = apply_softcap(S, softcap, is_softcap)
917 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
918 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
919 S = apply_alibi(
920 S,
921 col_idx,
922 row_idx,
923 seqlen_q,
924 seqlen_k,
925 is_causal=is_causal,
926 is_alibi=is_alibi,
927 alibi_slope=alibi_slope,
928 )
929 acc_, P, rowmax_, rowsum_ = softmax_rescale(
930 acc_,
931 S,
932 rowmax_,
933 rowsum_,
934 softmax_scale_log2e=scale_softmax_log2,
935 is_border=False,
936 )
938 if not PRE_LOAD_V:
939 if d == BLOCK_K:
940 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
941 else:
942 V = tl.load(
943 p_v0 + kv_off,
944 mask=dmask[None, :],
945 cache_modifier=".cg",
946 other=0.0,
947 )
948 P = P.to(v_ptr.type.element_ty)
949 acc_ = tl.dot(P, V, acc_)
950 else:
951 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)):
952 kv_off = n_block * BLOCK_N * k_row_stride
953 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
954 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
955 if IS_EVEN_MN & d == BLOCK_K:
956 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
957 if PRE_LOAD_V:
958 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
959 elif d == BLOCK_K:
960 kvmask = col_idx < seqlen_k
961 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg")
962 if PRE_LOAD_V:
963 V = tl.load(
964 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
965 )
966 else:
967 kvmask = col_idx < seqlen_k
968 K = tl.load(
969 p_k0 + kv_off,
970 mask=dmask[:, None] & kvmask[None, :],
971 cache_modifier=".cg",
972 other=0.0,
973 )
974 if PRE_LOAD_V:
975 V = tl.load(
976 p_v0 + kv_off,
977 mask=dmask[None, :] & kvmask[:, None],
978 cache_modifier=".cg",
979 other=0.0,
980 )
982 S = tl.dot(Q, K)
983 S = apply_softcap(S, softcap, is_softcap)
984 S = apply_alibi(
985 S,
986 col_idx,
987 row_idx,
988 seqlen_q,
989 seqlen_k,
990 is_causal=is_causal,
991 is_alibi=is_alibi,
992 alibi_slope=alibi_slope,
993 )
994 S = apply_mask(
995 S,
996 col_idx,
997 row_idx,
998 seqlen_q,
999 seqlen_k,
1000 window_size_left,
1001 window_size_right,
1002 is_even_mn=IS_EVEN_MN,
1003 is_causal=is_causal,
1004 is_local=False,
1005 )
1007 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1008 acc_,
1009 S,
1010 rowmax_,
1011 rowsum_,
1012 softmax_scale_log2e=scale_softmax_log2,
1013 is_border=(is_causal or is_local),
1014 )
1016 if not PRE_LOAD_V:
1017 if IS_EVEN_MN & d == BLOCK_K:
1018 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
1019 elif d == BLOCK_K:
1020 V = tl.load(
1021 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
1022 )
1023 else:
1024 V = tl.load(
1025 p_v0 + kv_off,
1026 mask=dmask[None, :] & kvmask[:, None],
1027 cache_modifier=".cg",
1028 other=0.0,
1029 )
1030 P = P.to(v_ptr.type.element_ty)
1031 acc_ = tl.dot(P, V, acc_)
1033 # LSE
1034 lse = tl.where(
1035 rowsum_ == 0 | (rowsum_ != rowsum_),
1036 float("-inf"),
1037 rowmax_ * scale_softmax + tl.log(rowsum_),
1038 )
1039 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1041 # Rescale output
1042 acc_ *= inv_sum[:, None]
1044 # Write back output
1045 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size)
1046 # grid = (seq_block, split, batch * head)
1047 o_split_ptr = o_ptr
1048 # + split, batch, head offsets, seq_block offsets are already added in row_idx
1049 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d
1050 o_split_offset = row_idx[:, None] * d + tl.arange(0, BLOCK_K)
1051 o_split_ptr = tl.multiple_of(o_split_ptr, d)
1052 p_om = o_split_ptr + o_split_offset
1054 if IS_EVEN_MN & BLOCK_K == d:
1055 tl.store(p_om, acc_, cache_modifier=".cg")
1056 else:
1057 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg")
1059 # Write back lse
1060 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
1061 lse_split_ptr = softmax_lse_ptr
1062 # + split, batch, head, seq_block offsets
1063 lse_split_ptr += (
1064 split_id * tl.num_programs(2) + tl.program_id(2)
1065 ) * seqlen_q + m_block * BLOCK_M
1067 if IS_EVEN_MN:
1068 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg")
1069 else:
1070 tl.store(
1071 lse_split_ptr + tl.arange(0, BLOCK_M),
1072 lse,
1073 mask=row_idx < seqlen_q,
1074 cache_modifier=".cg",
1075 )
1078@libentry()
1079@triton.jit
1080def flash_fwd_splitkv_combine_kernel(
1081 out_ptr,
1082 lse_ptr,
1083 out_splits_ptr,
1084 lse_splits_ptr,
1085 head_size: tl.constexpr,
1086 out_split_stride,
1087 lse_split_stride,
1088 out_b_stride,
1089 out_s_stride,
1090 out_h_stride,
1091 n_splits,
1092 BLOCK_M: tl.constexpr,
1093 BLOCK_K: tl.constexpr,
1094 q_total,
1095 MAX_N_SPLITS: tl.constexpr,
1096):
1097 pid = tl.program_id(0)
1098 lse_splits_ptr += pid * BLOCK_M
1099 lse_ptr += pid * BLOCK_M
1100 out_splits_ptr += pid * BLOCK_M * head_size
1101 out_ptr += pid * BLOCK_M * head_size
1103 # Subtracting maximum from each of the split lse's for better numerical stability
1104 lse_split_offset = (
1105 tl.arange(0, BLOCK_M)[:, None]
1106 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride
1107 )
1108 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & (
1109 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits
1110 )
1111 lse_splits = tl.load(
1112 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf")
1113 )
1114 max_lse = tl.max(lse_splits, 1)
1116 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse)
1117 Zi_scaled = tl.exp(lse_splits - max_lse[:, None])
1118 Z_scaled = tl.sum(Zi_scaled, 1)
1119 Zi_Z = Zi_scaled / Z_scaled[:, None]
1121 # Write back LSE
1122 lse = tl.log(Z_scaled) + max_lse
1123 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total
1124 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask)
1126 out_split_offset = (
1127 tl.arange(0, BLOCK_M)[:, None, None] * head_size
1128 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride
1129 + tl.arange(0, BLOCK_K)[None, None, :]
1130 )
1131 out_split_mask = (
1132 (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total)
1133 & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits)
1134 & (tl.arange(0, BLOCK_K)[None, None, :] < head_size)
1135 )
1136 out_splits = tl.load(
1137 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0.0
1138 )
1139 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1)
1140 out = out.to(out_ptr.type.element_ty)
1142 # Write back output
1143 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, BLOCK_K)
1144 dmask = tl.arange(0, BLOCK_K) < head_size
1145 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None] & dmask[None, :])
1148@triton.jit
1149def virtual_to_cache(virtual_index, page_table_ptr, block_size):
1150 # virtual_index is the kv sequence index in the current batch element
1151 # page_table_ptr is already pointed at current batch element's block table entry
1152 # block_size is the size of each block in the page table
1153 virtual_page_index = virtual_index // block_size
1154 page_offset = virtual_index % block_size
1155 page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32)
1156 return page_block_index * block_size + page_offset
1159@triton.jit
1160def load_from_kvcache(
1161 i,
1162 page_table_ptr,
1163 k_ptr_base,
1164 v_ptr_base,
1165 block_size,
1166 d,
1167 k_row_stride,
1168 BLOCK_K: tl.constexpr,
1169):
1170 kvcache_idx = virtual_to_cache(i, page_table_ptr, block_size)
1171 k_offset = tl.arange(0, BLOCK_K)[:, None] + kvcache_idx[None, :] * k_row_stride
1172 v_offset = tl.arange(0, BLOCK_K)[None, :] + kvcache_idx[:, None] * k_row_stride
1173 bK = tl.load(
1174 k_ptr_base + k_offset, mask=tl.arange(0, BLOCK_K)[:, None] < d, other=0.0
1175 )
1176 bV = tl.load(
1177 v_ptr_base + v_offset, mask=tl.arange(0, BLOCK_K)[None, :] < d, other=0.0
1178 )
1179 return bK, bV
1182@libentry()
1183@triton.jit(
1184 do_not_specialize=[
1185 "q_batch_stride",
1186 "k_batch_stride",
1187 "v_batch_stride",
1188 "o_batch_stride",
1189 "b",
1190 "bk",
1191 "seqlen_q",
1192 "seqlen_k",
1193 "seqlen_q_rounded",
1194 "seqlen_k_rounded",
1195 "total_q",
1196 ]
1197)
1198def flash_varlen_fwd_kernel(
1199 q_ptr,
1200 k_ptr,
1201 v_ptr,
1202 o_ptr,
1203 p_ptr,
1204 softmax_lse_ptr,
1205 q_row_stride,
1206 k_row_stride,
1207 v_row_stride,
1208 q_head_stride,
1209 k_head_stride,
1210 v_head_stride,
1211 o_row_stride,
1212 o_head_stride,
1213 q_batch_stride,
1214 k_batch_stride,
1215 v_batch_stride,
1216 o_batch_stride,
1217 is_cu_seqlens_q: tl.constexpr,
1218 cu_seqlens_q_ptr,
1219 is_cu_seqlens_k: tl.constexpr,
1220 cu_seqlens_k_ptr,
1221 is_seqused_k: tl.constexpr,
1222 seqused_k_ptr,
1223 # sizes
1224 b,
1225 bk,
1226 h: tl.constexpr,
1227 hk: tl.constexpr,
1228 h_hk_ratio: tl.constexpr,
1229 seqlen_q,
1230 seqlen_k,
1231 seqlen_q_rounded,
1232 seqlen_k_rounded,
1233 d: tl.constexpr,
1234 d_rounded: tl.constexpr,
1235 # scaling factors
1236 is_softcap: tl.constexpr,
1237 softcap: tl.constexpr,
1238 scale_softmax: tl.constexpr,
1239 scale_softmax_log2: tl.constexpr,
1240 # dropout
1241 is_dropout: tl.constexpr,
1242 p_dropout: tl.constexpr,
1243 rp_dropout: tl.constexpr,
1244 p_dropout_in_uint8_t: tl.constexpr,
1245 philox_args,
1246 return_softmax: tl.constexpr,
1247 # causal and swa
1248 is_causal: tl.constexpr,
1249 is_local: tl.constexpr,
1250 window_size_left: tl.constexpr,
1251 window_size_right: tl.constexpr,
1252 seqlenq_ngroups_swapped: tl.constexpr,
1253 # alibi
1254 is_alibi: tl.constexpr,
1255 alibi_slopes_ptr,
1256 alibi_slopes_batch_stride: tl.constexpr,
1257 # block table
1258 total_q,
1259 page_table_ptr,
1260 page_table_batch_stride: tl.constexpr,
1261 block_size: tl.constexpr,
1262 # kernel params
1263 BLOCK_M: tl.constexpr,
1264 BLOCK_N: tl.constexpr,
1265 BLOCK_K: tl.constexpr,
1266 num_warps: tl.constexpr,
1267 num_stages: tl.constexpr,
1268):
1269 m_block = tl.program_id(0)
1270 bid = tl.program_id(1)
1271 hid = tl.program_id(2)
1272 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
1274 if is_cu_seqlens_q:
1275 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32)
1276 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32)
1277 q_len = q_eos - q_bos
1278 # Current request's start offset in the batched Q
1279 q_offset = q_bos * q_row_stride
1280 o_offset = q_bos * o_row_stride
1281 lse_offset = q_bos * 1
1282 else:
1283 q_len = seqlen_q
1284 q_offset = bid * q_batch_stride
1285 o_offset = bid * o_batch_stride
1286 lse_offset = bid * seqlen_q
1288 if is_cu_seqlens_k:
1289 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32)
1290 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32)
1291 k_len_cache = k_eos - k_bos
1292 # k_offset = k_bos * k_row_stride
1293 else:
1294 k_len_cache = seqlen_k
1295 # k_offset = bid * k_batch_stride
1297 if is_seqused_k:
1298 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32)
1299 else:
1300 k_len = k_len_cache
1302 # Noop CTA
1303 if m_block * BLOCK_M > q_len:
1304 return
1306 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0)
1307 is_even_mn: tl.constexpr = False
1309 if is_local:
1310 n_block_min = max(
1311 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N
1312 )
1313 else:
1314 n_block_min = 0
1316 n_block_max = tl.cdiv(k_len, BLOCK_N)
1317 if is_causal or is_local:
1318 n_block_max = min(
1319 n_block_max,
1320 tl.cdiv(
1321 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N
1322 ),
1323 )
1325 if is_dropout:
1326 philox_seed = tl.load(philox_args).to(tl.uint64)
1327 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
1329 # Locate the page table entry for the current batch element
1330 page_table_ptr += bid * page_table_batch_stride
1331 # Calculate the starting offset of q for the current head
1332 q_row_offset = hid * q_head_stride
1333 # Calculate the starting offset of k and v for the current head
1334 k_row_offset = (hid // h_hk_ratio) * k_head_stride
1335 # Shift the k, v pointers to align with the current head
1336 k_ptr_base = k_ptr + k_row_offset
1337 v_ptr_base = v_ptr + k_row_offset
1339 gQ = tl.make_block_ptr(
1340 base=q_ptr + q_offset + q_row_offset,
1341 shape=(q_len, d),
1342 strides=(q_row_stride, 1),
1343 offsets=(0, 0),
1344 block_shape=(BLOCK_M, BLOCK_K),
1345 order=(0, 1),
1346 )
1347 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0, 1))
1349 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
1350 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
1351 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
1353 if is_alibi:
1354 alibi_offset = bid * alibi_slopes_batch_stride + hid
1355 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
1356 alibi_slope /= scale_softmax
1357 else:
1358 alibi_slope = 0.0
1360 if not is_causal and not is_local:
1361 n_masking_steps = 1
1362 elif is_even_mn:
1363 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
1364 else:
1365 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
1367 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps)
1369 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1370 n_block = n_block_max - 1
1371 for step in tl.range(0, n_masking_steps):
1372 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1373 bK, bV = load_from_kvcache(
1374 col_idx,
1375 page_table_ptr,
1376 k_ptr_base,
1377 v_ptr_base,
1378 block_size,
1379 d,
1380 k_row_stride,
1381 BLOCK_K=BLOCK_K,
1382 )
1383 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1384 S = apply_softcap(S, softcap, is_softcap)
1385 S = apply_alibi(
1386 S,
1387 col_idx,
1388 row_idx,
1389 q_len,
1390 k_len,
1391 is_causal=is_causal,
1392 is_alibi=is_alibi,
1393 alibi_slope=alibi_slope,
1394 )
1395 S = apply_mask(
1396 S,
1397 col_idx,
1398 row_idx,
1399 q_len,
1400 k_len,
1401 window_size_left,
1402 window_size_right,
1403 is_even_mn=is_even_mn,
1404 is_causal=is_causal,
1405 is_local=is_local,
1406 )
1408 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1409 acc_,
1410 S,
1411 rowmax_,
1412 rowsum_,
1413 softmax_scale_log2e=scale_softmax_log2,
1414 is_border=True,
1415 )
1416 P = P.to(v_ptr.type.element_ty)
1418 if is_dropout:
1419 P = apply_dropout(
1420 P,
1421 n_block * BLOCK_N,
1422 m_block * BLOCK_M,
1423 k_len,
1424 bid,
1425 hid,
1426 philox_seed,
1427 philox_offset,
1428 p_dropout_in_uint8_t,
1429 is_dropout,
1430 encode_dropout_in_sign_bit=False,
1431 NUM_HEADS=h,
1432 BLOCK_M=BLOCK_M,
1433 BLOCK_N=BLOCK_N,
1434 )
1436 acc_ = tl.dot(P, bV, acc_)
1437 n_block -= 1
1439 for n_block in tl.range(
1440 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1
1441 ):
1442 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1443 bK, bV = load_from_kvcache(
1444 col_idx,
1445 page_table_ptr,
1446 k_ptr_base,
1447 v_ptr_base,
1448 block_size,
1449 d,
1450 k_row_stride,
1451 BLOCK_K=BLOCK_K,
1452 )
1453 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1454 S = apply_softcap(S, softcap, is_softcap)
1455 S = apply_alibi(
1456 S,
1457 col_idx,
1458 row_idx,
1459 q_len,
1460 k_len,
1461 is_causal=is_causal,
1462 is_alibi=is_alibi,
1463 alibi_slope=alibi_slope,
1464 )
1465 S = apply_mask(
1466 S,
1467 col_idx,
1468 row_idx,
1469 q_len,
1470 k_len,
1471 window_size_left,
1472 window_size_right,
1473 is_even_mn=True,
1474 is_causal=False,
1475 is_local=is_local,
1476 )
1478 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1479 acc_,
1480 S,
1481 rowmax_,
1482 rowsum_,
1483 softmax_scale_log2e=scale_softmax_log2,
1484 is_border=is_local,
1485 )
1486 P = P.to(v_ptr.type.element_ty)
1488 if is_dropout:
1489 P = apply_dropout(
1490 P,
1491 m_block * BLOCK_M,
1492 n_block * BLOCK_N,
1493 k_len,
1494 bid,
1495 hid,
1496 philox_seed,
1497 philox_offset,
1498 p_dropout_in_uint8_t,
1499 is_dropout,
1500 encode_dropout_in_sign_bit=False,
1501 NUM_HEADS=h,
1502 BLOCK_M=BLOCK_M,
1503 BLOCK_N=BLOCK_N,
1504 )
1505 acc_ = tl.dot(P, bV, acc_)
1507 # LSE
1508 lse = tl.where(
1509 rowsum_ == 0 | (rowsum_ != rowsum_),
1510 float("inf"),
1511 rowmax_ * scale_softmax + tl.log(rowsum_),
1512 )
1513 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1515 acc_ *= inv_sum[:, None]
1517 out = acc_.to(o_ptr.type.element_ty) # noqa
1519 # Write back output
1520 o_row_offset = hid * o_head_stride
1522 gO = tl.make_block_ptr(
1523 base=o_ptr + o_offset + o_row_offset,
1524 shape=(q_len, d),
1525 strides=(o_row_stride, 1),
1526 offsets=(0, 0),
1527 block_shape=(BLOCK_M, BLOCK_K),
1528 order=(0, 1),
1529 )
1530 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0, 1))
1532 # Write back lse
1533 # lse shape: [h, total_q]
1534 softmax_lse_ptr += hid * total_q
1535 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1536 tl.store(
1537 softmax_lse_ptr + lse_row_offset,
1538 lse,
1539 mask=lse_row_offset < (lse_offset + q_len),
1540 )