Coverage for src/flag_gems/ops/flash_api.py: 91%
373 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
2import math
4import torch
5import triton
7import flag_gems
8from flag_gems import runtime
9from flag_gems.ops.flash_kernel import (
10 block_m_splitkv_heuristic,
11 block_n_splitkv_heuristic,
12 flash_fwd_kernel,
13 flash_fwd_splitkv_combine_kernel,
14 flash_fwd_splitkv_kernel,
15 flash_varlen_fwd_kernel,
16)
17from flag_gems.runtime import torch_device_fn
18from flag_gems.utils.random_utils import philox_backend_seed_offset
20logger = logging.getLogger(__name__)
21_debug = False
24def CHECK_DEVICE(x):
25 assert x.device.type == flag_gems.device
28class fwd_params:
29 __slots__ = (
30 # pointers and strides
31 "q_ptr",
32 "k_ptr",
33 "v_ptr",
34 "o_ptr",
35 "p_ptr",
36 "softmax_lse_ptr",
37 "q_row_stride",
38 "k_row_stride",
39 "v_row_stride",
40 "q_head_stride",
41 "k_head_stride",
42 "v_head_stride",
43 "o_row_stride",
44 "o_head_stride",
45 "q_batch_stride",
46 "k_batch_stride",
47 "v_batch_stride",
48 "o_batch_stride",
49 "is_cu_seqlens_q",
50 "cu_seqlens_q_ptr",
51 "is_cu_seqlens_k",
52 "cu_seqlens_k_ptr",
53 "is_seqused_k",
54 "seqused_k_ptr",
55 # sizes
56 "b",
57 "bk",
58 "h",
59 "hk",
60 "h_hk_ratio",
61 "seqlen_q",
62 "seqlen_k",
63 "seqlen_q_rounded",
64 "seqlen_k_rounded",
65 "d",
66 "d_rounded",
67 # scaling factors
68 "is_softcap",
69 "softcap",
70 "scale_softmax",
71 "scale_softmax_log2",
72 # dropout
73 "is_dropout",
74 "p_dropout",
75 "rp_dropout",
76 "p_dropout_in_uint8_t",
77 "philox_args",
78 "return_softmax",
79 # masking
80 "is_causal",
81 "is_local",
82 "window_size_left",
83 "window_size_right",
84 "seqlenq_ngroups_swapped",
85 # alibi
86 "is_alibi",
87 "alibi_slopes_ptr",
88 "alibi_slopes_batch_stride",
89 # block table
90 "total_q",
91 "page_table_ptr",
92 "page_table_batch_stride",
93 "block_size",
94 )
96 def __init__(
97 self,
98 q_ptr,
99 k_ptr,
100 v_ptr,
101 o_ptr,
102 p_ptr,
103 softmax_lse_ptr,
104 q_row_stride,
105 k_row_stride,
106 v_row_stride,
107 q_head_stride,
108 k_head_stride,
109 v_head_stride,
110 o_row_stride,
111 o_head_stride,
112 q_batch_stride,
113 k_batch_stride,
114 v_batch_stride,
115 o_batch_stride,
116 is_cu_seqlens_q,
117 cu_seqlens_q_ptr,
118 is_cu_seqlens_k,
119 cu_seqlens_k_ptr,
120 is_seqused_k,
121 seqused_k_ptr,
122 # sizes
123 b,
124 bk,
125 h,
126 hk,
127 h_hk_ratio,
128 seqlen_q,
129 seqlen_k,
130 seqlen_q_rounded,
131 seqlen_k_rounded,
132 d,
133 d_rounded,
134 # scaling factors
135 is_softcap,
136 softcap,
137 scale_softmax,
138 scale_softmax_log2,
139 # dropout
140 is_dropout,
141 p_dropout,
142 rp_dropout,
143 p_dropout_in_uint8_t,
144 philox_args,
145 return_softmax,
146 # masking
147 is_causal,
148 is_local,
149 window_size_left,
150 window_size_right,
151 seqlenq_ngroups_swapped,
152 # alibi
153 is_alibi,
154 alibi_slopes_ptr,
155 alibi_slopes_batch_stride,
156 # block table
157 total_q,
158 page_table_ptr,
159 page_table_batch_stride,
160 block_size,
161 ):
162 self.q_ptr = q_ptr
163 self.k_ptr = k_ptr
164 self.v_ptr = v_ptr
165 self.o_ptr = o_ptr
166 self.p_ptr = p_ptr
167 self.softmax_lse_ptr = softmax_lse_ptr
168 self.q_row_stride = q_row_stride
169 self.k_row_stride = k_row_stride
170 self.v_row_stride = v_row_stride
171 self.q_head_stride = q_head_stride
172 self.k_head_stride = k_head_stride
173 self.v_head_stride = v_head_stride
174 self.o_row_stride = o_row_stride
175 self.o_head_stride = o_head_stride
176 self.q_batch_stride = q_batch_stride
177 self.k_batch_stride = k_batch_stride
178 self.v_batch_stride = v_batch_stride
179 self.o_batch_stride = o_batch_stride
180 self.is_cu_seqlens_q = is_cu_seqlens_q
181 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr
182 self.is_cu_seqlens_k = is_cu_seqlens_k
183 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr
184 self.is_seqused_k = is_seqused_k
185 self.seqused_k_ptr = seqused_k_ptr
186 # sizes
187 self.b = b
188 self.bk = bk
189 self.h = h
190 self.hk = hk
191 self.h_hk_ratio = h_hk_ratio
192 self.seqlen_q = seqlen_q
193 self.seqlen_k = seqlen_k
194 self.seqlen_q_rounded = seqlen_q_rounded
195 self.seqlen_k_rounded = seqlen_k_rounded
196 self.d = d
197 self.d_rounded = d_rounded
198 # scaling factors
199 self.is_softcap = is_softcap
200 self.softcap = softcap
201 self.scale_softmax = scale_softmax
202 self.scale_softmax_log2 = scale_softmax_log2
203 # dropout
204 self.is_dropout = is_dropout
205 self.p_dropout = p_dropout
206 self.rp_dropout = rp_dropout
207 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t
208 self.philox_args = philox_args
209 self.return_softmax = return_softmax
210 # masking
211 self.is_causal = is_causal
212 self.is_local = is_local
213 self.window_size_left = window_size_left
214 self.window_size_right = window_size_right
215 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped
216 # alibi
217 self.is_alibi = is_alibi
218 self.alibi_slopes_ptr = alibi_slopes_ptr
219 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride
220 # block table
221 self.total_q = total_q
222 self.page_table_ptr = page_table_ptr
223 self.page_table_batch_stride = page_table_batch_stride
224 self.block_size = block_size
226 def args(self):
227 return tuple(getattr(self, k) for k in self.__slots__)
230def mha_varlan_fwd(
231 q,
232 k,
233 v,
234 out,
235 cu_seqlens_q,
236 cu_seqlens_k,
237 seqused_k,
238 leftpad_k,
239 page_table,
240 alibi_slopes,
241 max_seqlen_q,
242 max_seqlen_k,
243 p_dropout,
244 softmax_scale,
245 zero_tensors,
246 is_causal,
247 window_size_left,
248 window_size_right,
249 softcap,
250 return_softmax,
251 gen,
252):
253 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
254 q_device = q.device
255 q_dtype = q.dtype
256 assert q_dtype in (
257 torch.float16,
258 torch.bfloat16,
259 ), "FlashAttention only support fp16 and bf16 data type"
260 assert q_dtype == k.dtype
261 assert q_dtype == v.dtype
262 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
263 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
264 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
266 assert cu_seqlens_q.dtype == torch.int32
267 assert cu_seqlens_q.is_contiguous()
269 assert cu_seqlens_k.dtype == torch.int32
270 assert cu_seqlens_k.is_contiguous()
272 assert page_table is not None
274 # q shape: [total_q_tokens, num_heads, head_size]
275 # k shape:
276 # paged_kv: [num_pages, block_size, num_heads_k, head_size]
277 # batch_size, number of sentences
278 total_q, num_heads, head_size = q.size()
279 num_heads_k = k.size(2)
280 batch_size = cu_seqlens_q.numel() - 1
281 block_size = k.size(1)
282 num_pages = k.size(0)
283 k_batch_size = num_pages
284 # max_num_pages_per_seq = page_table.size(1)
285 page_table_batch_stride = page_table.stride(0)
286 k_batch_stride = k.stride(0)
287 v_batch_stride = v.stride(0)
289 assert k.size() == v.size()
290 assert cu_seqlens_q.size() == (batch_size + 1,)
291 assert cu_seqlens_k.size() == (batch_size + 1,)
293 # Check output shape
294 if out is not None:
295 assert out.stride(-1) == 1
296 assert out.dtype == q.dtype
297 assert out.size() == (total_q, num_heads, head_size)
299 if seqused_k is not None:
300 assert seqused_k.is_contiguous()
301 assert seqused_k.size() == (batch_size,)
303 if max_seqlen_q == 1 and alibi_slopes is None:
304 is_causal = False
306 if is_causal:
307 window_size_right = 0
309 # check disable swa
310 if window_size_left >= max_seqlen_k:
311 window_size_left = -1
312 if window_size_right >= max_seqlen_k:
313 window_size_right = -1
315 is_local = window_size_left >= 0
317 # Optimize all single-query sequences by swapping the query-group and sequence dimensions
318 seqlenq_ngroups_swapped = (
319 max_seqlen_q == 1
320 and alibi_slopes is None
321 and num_heads > num_heads_k
322 and window_size_left < 0
323 and window_size_right < 0
324 and p_dropout == 0
325 )
326 q_groups = num_heads // num_heads_k
327 if seqlenq_ngroups_swapped:
328 logger.debug("Swapping query groups and sequence dimensions")
329 q = (
330 q.reshape((batch_size, num_heads_k, q_groups, head_size))
331 .transpose(1, 2)
332 .reshape(batch_size * q_groups, num_heads_k, head_size)
333 )
334 max_seqlen_q = q_groups
335 num_heads = num_heads_k
336 cu_seqlens_q = None
337 q_batch_stride = q.stride(0) * max_seqlen_q
338 k_batch_stride = k.stride(0)
339 v_batch_stride = v.stride(0)
340 # o_batch_stride = out.stride(0) * max_seqlen_q
341 else:
342 q_batch_stride = 0
343 k_batch_stride = 0
344 v_batch_stride = 0
345 o_batch_stride = 0
347 total_q = q.size(0)
349 assert leftpad_k is None, "leftpad_k is not supported."
350 assert (
351 head_size <= 256
352 ), "FlashAttention forward only supports head dimension at most 256"
353 assert (
354 head_size % 8 == 0
355 ), "head_size must be a multiple of 8, this is ensured by padding!"
356 assert (
357 num_heads % num_heads_k == 0
358 ), "Number of heads in key/value must divide number of heads in query"
360 assert q.shape == (total_q, num_heads, head_size)
361 assert k.shape == (num_pages, block_size, num_heads_k, head_size)
362 assert v.shape == (num_pages, block_size, num_heads_k, head_size)
363 assert k.stride() == v.stride()
365 if softcap > 0.0:
366 assert p_dropout == 0, "dropout is not supported if softcap is used."
368 round_multiple = lambda x, m: (x + m - 1) // m * m
369 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256
370 seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
371 seqlen_k_rounded = round_multiple(max_seqlen_k, 32)
373 M_LOG2E = 1.4426950408889634074
374 if softcap > 0.0:
375 is_softcap = True
376 adjusted_scale_softmax = softcap
377 adjusted_softcap = softmax_scale / softcap
378 adjusted_scale_softmax_log2e = softcap * M_LOG2E
379 else:
380 is_softcap = False
381 adjusted_softcap = 0.0
382 adjusted_scale_softmax = softmax_scale
383 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
385 # Set alibi params
386 if alibi_slopes is not None:
387 assert alibi_slopes.device == q_device
388 assert alibi_slopes.dtype in (torch.float,)
389 assert alibi_slopes.stride(-1) == 1
390 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
391 batch_size,
392 num_heads,
393 )
394 alibi_slopes_batch_stride = (
395 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
396 )
397 is_alibi = True
398 else:
399 alibi_slopes_batch_stride = 0
400 is_alibi = False
402 # Prepare params to kernel
403 with torch_device_fn.device(q_device):
404 if out is not None:
405 out_ = out
406 if seqlenq_ngroups_swapped:
407 out = torch.empty_like(q, dtype=v.dtype)
408 else:
409 out_ = None
410 out = torch.empty_like(q, dtype=v.dtype)
412 if seqlenq_ngroups_swapped:
413 o_batch_stride = out.stride(0) * max_seqlen_q
415 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device)
417 if p_dropout > 0:
418 is_dropout = True
419 increment = batch_size * num_heads * 32
420 philox_seed, philox_offset = philox_backend_seed_offset(increment)
421 philox_args = torch.tensor(
422 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
423 )
424 else:
425 is_dropout = False
426 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
428 p_dropout = 1 - p_dropout
429 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
430 rp_dropout = 1.0 / p_dropout
432 if return_softmax:
433 assert is_dropout, "Only supported with non-zero dropout."
434 p = torch.empty(
435 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
436 device=q_device,
437 )
438 else:
439 p = torch.empty((), device=q_device)
441 if zero_tensors:
442 out.zero_()
443 lse.fill_(float("-inf"))
445 params = fwd_params(
446 q, # q_ptr,
447 k, # k_ptr,
448 v, # v_ptr,
449 out, # o_ptr,
450 p, # p_ptr,
451 lse, # softmax_lse_ptr,
452 q.stride(-3), # q_row_stride,
453 k.stride(-3), # k_row_stride,
454 v.stride(-3), # v_row_stride,
455 q.stride(-2), # q_head_stride,
456 k.stride(-2), # k_head_stride,
457 v.stride(-2), # v_head_stride,
458 out.stride(-3), # o_row_stride,
459 out.stride(-2), # o_head_stride,
460 q_batch_stride, # q_batch_stride,
461 k_batch_stride, # k_batch_stride,
462 v_batch_stride, # v_batch_stride,
463 o_batch_stride, # o_batch_stride,
464 cu_seqlens_q is not None, # is_cu_seqlens_q,
465 cu_seqlens_q, # cu_seqlens_q_ptr,
466 seqused_k is None, # is_cu_seqlens_k,
467 cu_seqlens_k, # cu_seqlens_k_ptr,
468 seqused_k is not None, # is_seqused_k,
469 seqused_k, # seqused_k_ptr,
470 # sizes
471 batch_size, # b,
472 k_batch_size, # bk,
473 num_heads, # h,
474 num_heads_k, # hk,
475 num_heads // num_heads_k, # h_hk_ratio,
476 max_seqlen_q, # seqlen_q,
477 max_seqlen_k, # seqlen_k,
478 seqlen_q_rounded, # seqlen_q_rounded,
479 seqlen_k_rounded, # seqlen_k_rounded,
480 head_size, # d,
481 head_size_rounded, # d_rounded,
482 # scaling factors
483 is_softcap,
484 adjusted_softcap, # softcap,
485 adjusted_scale_softmax, # scale_softmax,
486 adjusted_scale_softmax_log2e, # scale_softmax_log2,
487 # dropout
488 is_dropout,
489 p_dropout,
490 rp_dropout,
491 p_dropout_in_uint8_t,
492 philox_args,
493 return_softmax,
494 # causal and swa
495 is_causal, # is_causal,
496 is_local, # is_local,
497 window_size_left, # window_size_left,
498 window_size_right, # window_size_right,
499 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
500 # alibi
501 is_alibi, #
502 alibi_slopes, # alibi_slopes_ptr,
503 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
504 # block table params
505 total_q, # total_q,
506 page_table, # page_table_ptr,
507 page_table_batch_stride, # page_table_batch_stride,
508 block_size, # block_size,
509 )
511 if flag_gems.vendor_name == "iluvatar":
512 params.k_ptr = k.view(k.shape[0], k.shape[1], -1)
513 params.v_ptr = v.view(v.shape[0], v.shape[1], -1)
514 logger.debug("kernel: flash_varlen_fwd")
515 grid = lambda args: (
516 triton.cdiv(max_seqlen_q, args["BLOCK_M"]),
517 batch_size,
518 num_heads,
519 )
520 kernel = flash_varlen_fwd_kernel[grid]
521 args = tuple(getattr(params, k) for k in params.__slots__)
523 # We assess which phase the requests are likely to be in and set the config accordingly.
524 total_rows = total_q * num_heads
525 num_sms = torch_device_fn.get_device_properties(
526 flag_gems.device
527 ).multi_processor_count
528 avg_rows_per_sm = total_rows / num_sms
529 avg_rows_per_batch = total_q / batch_size
530 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm)
531 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase.
532 # This is a rough heuristic and may not be accurate for all scenarios.
533 if avg_rows_per_cta > 64:
534 varlen_fwd_config_str = "mha_block_128"
535 elif avg_rows_per_cta > 32:
536 varlen_fwd_config_str = "mha_block_64"
537 elif avg_rows_per_cta > 16:
538 varlen_fwd_config_str = "mha_block_32"
539 else:
540 varlen_fwd_config_str = "mha_block_16"
541 if flag_gems.vendor_name == "mthreads":
542 varlen_fwd_config_str = "mha_block_32"
544 cfg = runtime.get_heuristic_config(varlen_fwd_config_str)
545 cfg_params = {
546 "BLOCK_M": cfg["BLOCK_M"](args),
547 "BLOCK_N": cfg["BLOCK_N"](args),
548 "BLOCK_K": triton.next_power_of_2(head_size),
549 "num_warps": cfg["num_warps"](args),
550 "num_stages": cfg["num_stages"](args),
551 }
553 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params)
554 kernel(*args, **cfg_params)
556 if seqlenq_ngroups_swapped:
557 out = out.reshape(
558 batch_size, max_seqlen_q, num_heads_k, head_size
559 ).transpose(1, 2)
560 if out_ is not None:
561 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out)
562 out = out_
563 else:
564 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size)
565 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q)
566 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size)
568 unused = torch.empty((), dtype=torch.int64, device=q_device)
569 return out, q, k, v, lse, philox_args, unused, p
572def mha_fwd(
573 q,
574 k,
575 v,
576 out,
577 alibi_slopes,
578 p_dropout,
579 softmax_scale,
580 is_causal,
581 window_size_left,
582 window_size_right,
583 softcap,
584 return_softmax,
585 disable_splitkv=False,
586):
587 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
588 q_dtype = q.dtype
589 q_device = q.device
590 assert q_dtype in (
591 torch.float16,
592 torch.bfloat16,
593 ), "FlashAttention only support fp16 and bf16 data type"
594 assert q_dtype == k.dtype
595 assert q_dtype == v.dtype
596 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
597 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
598 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
599 batch_size, seqlen_q, num_heads, head_size = q.size()
600 _, seqlen_k, num_heads_k, _ = k.size()
602 # Check output shape
603 if out is not None:
604 assert out.stride(-1) == 1
605 assert out.dtype == q.dtype
606 assert out.size() == (batch_size, seqlen_q, num_heads, head_size)
607 CHECK_DEVICE(out)
609 assert (
610 head_size % 8 == 0
611 ), "head_size must be a multiple of 8, this is ensured by padding!"
612 assert (
613 num_heads % num_heads_k == 0
614 ), "Number of heads in key/value must divide number of heads in query"
615 if window_size_left >= seqlen_k:
616 window_size_left = -1
617 if window_size_right >= seqlen_k:
618 window_size_right = -1
619 if seqlen_q == 1 and alibi_slopes is None:
620 is_causal = False
621 if is_causal:
622 window_size_right = 0
624 is_causal = window_size_left < 0 and window_size_right == 0
625 is_local = window_size_left >= 0 and window_size_right >= 0
627 seqlenq_ngroups_swapped = (
628 seqlen_q == 1
629 and alibi_slopes is None
630 and num_heads > num_heads_k
631 and window_size_left < 0
632 and window_size_right < 0
633 and p_dropout == 0
634 )
635 q_groups = num_heads // num_heads_k
637 if seqlenq_ngroups_swapped:
638 logger.debug("q_kg swapped.")
639 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2)
640 seqlen_q = q_groups
641 num_heads = num_heads_k
643 round_multiple = lambda x, m: (x + m - 1) // m * m
644 head_size_rounded = round_multiple(head_size, 32)
645 seqlen_q_rounded = round_multiple(seqlen_q, 128)
646 seqlen_k_rounded = round_multiple(seqlen_k, 32)
648 assert (
649 head_size <= 256
650 ), "FlashAttention forward only supports head dimension at most 256"
651 assert head_size == head_size_rounded, "head_size must be rounded to 32"
653 def splits_heuristic(num_tasks, num_sms, n_blocks):
654 # splits when wave efficiency is low
655 n_waves = triton.cdiv(num_tasks, num_sms)
656 eff = (num_tasks / num_sms) / n_waves
657 if eff > 0.8 or n_waves > 1:
658 return 1
660 min_blocks_per_split = 2
661 best_splits = min(
662 triton.cdiv(n_blocks, min_blocks_per_split),
663 int(math.floor(1.0 / eff)),
664 num_sms,
665 )
667 return best_splits
669 with torch_device_fn.device(q_device):
670 # Set softmax params
671 lse = torch.empty(
672 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device
673 )
675 if out is not None:
676 if seqlenq_ngroups_swapped:
677 out = out.reshape(
678 batch_size, num_heads_k, q_groups, head_size
679 ).transpose(1, 2)
680 else:
681 out = torch.empty_like(q, dtype=v.dtype)
683 # Set dropout params
684 if p_dropout > 0:
685 is_dropout = True
686 increment = batch_size * num_heads * 32
687 philox_seed, philox_offset = philox_backend_seed_offset(increment)
688 philox_args = torch.tensor(
689 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
690 )
691 else:
692 is_dropout = False
693 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
695 p_dropout = 1 - p_dropout
696 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
697 rp_dropout = 1.0 / p_dropout
699 if return_softmax:
700 assert is_dropout, "Only supported with non-zero dropout."
701 p = torch.empty(
702 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
703 device=q_device,
704 )
705 else:
706 p = torch.empty((), device=q_device)
708 M_LOG2E = 1.4426950408889634074
709 if softcap > 0.0:
710 is_softcap = True
711 adjusted_scale_softmax = softcap
712 adjusted_softcap = softmax_scale / softcap
713 adjusted_scale_softmax_log2e = softcap * M_LOG2E
714 else:
715 is_softcap = False
716 adjusted_softcap = 0.0
717 adjusted_scale_softmax = softmax_scale
718 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
720 # Set alibi params
721 if alibi_slopes is not None:
722 assert alibi_slopes.device == q_device
723 assert alibi_slopes.dtype in (torch.float,)
724 assert alibi_slopes.stride(-1) == 1
725 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
726 batch_size,
727 num_heads,
728 )
729 alibi_slopes_batch_stride = (
730 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
731 )
732 is_alibi = True
733 else:
734 alibi_slopes_batch_stride = 0
735 is_alibi = False
737 # ONLY EVEN_K IS SUPPORTED
738 assert head_size == head_size_rounded
740 # Do kernel dispatching
741 def dispatch(B, H, Q, K, D, params):
742 num_sms = torch_device_fn.get_device_properties(
743 "cuda"
744 ).multi_processor_count
746 # Try bh parallel
747 # if B * H > 0.8 * num_sms:
748 # kernel = flash_fwd_bh_parallel_kernel[(H, B)]
749 # # Yield kernel and prefilled args
750 # return kernel, default_args, None, None
752 # Try splitkv
753 if not is_dropout and not is_local and not disable_splitkv:
754 BM = block_m_splitkv_heuristic(D)
755 n_tasks = B * H * triton.cdiv(seqlen_q, BM)
756 BN = block_n_splitkv_heuristic(D)
757 n_blocks = triton.cdiv(seqlen_k, BN)
758 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks)
760 if n_splits > 1:
761 logger.debug("kernel: flash_fwd_splitkv")
762 lse_splits = torch.empty(
763 (n_splits, B, H, Q), dtype=torch.float, device=q_device
764 )
765 out_splits = torch.empty(
766 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device
767 )
768 grid = lambda args: (
769 triton.cdiv(Q, args["BLOCK_M"]),
770 n_splits,
771 B * H,
772 )
773 splitkv_kernel = flash_fwd_splitkv_kernel[grid]
774 params.o_ptr = out_splits
775 params.softmax_lse_ptr = lse_splits
776 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)}
777 kernel = splitkv_kernel(*params.args(), **extra_args)
779 if D >= 128:
780 BLOCK_M = 4
781 elif D >= 64:
782 BLOCK_M = 8
783 else:
784 BLOCK_M = 16
785 BLOCK_K = triton.next_power_of_2(D)
786 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),)
787 combine_kernel = flash_fwd_splitkv_combine_kernel[grid]
788 combine_args = {
789 "out_ptr": out,
790 "lse_ptr": lse,
791 "head_size": head_size,
792 "out_split_stride": out_splits.stride(0),
793 "lse_split_stride": lse_splits.stride(0),
794 "out_b_stride": out.stride(0),
795 "out_s_stride": out.stride(-3),
796 "out_h_stride": out.stride(-1),
797 "out_splits_ptr": out_splits,
798 "lse_splits_ptr": lse_splits,
799 "n_splits": n_splits,
800 "BLOCK_M": BLOCK_M,
801 "BLOCK_K": BLOCK_K,
802 "q_total": B * H * Q,
803 "MAX_N_SPLITS": triton.next_power_of_2(n_splits),
804 }
805 combine_kernel(**combine_args)
806 return kernel
808 # Last option: flash_fwd
809 logger.debug("kernel: flash_fwd")
810 grid = lambda args: (
811 triton.cdiv(Q, args["BLOCK_M"]),
812 H * B,
813 )
814 kernel = flash_fwd_kernel[grid]
815 kernel = kernel(*params.args())
816 return kernel
818 if _debug:
819 p = torch.empty(
820 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
821 dtype=torch.float32,
822 device=q_device,
823 )
824 return_softmax = True
826 params = fwd_params(
827 q, # q_ptr,
828 k, # k_ptr,
829 v, # v_ptr,
830 out, # o_ptr,
831 p, # p_ptr,
832 lse, # softmax_lse_ptr,
833 q.stride(-3), # q_row_stride,
834 k.stride(-3), # k_row_stride,
835 v.stride(-3), # v_row_stride,
836 q.stride(-2), # q_head_stride,
837 k.stride(-2), # k_head_stride,
838 v.stride(-2), # v_head_stride,
839 out.stride(-3), # o_row_stride,
840 out.stride(-2), # o_head_stride,
841 q.stride(0), # q_batch_stride,
842 k.stride(0), # k_batch_stride,
843 v.stride(0), # v_batch_stride,
844 out.stride(0), # o_batch_stride,
845 False, # is_cu_seqlens_q,
846 None, # cu_seqlens_q_ptr,
847 False, # is_cu_seqlens_k,
848 None, # cu_seqlens_k_ptr,
849 False, # is_seqused_k,
850 None, # seqused_k_ptr,
851 # sizes
852 batch_size, # b,
853 0, # bk,
854 num_heads, # h,
855 num_heads_k, # hk,
856 num_heads // num_heads_k, # h_hk_ratio,
857 seqlen_q, # seqlen_q,
858 seqlen_k, # seqlen_k,
859 seqlen_q_rounded, # seqlen_q_rounded,
860 seqlen_k_rounded, # seqlen_k_rounded,
861 head_size, # d,
862 head_size_rounded, # d_rounded,
863 # scaling factors
864 is_softcap,
865 adjusted_softcap, # softcap,
866 adjusted_scale_softmax, # scale_softmax,
867 adjusted_scale_softmax_log2e, # scale_softmax_log2,
868 # dropout
869 is_dropout,
870 p_dropout,
871 rp_dropout,
872 p_dropout_in_uint8_t,
873 philox_args,
874 return_softmax,
875 # causal and swa
876 is_causal, # is_causal,
877 is_local, # is_local,
878 window_size_left, # window_size_left,
879 window_size_right, # window_size_right,
880 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
881 # alibi
882 is_alibi, #
883 alibi_slopes, # alibi_slopes_ptr,
884 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
885 # block table params
886 0, # total_q,
887 None, # page_table_ptr,
888 0, # page_table_batch_stride,
889 0, # block_size,
890 )
892 # Move TxD to last dims for correct stride in Triton tt.load
893 if flag_gems.vendor_name == "iluvatar":
894 params.q_ptr = q.transpose(1, 2)
895 params.k_ptr = k.transpose(1, 2)
896 params.v_ptr = v.transpose(1, 2)
897 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params)
899 if _debug:
900 print(f"{kernel.name} shared memory:", kernel.metadata.shared)
901 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps)
902 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages)
903 # print(kernel.asm['ttgir'])
905 if seqlenq_ngroups_swapped:
906 out = out.transpose(1, 2).reshape(
907 (batch_size, 1, num_heads_k * seqlen_q, head_size)
908 )
909 q = q.transpose(1, 2).reshape(
910 (batch_size, 1, num_heads_k * seqlen_q, head_size)
911 )
912 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1))
914 unused = torch.empty((), dtype=torch.int64, device=q_device)
916 return out, q, k, v, lse, philox_args, unused, p