Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/flash_api.py: 0%
364 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2import math
4import torch
5import triton
7import flag_gems
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils.random_utils import philox_backend_seed_offset
12from .flash_kernel import (
13 block_m_splitkv_heuristic,
14 block_n_splitkv_heuristic,
15 flash_fwd_kernel,
16 flash_fwd_splitkv_combine_kernel,
17 flash_fwd_splitkv_kernel,
18 flash_varlen_fwd_kernel,
19)
21logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
22_debug = False
25def CHECK_DEVICE(x):
26 assert x.device.type == flag_gems.device
29class fwd_params:
30 __slots__ = (
31 # pointers and strides
32 "q_ptr",
33 "k_ptr",
34 "v_ptr",
35 "o_ptr",
36 "p_ptr",
37 "softmax_lse_ptr",
38 "q_row_stride",
39 "k_row_stride",
40 "v_row_stride",
41 "q_head_stride",
42 "k_head_stride",
43 "v_head_stride",
44 "o_row_stride",
45 "o_head_stride",
46 "q_batch_stride",
47 "k_batch_stride",
48 "v_batch_stride",
49 "o_batch_stride",
50 "is_cu_seqlens_q",
51 "cu_seqlens_q_ptr",
52 "is_cu_seqlens_k",
53 "cu_seqlens_k_ptr",
54 "is_seqused_k",
55 "seqused_k_ptr",
56 # sizes
57 "b",
58 "bk",
59 "h",
60 "hk",
61 "h_hk_ratio",
62 "seqlen_q",
63 "seqlen_k",
64 "seqlen_q_rounded",
65 "seqlen_k_rounded",
66 "d",
67 "d_rounded",
68 # scaling factors
69 "is_softcap",
70 "softcap",
71 "scale_softmax",
72 "scale_softmax_log2",
73 # dropout
74 "is_dropout",
75 "p_dropout",
76 "rp_dropout",
77 "p_dropout_in_uint8_t",
78 "philox_args",
79 "return_softmax",
80 # masking
81 "is_causal",
82 "is_local",
83 "window_size_left",
84 "window_size_right",
85 "seqlenq_ngroups_swapped",
86 # alibi
87 "is_alibi",
88 "alibi_slopes_ptr",
89 "alibi_slopes_batch_stride",
90 # block table
91 "total_q",
92 "page_table_ptr",
93 "page_table_batch_stride",
94 "block_size",
95 )
97 def __init__(
98 self,
99 q_ptr,
100 k_ptr,
101 v_ptr,
102 o_ptr,
103 p_ptr,
104 softmax_lse_ptr,
105 q_row_stride,
106 k_row_stride,
107 v_row_stride,
108 q_head_stride,
109 k_head_stride,
110 v_head_stride,
111 o_row_stride,
112 o_head_stride,
113 q_batch_stride,
114 k_batch_stride,
115 v_batch_stride,
116 o_batch_stride,
117 is_cu_seqlens_q,
118 cu_seqlens_q_ptr,
119 is_cu_seqlens_k,
120 cu_seqlens_k_ptr,
121 is_seqused_k,
122 seqused_k_ptr,
123 # sizes
124 b,
125 bk,
126 h,
127 hk,
128 h_hk_ratio,
129 seqlen_q,
130 seqlen_k,
131 seqlen_q_rounded,
132 seqlen_k_rounded,
133 d,
134 d_rounded,
135 # scaling factors
136 is_softcap,
137 softcap,
138 scale_softmax,
139 scale_softmax_log2,
140 # dropout
141 is_dropout,
142 p_dropout,
143 rp_dropout,
144 p_dropout_in_uint8_t,
145 philox_args,
146 return_softmax,
147 # masking
148 is_causal,
149 is_local,
150 window_size_left,
151 window_size_right,
152 seqlenq_ngroups_swapped,
153 # alibi
154 is_alibi,
155 alibi_slopes_ptr,
156 alibi_slopes_batch_stride,
157 # block table
158 total_q,
159 page_table_ptr,
160 page_table_batch_stride,
161 block_size,
162 ):
163 self.q_ptr = q_ptr
164 self.k_ptr = k_ptr
165 self.v_ptr = v_ptr
166 self.o_ptr = o_ptr
167 self.p_ptr = p_ptr
168 self.softmax_lse_ptr = softmax_lse_ptr
169 self.q_row_stride = q_row_stride
170 self.k_row_stride = k_row_stride
171 self.v_row_stride = v_row_stride
172 self.q_head_stride = q_head_stride
173 self.k_head_stride = k_head_stride
174 self.v_head_stride = v_head_stride
175 self.o_row_stride = o_row_stride
176 self.o_head_stride = o_head_stride
177 self.q_batch_stride = q_batch_stride
178 self.k_batch_stride = k_batch_stride
179 self.v_batch_stride = v_batch_stride
180 self.o_batch_stride = o_batch_stride
181 self.is_cu_seqlens_q = is_cu_seqlens_q
182 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr
183 self.is_cu_seqlens_k = is_cu_seqlens_k
184 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr
185 self.is_seqused_k = is_seqused_k
186 self.seqused_k_ptr = seqused_k_ptr
187 # sizes
188 self.b = b
189 self.bk = bk
190 self.h = h
191 self.hk = hk
192 self.h_hk_ratio = h_hk_ratio
193 self.seqlen_q = seqlen_q
194 self.seqlen_k = seqlen_k
195 self.seqlen_q_rounded = seqlen_q_rounded
196 self.seqlen_k_rounded = seqlen_k_rounded
197 self.d = d
198 self.d_rounded = d_rounded
199 # scaling factors
200 self.is_softcap = is_softcap
201 self.softcap = softcap
202 self.scale_softmax = scale_softmax
203 self.scale_softmax_log2 = scale_softmax_log2
204 # dropout
205 self.is_dropout = is_dropout
206 self.p_dropout = p_dropout
207 self.rp_dropout = rp_dropout
208 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t
209 self.philox_args = philox_args
210 self.return_softmax = return_softmax
211 # masking
212 self.is_causal = is_causal
213 self.is_local = is_local
214 self.window_size_left = window_size_left
215 self.window_size_right = window_size_right
216 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped
217 # alibi
218 self.is_alibi = is_alibi
219 self.alibi_slopes_ptr = alibi_slopes_ptr
220 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride
221 # block table
222 self.total_q = total_q
223 self.page_table_ptr = page_table_ptr
224 self.page_table_batch_stride = page_table_batch_stride
225 self.block_size = block_size
227 def args(self):
228 return tuple(getattr(self, k) for k in self.__slots__)
231def mha_varlan_fwd(
232 q,
233 k,
234 v,
235 out,
236 cu_seqlens_q,
237 cu_seqlens_k,
238 seqused_k,
239 leftpad_k,
240 page_table,
241 alibi_slopes,
242 max_seqlen_q,
243 max_seqlen_k,
244 p_dropout,
245 softmax_scale,
246 zero_tensors,
247 is_causal,
248 window_size_left,
249 window_size_right,
250 softcap,
251 return_softmax,
252 gen,
253):
254 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
255 q_device = q.device
256 q_dtype = q.dtype
257 assert q_dtype in (
258 torch.float16,
259 torch.bfloat16,
260 ), "FlashAttention only support fp16 and bf16 data type"
261 assert q_dtype == k.dtype
262 assert q_dtype == v.dtype
263 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
264 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
265 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
267 assert cu_seqlens_q.dtype == torch.int32
268 assert cu_seqlens_q.is_contiguous()
270 assert cu_seqlens_k.dtype == torch.int32
271 assert cu_seqlens_k.is_contiguous()
273 assert page_table is not None
275 # q shape: [total_q_tokens, num_heads, head_size]
276 # k shape:
277 # paged_kv: [num_pages, block_size, num_heads_k, head_size]
278 # batch_size, number of sentences
279 total_q, num_heads, head_size = q.size()
280 num_heads_k = k.size(2)
281 batch_size = cu_seqlens_q.numel() - 1
282 block_size = k.size(1)
283 num_pages = k.size(0)
284 k_batch_size = num_pages
285 # max_num_pages_per_seq = page_table.size(1)
286 page_table_batch_stride = page_table.stride(0)
287 k_batch_stride = k.stride(0)
288 v_batch_stride = v.stride(0)
290 assert k.size() == v.size()
291 assert cu_seqlens_q.size() == (batch_size + 1,)
292 assert cu_seqlens_k.size() == (batch_size + 1,)
294 # Check output shape
295 if out is not None:
296 assert out.stride(-1) == 1
297 assert out.dtype == q.dtype
298 assert out.size() == (total_q, num_heads, head_size)
300 if seqused_k is not None:
301 assert seqused_k.is_contiguous()
302 assert seqused_k.size() == (batch_size,)
304 if max_seqlen_q == 1 and alibi_slopes is None:
305 is_causal = False
307 if is_causal:
308 window_size_right = 0
310 # check disable swa
311 if window_size_left >= max_seqlen_k:
312 window_size_left = -1
313 if window_size_right >= max_seqlen_k:
314 window_size_right = -1
316 is_local = window_size_left >= 0
318 # Optimize all single-query sequences by swapping the query-group and sequence dimensions
319 seqlenq_ngroups_swapped = (
320 max_seqlen_q == 1
321 and alibi_slopes is None
322 and num_heads > num_heads_k
323 and window_size_left < 0
324 and window_size_right < 0
325 and p_dropout == 0
326 )
327 q_groups = num_heads // num_heads_k
328 if seqlenq_ngroups_swapped:
329 logger.debug("Swapping query groups and sequence dimensions")
330 q = (
331 q.reshape((batch_size, num_heads_k, q_groups, head_size))
332 .transpose(1, 2)
333 .reshape(batch_size * q_groups, num_heads_k, head_size)
334 )
335 max_seqlen_q = q_groups
336 num_heads = num_heads_k
337 cu_seqlens_q = None
338 q_batch_stride = q.stride(0) * max_seqlen_q
339 k_batch_stride = k.stride(0)
340 v_batch_stride = v.stride(0)
341 # o_batch_stride = out.stride(0) * max_seqlen_q
342 else:
343 q_batch_stride = 0
344 k_batch_stride = 0
345 v_batch_stride = 0
346 o_batch_stride = 0
348 total_q = q.size(0)
350 assert leftpad_k is None, "leftpad_k is not supported."
351 assert (
352 head_size <= 256
353 ), "FlashAttention forward only supports head dimension at most 256"
354 assert (
355 head_size % 8 == 0
356 ), "head_size must be a multiple of 8, this is ensured by padding!"
357 assert (
358 num_heads % num_heads_k == 0
359 ), "Number of heads in key/value must divide number of heads in query"
361 assert q.shape == (total_q, num_heads, head_size)
362 assert k.shape == (num_pages, block_size, num_heads_k, head_size)
363 assert v.shape == (num_pages, block_size, num_heads_k, head_size)
364 assert k.stride() == v.stride()
366 if softcap > 0.0:
367 assert p_dropout == 0, "dropout is not supported if softcap is used."
369 round_multiple = lambda x, m: (x + m - 1) // m * m
370 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256
371 seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
372 seqlen_k_rounded = round_multiple(max_seqlen_k, 32)
374 M_LOG2E = 1.4426950408889634074
375 if softcap > 0.0:
376 is_softcap = True
377 adjusted_scale_softmax = softcap
378 adjusted_softcap = softmax_scale / softcap
379 adjusted_scale_softmax_log2e = softcap * M_LOG2E
380 else:
381 is_softcap = False
382 adjusted_softcap = 0.0
383 adjusted_scale_softmax = softmax_scale
384 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
386 # Set alibi params
387 if alibi_slopes is not None:
388 assert alibi_slopes.device == q_device
389 assert alibi_slopes.dtype in (torch.float,)
390 assert alibi_slopes.stride(-1) == 1
391 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
392 batch_size,
393 num_heads,
394 )
395 alibi_slopes_batch_stride = (
396 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
397 )
398 is_alibi = True
399 else:
400 alibi_slopes_batch_stride = 0
401 is_alibi = False
403 # Prepare params to kernel
404 with torch_device_fn.device(q_device):
405 if out is not None:
406 out_ = out
407 if seqlenq_ngroups_swapped:
408 out = torch.empty_like(q, dtype=v.dtype)
409 else:
410 out_ = None
411 out = torch.empty_like(q, dtype=v.dtype)
413 if seqlenq_ngroups_swapped:
414 o_batch_stride = out.stride(0) * max_seqlen_q
416 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device)
418 if p_dropout > 0:
419 is_dropout = True
420 increment = batch_size * num_heads * 32
421 philox_seed, philox_offset = philox_backend_seed_offset(increment)
422 philox_args = torch.tensor(
423 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
424 )
425 else:
426 is_dropout = False
427 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
429 p_dropout = 1 - p_dropout
430 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
431 rp_dropout = 1.0 / p_dropout
433 if return_softmax:
434 assert is_dropout, "Only supported with non-zero dropout."
435 p = torch.empty(
436 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
437 device=q_device,
438 )
439 else:
440 p = torch.empty((), device=q_device)
442 if zero_tensors:
443 out.zero_()
444 lse.fill_(float("-inf"))
446 params = fwd_params(
447 q, # q_ptr,
448 k, # k_ptr,
449 v, # v_ptr,
450 out, # o_ptr,
451 p, # p_ptr,
452 lse, # softmax_lse_ptr,
453 q.stride(-3), # q_row_stride,
454 k.stride(-3), # k_row_stride,
455 v.stride(-3), # v_row_stride,
456 q.stride(-2), # q_head_stride,
457 k.stride(-2), # k_head_stride,
458 v.stride(-2), # v_head_stride,
459 out.stride(-3), # o_row_stride,
460 out.stride(-2), # o_head_stride,
461 q_batch_stride, # q_batch_stride,
462 k_batch_stride, # k_batch_stride,
463 v_batch_stride, # v_batch_stride,
464 o_batch_stride, # o_batch_stride,
465 cu_seqlens_q is not None, # is_cu_seqlens_q,
466 cu_seqlens_q, # cu_seqlens_q_ptr,
467 seqused_k is None, # is_cu_seqlens_k,
468 cu_seqlens_k, # cu_seqlens_k_ptr,
469 seqused_k is not None, # is_seqused_k,
470 seqused_k, # seqused_k_ptr,
471 # sizes
472 batch_size, # b,
473 k_batch_size, # bk,
474 num_heads, # h,
475 num_heads_k, # hk,
476 num_heads // num_heads_k, # h_hk_ratio,
477 max_seqlen_q, # seqlen_q,
478 max_seqlen_k, # seqlen_k,
479 seqlen_q_rounded, # seqlen_q_rounded,
480 seqlen_k_rounded, # seqlen_k_rounded,
481 head_size, # d,
482 head_size_rounded, # d_rounded,
483 # scaling factors
484 is_softcap,
485 adjusted_softcap, # softcap,
486 adjusted_scale_softmax, # scale_softmax,
487 adjusted_scale_softmax_log2e, # scale_softmax_log2,
488 # dropout
489 is_dropout,
490 p_dropout,
491 rp_dropout,
492 p_dropout_in_uint8_t,
493 philox_args,
494 return_softmax,
495 # causal and swa
496 is_causal, # is_causal,
497 is_local, # is_local,
498 window_size_left, # window_size_left,
499 window_size_right, # window_size_right,
500 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
501 # alibi
502 is_alibi, #
503 alibi_slopes, # alibi_slopes_ptr,
504 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
505 # block table params
506 total_q, # total_q,
507 page_table, # page_table_ptr,
508 page_table_batch_stride, # page_table_batch_stride,
509 block_size, # block_size,
510 )
512 if flag_gems.vendor_name == "iluvatar":
513 params.k_ptr = k.view(k.shape[0], k.shape[1], -1)
514 params.v_ptr = v.view(v.shape[0], v.shape[1], -1)
515 logger.debug("kernel: flash_varlen_fwd")
516 grid = lambda args: (
517 triton.cdiv(max_seqlen_q, args["BLOCK_M"]),
518 batch_size,
519 num_heads,
520 )
521 kernel = flash_varlen_fwd_kernel[grid]
522 args = tuple(getattr(params, k) for k in params.__slots__)
524 # We assess which phase the requests are likely to be in and set the config accordingly.
525 # prefill_config: BLOCK_M=128, BLOCK_N=32, num_warps=4, num_stages=3
526 # decode_config: BLOCK_M=32, BLOCK_N=32, num_warps=4, num_stages=3
527 avg_seqlen_q = total_q / batch_size
528 if avg_seqlen_q >= 256:
529 varlen_fwd_config_str = "mha_varlen_prefill"
530 else:
531 varlen_fwd_config_str = "mha_varlen_decode"
532 cfg = runtime.get_heuristic_config(varlen_fwd_config_str)
533 cfg_params = {
534 "BLOCK_M": cfg["BLOCK_M"](args),
535 "BLOCK_N": cfg["BLOCK_N"](args),
536 "BLOCK_K": triton.next_power_of_2(head_size),
537 "num_warps": cfg["num_warps"](args),
538 "num_stages": cfg["num_stages"](args),
539 }
541 logger.debug("Average query sequence length: %d", avg_seqlen_q)
542 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params)
543 kernel(*args, **cfg_params)
545 if seqlenq_ngroups_swapped:
546 out = out.reshape(
547 batch_size, max_seqlen_q, num_heads_k, head_size
548 ).transpose(1, 2)
549 if out_ is not None:
550 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out)
551 out = out_
552 else:
553 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size)
554 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q)
555 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size)
557 unused = torch.empty((), dtype=torch.int64, device=q_device)
558 return out, q, k, v, lse, philox_args, unused, p
561def mha_fwd(
562 q,
563 k,
564 v,
565 out,
566 alibi_slopes,
567 p_dropout,
568 softmax_scale,
569 is_causal,
570 window_size_left,
571 window_size_right,
572 softcap,
573 return_softmax,
574 disable_splitkv=False,
575):
576 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
577 q_dtype = q.dtype
578 q_device = q.device
579 assert q_dtype in (
580 torch.float16,
581 torch.bfloat16,
582 ), "FlashAttention only support fp16 and bf16 data type"
583 assert q_dtype == k.dtype
584 assert q_dtype == v.dtype
585 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
586 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
587 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
588 batch_size, seqlen_q, num_heads, head_size = q.size()
589 _, seqlen_k, num_heads_k, _ = k.size()
591 # Check output shape
592 if out is not None:
593 assert out.stride(-1) == 1
594 assert out.dtype == q.dtype
595 assert out.size() == (batch_size, seqlen_q, num_heads, head_size)
596 CHECK_DEVICE(out)
598 assert (
599 head_size % 8 == 0
600 ), "head_size must be a multiple of 8, this is ensured by padding!"
601 assert (
602 num_heads % num_heads_k == 0
603 ), "Number of heads in key/value must divide number of heads in query"
604 if window_size_left >= seqlen_k:
605 window_size_left = -1
606 if window_size_right >= seqlen_k:
607 window_size_right = -1
608 if seqlen_q == 1 and alibi_slopes is None:
609 is_causal = False
610 if is_causal:
611 window_size_right = 0
613 is_causal = window_size_left < 0 and window_size_right == 0
614 is_local = window_size_left >= 0 and window_size_right >= 0
616 seqlenq_ngroups_swapped = (
617 seqlen_q == 1
618 and alibi_slopes is None
619 and num_heads > num_heads_k
620 and window_size_left < 0
621 and window_size_right < 0
622 and p_dropout == 0
623 )
624 q_groups = num_heads // num_heads_k
626 if seqlenq_ngroups_swapped:
627 logger.debug("q_kg swapped.")
628 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2)
629 seqlen_q = q_groups
630 num_heads = num_heads_k
632 round_multiple = lambda x, m: (x + m - 1) // m * m
633 head_size_rounded = round_multiple(head_size, 32)
634 seqlen_q_rounded = round_multiple(seqlen_q, 128)
635 seqlen_k_rounded = round_multiple(seqlen_k, 32)
637 assert (
638 head_size <= 256
639 ), "FlashAttention forward only supports head dimension at most 256"
640 assert head_size == head_size_rounded, "head_size must be rounded to 32"
642 def splits_heuristic(num_tasks, num_sms, n_blocks):
643 # splits when wave efficiency is low
644 n_waves = triton.cdiv(num_tasks, num_sms)
645 eff = (num_tasks / num_sms) / n_waves
646 if eff > 0.8 or n_waves > 1:
647 return 1
649 min_blocks_per_split = 2
650 best_splits = min(
651 triton.cdiv(n_blocks, min_blocks_per_split),
652 int(math.floor(1.0 / eff)),
653 num_sms,
654 )
656 return best_splits
658 with torch_device_fn.device(q_device):
659 # Set softmax params
660 lse = torch.empty(
661 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device
662 )
664 if out is not None:
665 if seqlenq_ngroups_swapped:
666 out = out.reshape(
667 batch_size, num_heads_k, q_groups, head_size
668 ).transpose(1, 2)
669 else:
670 out = torch.empty_like(q, dtype=v.dtype)
672 # Set dropout params
673 if p_dropout > 0:
674 is_dropout = True
675 increment = batch_size * num_heads * 32
676 philox_seed, philox_offset = philox_backend_seed_offset(increment)
677 philox_args = torch.tensor(
678 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
679 )
680 else:
681 is_dropout = False
682 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
684 p_dropout = 1 - p_dropout
685 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
686 rp_dropout = 1.0 / p_dropout
688 if return_softmax:
689 assert is_dropout, "Only supported with non-zero dropout."
690 p = torch.empty(
691 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
692 device=q_device,
693 )
694 else:
695 p = torch.empty((), device=q_device)
697 M_LOG2E = 1.4426950408889634074
698 if softcap > 0.0:
699 is_softcap = True
700 adjusted_scale_softmax = softcap
701 adjusted_softcap = softmax_scale / softcap
702 adjusted_scale_softmax_log2e = softcap * M_LOG2E
703 else:
704 is_softcap = False
705 adjusted_softcap = 0.0
706 adjusted_scale_softmax = softmax_scale
707 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
709 # Set alibi params
710 if alibi_slopes is not None:
711 assert alibi_slopes.device == q_device
712 assert alibi_slopes.dtype in (torch.float,)
713 assert alibi_slopes.stride(-1) == 1
714 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
715 batch_size,
716 num_heads,
717 )
718 alibi_slopes_batch_stride = (
719 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
720 )
721 is_alibi = True
722 else:
723 alibi_slopes_batch_stride = 0
724 is_alibi = False
726 # ONLY EVEN_K IS SUPPORTED
727 assert head_size == head_size_rounded
729 # Do kernel dispatching
730 def dispatch(B, H, Q, K, D, params):
731 num_sms = torch_device_fn.get_device_properties(
732 "cuda"
733 ).multi_processor_count
735 # Try bh parallel
736 # if B * H > 0.8 * num_sms:
737 # kernel = flash_fwd_bh_parallel_kernel[(H, B)]
738 # # Yield kernel and prefilled args
739 # return kernel, default_args, None, None
741 # Try splitkv
742 if not is_dropout and not is_local and not disable_splitkv:
743 BM = block_m_splitkv_heuristic(D)
744 n_tasks = B * H * triton.cdiv(seqlen_q, BM)
745 BN = block_n_splitkv_heuristic(D)
746 n_blocks = triton.cdiv(seqlen_k, BN)
747 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks)
749 if n_splits > 1:
750 logger.debug("kernel: flash_fwd_splitkv")
751 lse_splits = torch.empty(
752 (n_splits, B, H, Q), dtype=torch.float, device=q_device
753 )
754 out_splits = torch.empty(
755 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device
756 )
757 grid = lambda args: (
758 triton.cdiv(Q, args["BLOCK_M"]),
759 n_splits,
760 B * H,
761 )
762 splitkv_kernel = flash_fwd_splitkv_kernel[grid]
763 params.o_ptr = out_splits
764 params.softmax_lse_ptr = lse_splits
765 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)}
766 kernel = splitkv_kernel(*params.args(), **extra_args)
768 if D >= 128:
769 BLOCK_M = 4
770 elif D >= 64:
771 BLOCK_M = 8
772 else:
773 BLOCK_M = 16
774 BLOCK_K = triton.next_power_of_2(D)
775 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),)
776 combine_kernel = flash_fwd_splitkv_combine_kernel[grid]
777 combine_args = {
778 "out_ptr": out,
779 "lse_ptr": lse,
780 "head_size": head_size,
781 "out_split_stride": out_splits.stride(0),
782 "lse_split_stride": lse_splits.stride(0),
783 "out_b_stride": out.stride(0),
784 "out_s_stride": out.stride(-3),
785 "out_h_stride": out.stride(-1),
786 "out_splits_ptr": out_splits,
787 "lse_splits_ptr": lse_splits,
788 "n_splits": n_splits,
789 "BLOCK_M": BLOCK_M,
790 "BLOCK_K": BLOCK_K,
791 "q_total": B * H * Q,
792 "MAX_N_SPLITS": triton.next_power_of_2(n_splits),
793 }
794 combine_kernel(**combine_args)
795 return kernel
797 # Last option: flash_fwd
798 logger.debug("kernel: flash_fwd")
799 grid = lambda args: (
800 triton.cdiv(Q, args["BLOCK_M"]),
801 H * B,
802 )
803 kernel = flash_fwd_kernel[grid]
804 kernel = kernel(*params.args())
805 return kernel
807 if _debug:
808 p = torch.empty(
809 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
810 dtype=torch.float32,
811 device=q_device,
812 )
813 return_softmax = True
815 params = fwd_params(
816 q, # q_ptr,
817 k, # k_ptr,
818 v, # v_ptr,
819 out, # o_ptr,
820 p, # p_ptr,
821 lse, # softmax_lse_ptr,
822 q.stride(-3), # q_row_stride,
823 k.stride(-3), # k_row_stride,
824 v.stride(-3), # v_row_stride,
825 q.stride(-2), # q_head_stride,
826 k.stride(-2), # k_head_stride,
827 v.stride(-2), # v_head_stride,
828 out.stride(-3), # o_row_stride,
829 out.stride(-2), # o_head_stride,
830 q.stride(0), # q_batch_stride,
831 k.stride(0), # k_batch_stride,
832 v.stride(0), # v_batch_stride,
833 out.stride(0), # o_batch_stride,
834 False, # is_cu_seqlens_q,
835 None, # cu_seqlens_q_ptr,
836 False, # is_cu_seqlens_k,
837 None, # cu_seqlens_k_ptr,
838 False, # is_seqused_k,
839 None, # seqused_k_ptr,
840 # sizes
841 batch_size, # b,
842 0, # bk,
843 num_heads, # h,
844 num_heads_k, # hk,
845 num_heads // num_heads_k, # h_hk_ratio,
846 seqlen_q, # seqlen_q,
847 seqlen_k, # seqlen_k,
848 seqlen_q_rounded, # seqlen_q_rounded,
849 seqlen_k_rounded, # seqlen_k_rounded,
850 head_size, # d,
851 head_size_rounded, # d_rounded,
852 # scaling factors
853 is_softcap,
854 adjusted_softcap, # softcap,
855 adjusted_scale_softmax, # scale_softmax,
856 adjusted_scale_softmax_log2e, # scale_softmax_log2,
857 # dropout
858 is_dropout,
859 p_dropout,
860 rp_dropout,
861 p_dropout_in_uint8_t,
862 philox_args,
863 return_softmax,
864 # causal and swa
865 is_causal, # is_causal,
866 is_local, # is_local,
867 window_size_left, # window_size_left,
868 window_size_right, # window_size_right,
869 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
870 # alibi
871 is_alibi, #
872 alibi_slopes, # alibi_slopes_ptr,
873 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
874 # block table params
875 0, # total_q,
876 None, # page_table_ptr,
877 0, # page_table_batch_stride,
878 0, # block_size,
879 )
881 # Move TxD to last dims for correct stride in Triton tt.load
882 if flag_gems.vendor_name == "iluvatar":
883 params.q_ptr = q.transpose(1, 2)
884 params.k_ptr = k.transpose(1, 2)
885 params.v_ptr = v.transpose(1, 2)
886 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params)
888 if _debug:
889 print(f"{kernel.name} shared memory:", kernel.metadata.shared)
890 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps)
891 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages)
892 # print(kernel.asm['ttgir'])
894 if seqlenq_ngroups_swapped:
895 out = out.transpose(1, 2).reshape(
896 (batch_size, 1, num_heads_k * seqlen_q, head_size)
897 )
898 q = q.transpose(1, 2).reshape(
899 (batch_size, 1, num_heads_k * seqlen_q, head_size)
900 )
901 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1))
903 unused = torch.empty((), dtype=torch.int64, device=q_device)
905 return out, q, k, v, lse, philox_args, unused, p