Coverage for src/flag_gems/ops/flash_api.py: 91%
377 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
2import math
4import torch
5import triton
7import flag_gems
8from flag_gems import runtime
9from flag_gems.ops.flash_kernel import (
10 block_m_splitkv_heuristic,
11 block_n_splitkv_heuristic,
12 flash_fwd_kernel,
13 flash_fwd_splitkv_combine_kernel,
14 flash_fwd_splitkv_kernel,
15 flash_varlen_fwd_kernel,
16)
17from flag_gems.runtime import torch_device_fn
18from flag_gems.utils.random_utils import philox_backend_seed_offset
20logger = logging.getLogger(__name__)
21_debug = False
24def CHECK_DEVICE(x):
25 assert x.device.type == flag_gems.device
28class fwd_params:
29 __slots__ = (
30 # pointers and strides
31 "q_ptr",
32 "k_ptr",
33 "v_ptr",
34 "o_ptr",
35 "p_ptr",
36 "softmax_lse_ptr",
37 "q_row_stride",
38 "k_row_stride",
39 "v_row_stride",
40 "q_head_stride",
41 "k_head_stride",
42 "v_head_stride",
43 "o_row_stride",
44 "o_head_stride",
45 "q_batch_stride",
46 "k_batch_stride",
47 "v_batch_stride",
48 "o_batch_stride",
49 "is_cu_seqlens_q",
50 "cu_seqlens_q_ptr",
51 "is_cu_seqlens_k",
52 "cu_seqlens_k_ptr",
53 "is_seqused_k",
54 "seqused_k_ptr",
55 # sizes
56 "b",
57 "bk",
58 "h",
59 "hk",
60 "h_hk_ratio",
61 "seqlen_q",
62 "seqlen_k",
63 "seqlen_q_rounded",
64 "seqlen_k_rounded",
65 "d",
66 "d_rounded",
67 # scaling factors
68 "is_softcap",
69 "softcap",
70 "scale_softmax",
71 "scale_softmax_log2",
72 # dropout
73 "is_dropout",
74 "p_dropout",
75 "rp_dropout",
76 "p_dropout_in_uint8_t",
77 "philox_args",
78 "return_softmax",
79 # masking
80 "is_causal",
81 "is_local",
82 "window_size_left",
83 "window_size_right",
84 "seqlenq_ngroups_swapped",
85 "is_paged",
86 # alibi
87 "is_alibi",
88 "alibi_slopes_ptr",
89 "alibi_slopes_batch_stride",
90 # block table
91 "total_q",
92 "page_table_ptr",
93 "page_table_batch_stride",
94 "block_size",
95 )
97 def __init__(
98 self,
99 q_ptr,
100 k_ptr,
101 v_ptr,
102 o_ptr,
103 p_ptr,
104 softmax_lse_ptr,
105 q_row_stride,
106 k_row_stride,
107 v_row_stride,
108 q_head_stride,
109 k_head_stride,
110 v_head_stride,
111 o_row_stride,
112 o_head_stride,
113 q_batch_stride,
114 k_batch_stride,
115 v_batch_stride,
116 o_batch_stride,
117 is_cu_seqlens_q,
118 cu_seqlens_q_ptr,
119 is_cu_seqlens_k,
120 cu_seqlens_k_ptr,
121 is_seqused_k,
122 seqused_k_ptr,
123 # sizes
124 b,
125 bk,
126 h,
127 hk,
128 h_hk_ratio,
129 seqlen_q,
130 seqlen_k,
131 seqlen_q_rounded,
132 seqlen_k_rounded,
133 d,
134 d_rounded,
135 # scaling factors
136 is_softcap,
137 softcap,
138 scale_softmax,
139 scale_softmax_log2,
140 # dropout
141 is_dropout,
142 p_dropout,
143 rp_dropout,
144 p_dropout_in_uint8_t,
145 philox_args,
146 return_softmax,
147 # masking
148 is_causal,
149 is_local,
150 window_size_left,
151 window_size_right,
152 seqlenq_ngroups_swapped,
153 is_paged,
154 # alibi
155 is_alibi,
156 alibi_slopes_ptr,
157 alibi_slopes_batch_stride,
158 # block table
159 total_q,
160 page_table_ptr,
161 page_table_batch_stride,
162 block_size,
163 ):
164 self.q_ptr = q_ptr
165 self.k_ptr = k_ptr
166 self.v_ptr = v_ptr
167 self.o_ptr = o_ptr
168 self.p_ptr = p_ptr
169 self.softmax_lse_ptr = softmax_lse_ptr
170 self.q_row_stride = q_row_stride
171 self.k_row_stride = k_row_stride
172 self.v_row_stride = v_row_stride
173 self.q_head_stride = q_head_stride
174 self.k_head_stride = k_head_stride
175 self.v_head_stride = v_head_stride
176 self.o_row_stride = o_row_stride
177 self.o_head_stride = o_head_stride
178 self.q_batch_stride = q_batch_stride
179 self.k_batch_stride = k_batch_stride
180 self.v_batch_stride = v_batch_stride
181 self.o_batch_stride = o_batch_stride
182 self.is_cu_seqlens_q = is_cu_seqlens_q
183 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr
184 self.is_cu_seqlens_k = is_cu_seqlens_k
185 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr
186 self.is_seqused_k = is_seqused_k
187 self.seqused_k_ptr = seqused_k_ptr
188 # sizes
189 self.b = b
190 self.bk = bk
191 self.h = h
192 self.hk = hk
193 self.h_hk_ratio = h_hk_ratio
194 self.seqlen_q = seqlen_q
195 self.seqlen_k = seqlen_k
196 self.seqlen_q_rounded = seqlen_q_rounded
197 self.seqlen_k_rounded = seqlen_k_rounded
198 self.d = d
199 self.d_rounded = d_rounded
200 # scaling factors
201 self.is_softcap = is_softcap
202 self.softcap = softcap
203 self.scale_softmax = scale_softmax
204 self.scale_softmax_log2 = scale_softmax_log2
205 # dropout
206 self.is_dropout = is_dropout
207 self.p_dropout = p_dropout
208 self.rp_dropout = rp_dropout
209 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t
210 self.philox_args = philox_args
211 self.return_softmax = return_softmax
212 # masking
213 self.is_causal = is_causal
214 self.is_local = is_local
215 self.window_size_left = window_size_left
216 self.window_size_right = window_size_right
217 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped
218 self.is_paged = is_paged
219 # alibi
220 self.is_alibi = is_alibi
221 self.alibi_slopes_ptr = alibi_slopes_ptr
222 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride
223 # block table
224 self.total_q = total_q
225 self.page_table_ptr = page_table_ptr
226 self.page_table_batch_stride = page_table_batch_stride
227 self.block_size = block_size
229 def args(self):
230 return tuple(getattr(self, k) for k in self.__slots__)
233def mha_varlan_fwd(
234 q,
235 k,
236 v,
237 out,
238 cu_seqlens_q,
239 cu_seqlens_k,
240 seqused_k,
241 leftpad_k,
242 page_table,
243 alibi_slopes,
244 max_seqlen_q,
245 max_seqlen_k,
246 p_dropout,
247 softmax_scale,
248 zero_tensors,
249 is_causal,
250 window_size_left,
251 window_size_right,
252 softcap,
253 return_softmax,
254 gen,
255):
256 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
257 q_device = q.device
258 q_dtype = q.dtype
259 assert q_dtype in (
260 torch.float16,
261 torch.bfloat16,
262 ), "FlashAttention only support fp16 and bf16 data type"
263 assert q_dtype == k.dtype
264 assert q_dtype == v.dtype
265 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
266 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
267 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
269 assert cu_seqlens_q.dtype == torch.int32
270 assert cu_seqlens_q.is_contiguous()
272 assert cu_seqlens_k.dtype == torch.int32
273 assert cu_seqlens_k.is_contiguous()
275 is_paged = page_table is not None
276 if not is_paged:
277 page_table = torch.empty((0, 0), device=q_device, dtype=torch.int32)
279 # q shape: [total_q_tokens, num_heads, head_size]
280 # k shape:
281 # paged_kv: [num_pages, block_size, num_heads_k, head_size]
282 # batch_size, number of sentences
283 total_q, num_heads, head_size = q.size()
284 num_heads_k = k.size(2) if is_paged else k.size(1)
285 batch_size = cu_seqlens_q.numel() - 1
286 block_size = k.size(1) if is_paged else 1
287 num_pages = k.size(0) if is_paged else 0
288 k_batch_size = num_pages
289 # max_num_pages_per_seq = page_table.size(1)
290 page_table_batch_stride = page_table.stride(0)
291 k_batch_stride = k.stride(0)
292 v_batch_stride = v.stride(0)
294 assert k.size() == v.size()
295 assert cu_seqlens_q.size() == (batch_size + 1,)
296 assert cu_seqlens_k.size() == (batch_size + 1,)
298 # Check output shape
299 if out is not None:
300 assert out.stride(-1) == 1
301 assert out.dtype == q.dtype
302 assert out.size() == (total_q, num_heads, head_size)
304 if seqused_k is not None:
305 assert seqused_k.is_contiguous()
306 assert seqused_k.size() == (batch_size,)
308 if max_seqlen_q == 1 and alibi_slopes is None:
309 is_causal = False
311 if is_causal:
312 window_size_right = 0
314 # check disable swa
315 if window_size_left >= max_seqlen_k:
316 window_size_left = -1
317 if window_size_right >= max_seqlen_k:
318 window_size_right = -1
320 is_local = window_size_left >= 0
322 # Optimize all single-query sequences by swapping the query-group and sequence dimensions
323 seqlenq_ngroups_swapped = (
324 max_seqlen_q == 1
325 and alibi_slopes is None
326 and num_heads > num_heads_k
327 and window_size_left < 0
328 and window_size_right < 0
329 and p_dropout == 0
330 )
331 q_groups = num_heads // num_heads_k
332 if seqlenq_ngroups_swapped:
333 logger.debug("Swapping query groups and sequence dimensions")
334 q = (
335 q.reshape((batch_size, num_heads_k, q_groups, head_size))
336 .transpose(1, 2)
337 .reshape(batch_size * q_groups, num_heads_k, head_size)
338 )
339 max_seqlen_q = q_groups
340 num_heads = num_heads_k
341 cu_seqlens_q = None
342 q_batch_stride = q.stride(0) * max_seqlen_q
343 k_batch_stride = k.stride(0)
344 v_batch_stride = v.stride(0)
345 # o_batch_stride = out.stride(0) * max_seqlen_q
346 else:
347 q_batch_stride = 0
348 k_batch_stride = 0
349 v_batch_stride = 0
350 o_batch_stride = 0
352 total_q = q.size(0)
354 assert leftpad_k is None, "leftpad_k is not supported."
355 assert (
356 head_size <= 256
357 ), "FlashAttention forward only supports head dimension at most 256"
358 assert (
359 head_size % 8 == 0
360 ), "head_size must be a multiple of 8, this is ensured by padding!"
361 assert (
362 num_heads % num_heads_k == 0
363 ), "Number of heads in key/value must divide number of heads in query"
365 assert q.shape == (total_q, num_heads, head_size)
366 if is_paged:
367 assert k.shape == (num_pages, block_size, num_heads_k, head_size)
368 assert v.shape == (num_pages, block_size, num_heads_k, head_size)
369 assert k.stride() == v.stride()
371 if softcap > 0.0:
372 assert p_dropout == 0, "dropout is not supported if softcap is used."
374 round_multiple = lambda x, m: (x + m - 1) // m * m
375 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256
376 seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
377 seqlen_k_rounded = round_multiple(max_seqlen_k, 32)
379 M_LOG2E = 1.4426950408889634074
380 if softcap > 0.0:
381 is_softcap = True
382 adjusted_scale_softmax = softcap
383 adjusted_softcap = softmax_scale / softcap
384 adjusted_scale_softmax_log2e = softcap * M_LOG2E
385 else:
386 is_softcap = False
387 adjusted_softcap = 0.0
388 adjusted_scale_softmax = softmax_scale
389 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
391 # Set alibi params
392 if alibi_slopes is not None:
393 assert alibi_slopes.device == q_device
394 assert alibi_slopes.dtype in (torch.float,)
395 assert alibi_slopes.stride(-1) == 1
396 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
397 batch_size,
398 num_heads,
399 )
400 alibi_slopes_batch_stride = (
401 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
402 )
403 is_alibi = True
404 else:
405 alibi_slopes_batch_stride = 0
406 is_alibi = False
408 # Prepare params to kernel
409 with torch_device_fn.device(q_device):
410 if out is not None:
411 out_ = out
412 if seqlenq_ngroups_swapped:
413 out = torch.empty_like(q, dtype=v.dtype)
414 else:
415 out_ = None
416 out = torch.empty_like(q, dtype=v.dtype)
418 if seqlenq_ngroups_swapped:
419 o_batch_stride = out.stride(0) * max_seqlen_q
421 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device)
423 if p_dropout > 0:
424 is_dropout = True
425 increment = batch_size * num_heads * 32
426 philox_seed, philox_offset = philox_backend_seed_offset(increment)
427 philox_args = torch.tensor(
428 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
429 )
430 else:
431 is_dropout = False
432 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
434 p_dropout = 1 - p_dropout
435 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
436 rp_dropout = 1.0 / p_dropout
438 if return_softmax:
439 assert is_dropout, "Only supported with non-zero dropout."
440 p = torch.empty(
441 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
442 device=q_device,
443 )
444 else:
445 p = torch.empty((), device=q_device)
447 if zero_tensors:
448 out.zero_()
449 lse.fill_(float("-inf"))
451 params = fwd_params(
452 q, # q_ptr,
453 k, # k_ptr,
454 v, # v_ptr,
455 out, # o_ptr,
456 p, # p_ptr,
457 lse, # softmax_lse_ptr,
458 q.stride(-3), # q_row_stride,
459 k.stride(-3), # k_row_stride,
460 v.stride(-3), # v_row_stride,
461 q.stride(-2), # q_head_stride,
462 k.stride(-2), # k_head_stride,
463 v.stride(-2), # v_head_stride,
464 out.stride(-3), # o_row_stride,
465 out.stride(-2), # o_head_stride,
466 q_batch_stride, # q_batch_stride,
467 k_batch_stride, # k_batch_stride,
468 v_batch_stride, # v_batch_stride,
469 o_batch_stride, # o_batch_stride,
470 cu_seqlens_q is not None, # is_cu_seqlens_q,
471 cu_seqlens_q, # cu_seqlens_q_ptr,
472 seqused_k is None, # is_cu_seqlens_k,
473 cu_seqlens_k, # cu_seqlens_k_ptr,
474 seqused_k is not None, # is_seqused_k,
475 seqused_k, # seqused_k_ptr,
476 # sizes
477 batch_size, # b,
478 k_batch_size, # bk,
479 num_heads, # h,
480 num_heads_k, # hk,
481 num_heads // num_heads_k, # h_hk_ratio,
482 max_seqlen_q, # seqlen_q,
483 max_seqlen_k, # seqlen_k,
484 seqlen_q_rounded, # seqlen_q_rounded,
485 seqlen_k_rounded, # seqlen_k_rounded,
486 head_size, # d,
487 head_size_rounded, # d_rounded,
488 # scaling factors
489 is_softcap,
490 adjusted_softcap, # softcap,
491 adjusted_scale_softmax, # scale_softmax,
492 adjusted_scale_softmax_log2e, # scale_softmax_log2,
493 # dropout
494 is_dropout,
495 p_dropout,
496 rp_dropout,
497 p_dropout_in_uint8_t,
498 philox_args,
499 return_softmax,
500 # causal and swa
501 is_causal, # is_causal,
502 is_local, # is_local,
503 window_size_left, # window_size_left,
504 window_size_right, # window_size_right,
505 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
506 is_paged,
507 # alibi
508 is_alibi, #
509 alibi_slopes, # alibi_slopes_ptr,
510 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
511 # block table params
512 total_q, # total_q,
513 page_table, # page_table_ptr,
514 page_table_batch_stride, # page_table_batch_stride,
515 block_size, # block_size,
516 )
518 if flag_gems.vendor_name == "iluvatar":
519 params.k_ptr = k.view(k.shape[0], k.shape[1], -1)
520 params.v_ptr = v.view(v.shape[0], v.shape[1], -1)
521 logger.debug("kernel: flash_varlen_fwd")
522 grid = lambda args: (
523 triton.cdiv(max_seqlen_q, args["BLOCK_M"]),
524 batch_size,
525 num_heads,
526 )
527 kernel = flash_varlen_fwd_kernel[grid]
528 args = tuple(getattr(params, k) for k in params.__slots__)
530 # We assess which phase the requests are likely to be in and set the config accordingly.
531 total_rows = total_q * num_heads
532 num_sms = torch_device_fn.get_device_properties(
533 flag_gems.device
534 ).multi_processor_count
535 avg_rows_per_sm = total_rows / num_sms
536 avg_rows_per_batch = total_q / batch_size
537 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm)
538 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase.
539 # This is a rough heuristic and may not be accurate for all scenarios.
540 if avg_rows_per_cta > 64:
541 varlen_fwd_config_str = "mha_block_128"
542 elif avg_rows_per_cta > 32:
543 varlen_fwd_config_str = "mha_block_64"
544 elif avg_rows_per_cta > 16:
545 varlen_fwd_config_str = "mha_block_32"
546 else:
547 varlen_fwd_config_str = "mha_block_16"
548 if flag_gems.vendor_name == "mthreads":
549 varlen_fwd_config_str = "mha_block_32"
551 cfg = runtime.get_heuristic_config(varlen_fwd_config_str)
552 cfg_params = {
553 "BLOCK_M": cfg["BLOCK_M"](args),
554 "BLOCK_N": cfg["BLOCK_N"](args),
555 "BLOCK_K": triton.next_power_of_2(head_size),
556 "num_warps": cfg["num_warps"](args),
557 "num_stages": 1 if not is_paged else cfg["num_stages"](args),
558 }
560 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params)
561 kernel(*args, **cfg_params)
563 if seqlenq_ngroups_swapped:
564 out = out.reshape(
565 batch_size, max_seqlen_q, num_heads_k, head_size
566 ).transpose(1, 2)
567 if out_ is not None:
568 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out)
569 out = out_
570 else:
571 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size)
572 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q)
573 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size)
575 unused = torch.empty((), dtype=torch.int64, device=q_device)
576 return out, q, k, v, lse, philox_args, unused, p
579def mha_fwd(
580 q,
581 k,
582 v,
583 out,
584 alibi_slopes,
585 p_dropout,
586 softmax_scale,
587 is_causal,
588 window_size_left,
589 window_size_right,
590 softcap,
591 return_softmax,
592 disable_splitkv=False,
593):
594 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
595 q_dtype = q.dtype
596 q_device = q.device
597 assert q_dtype in (
598 torch.float16,
599 torch.bfloat16,
600 ), "FlashAttention only support fp16 and bf16 data type"
601 assert q_dtype == k.dtype
602 assert q_dtype == v.dtype
603 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
604 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
605 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
606 batch_size, seqlen_q, num_heads, head_size = q.size()
607 _, seqlen_k, num_heads_k, _ = k.size()
609 # Check output shape
610 if out is not None:
611 assert out.stride(-1) == 1
612 assert out.dtype == q.dtype
613 assert out.size() == (batch_size, seqlen_q, num_heads, head_size)
614 CHECK_DEVICE(out)
616 assert (
617 head_size % 8 == 0
618 ), "head_size must be a multiple of 8, this is ensured by padding!"
619 assert (
620 num_heads % num_heads_k == 0
621 ), "Number of heads in key/value must divide number of heads in query"
622 if window_size_left >= seqlen_k:
623 window_size_left = -1
624 if window_size_right >= seqlen_k:
625 window_size_right = -1
626 if seqlen_q == 1 and alibi_slopes is None:
627 is_causal = False
628 if is_causal:
629 window_size_right = 0
631 is_causal = window_size_left < 0 and window_size_right == 0
632 is_local = window_size_left >= 0 and window_size_right >= 0
634 seqlenq_ngroups_swapped = (
635 seqlen_q == 1
636 and alibi_slopes is None
637 and num_heads > num_heads_k
638 and window_size_left < 0
639 and window_size_right < 0
640 and p_dropout == 0
641 )
642 q_groups = num_heads // num_heads_k
644 if seqlenq_ngroups_swapped:
645 logger.debug("q_kg swapped.")
646 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2)
647 seqlen_q = q_groups
648 num_heads = num_heads_k
650 round_multiple = lambda x, m: (x + m - 1) // m * m
651 head_size_rounded = round_multiple(head_size, 32)
652 seqlen_q_rounded = round_multiple(seqlen_q, 128)
653 seqlen_k_rounded = round_multiple(seqlen_k, 32)
655 assert (
656 head_size <= 256
657 ), "FlashAttention forward only supports head dimension at most 256"
658 assert head_size == head_size_rounded, "head_size must be rounded to 32"
660 def splits_heuristic(num_tasks, num_sms, n_blocks):
661 # splits when wave efficiency is low
662 n_waves = triton.cdiv(num_tasks, num_sms)
663 eff = (num_tasks / num_sms) / n_waves
664 if eff > 0.8 or n_waves > 1:
665 return 1
667 min_blocks_per_split = 2
668 best_splits = min(
669 triton.cdiv(n_blocks, min_blocks_per_split),
670 int(math.floor(1.0 / eff)),
671 num_sms,
672 )
674 return best_splits
676 with torch_device_fn.device(q_device):
677 # Set softmax params
678 lse = torch.empty(
679 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device
680 )
682 if out is not None:
683 if seqlenq_ngroups_swapped:
684 out = out.reshape(
685 batch_size, num_heads_k, q_groups, head_size
686 ).transpose(1, 2)
687 else:
688 out = torch.empty_like(q, dtype=v.dtype)
690 # Set dropout params
691 if p_dropout > 0:
692 is_dropout = True
693 increment = batch_size * num_heads * 32
694 philox_seed, philox_offset = philox_backend_seed_offset(increment)
695 philox_args = torch.tensor(
696 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
697 )
698 else:
699 is_dropout = False
700 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
702 p_dropout = 1 - p_dropout
703 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
704 rp_dropout = 1.0 / p_dropout
706 if return_softmax:
707 assert is_dropout, "Only supported with non-zero dropout."
708 p = torch.empty(
709 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
710 device=q_device,
711 )
712 else:
713 p = torch.empty((), device=q_device)
715 M_LOG2E = 1.4426950408889634074
716 if softcap > 0.0:
717 is_softcap = True
718 adjusted_scale_softmax = softcap
719 adjusted_softcap = softmax_scale / softcap
720 adjusted_scale_softmax_log2e = softcap * M_LOG2E
721 else:
722 is_softcap = False
723 adjusted_softcap = 0.0
724 adjusted_scale_softmax = softmax_scale
725 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
727 # Set alibi params
728 if alibi_slopes is not None:
729 assert alibi_slopes.device == q_device
730 assert alibi_slopes.dtype in (torch.float,)
731 assert alibi_slopes.stride(-1) == 1
732 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
733 batch_size,
734 num_heads,
735 )
736 alibi_slopes_batch_stride = (
737 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
738 )
739 is_alibi = True
740 else:
741 alibi_slopes_batch_stride = 0
742 is_alibi = False
744 # ONLY EVEN_K IS SUPPORTED
745 assert head_size == head_size_rounded
747 # Do kernel dispatching
748 def dispatch(B, H, Q, K, D, params):
749 num_sms = torch_device_fn.get_device_properties(
750 "cuda"
751 ).multi_processor_count
753 # Try bh parallel
754 # if B * H > 0.8 * num_sms:
755 # kernel = flash_fwd_bh_parallel_kernel[(H, B)]
756 # # Yield kernel and prefilled args
757 # return kernel, default_args, None, None
759 # Try splitkv
760 if not is_dropout and not is_local and not disable_splitkv:
761 BM = block_m_splitkv_heuristic(D)
762 n_tasks = B * H * triton.cdiv(seqlen_q, BM)
763 BN = block_n_splitkv_heuristic(D)
764 n_blocks = triton.cdiv(seqlen_k, BN)
765 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks)
767 if n_splits > 1:
768 logger.debug("kernel: flash_fwd_splitkv")
769 lse_splits = torch.empty(
770 (n_splits, B, H, Q), dtype=torch.float, device=q_device
771 )
772 out_splits = torch.empty(
773 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device
774 )
775 grid = lambda args: (
776 triton.cdiv(Q, args["BLOCK_M"]),
777 n_splits,
778 B * H,
779 )
780 splitkv_kernel = flash_fwd_splitkv_kernel[grid]
781 params.o_ptr = out_splits
782 params.softmax_lse_ptr = lse_splits
783 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)}
784 kernel = splitkv_kernel(*params.args(), **extra_args)
786 if D >= 128:
787 BLOCK_M = 4
788 elif D >= 64:
789 BLOCK_M = 8
790 else:
791 BLOCK_M = 16
792 BLOCK_K = triton.next_power_of_2(D)
793 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),)
794 combine_kernel = flash_fwd_splitkv_combine_kernel[grid]
795 combine_args = {
796 "out_ptr": out,
797 "lse_ptr": lse,
798 "head_size": head_size,
799 "out_split_stride": out_splits.stride(0),
800 "lse_split_stride": lse_splits.stride(0),
801 "out_b_stride": out.stride(0),
802 "out_s_stride": out.stride(-3),
803 "out_h_stride": out.stride(-1),
804 "out_splits_ptr": out_splits,
805 "lse_splits_ptr": lse_splits,
806 "n_splits": n_splits,
807 "BLOCK_M": BLOCK_M,
808 "BLOCK_K": BLOCK_K,
809 "q_total": B * H * Q,
810 "MAX_N_SPLITS": triton.next_power_of_2(n_splits),
811 }
812 combine_kernel(**combine_args)
813 return kernel
815 # Last option: flash_fwd
816 logger.debug("kernel: flash_fwd")
817 grid = lambda args: (
818 triton.cdiv(Q, args["BLOCK_M"]),
819 H * B,
820 )
821 kernel = flash_fwd_kernel[grid]
822 kernel = kernel(*params.args())
823 return kernel
825 if _debug:
826 p = torch.empty(
827 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
828 dtype=torch.float32,
829 device=q_device,
830 )
831 return_softmax = True
833 params = fwd_params(
834 q, # q_ptr,
835 k, # k_ptr,
836 v, # v_ptr,
837 out, # o_ptr,
838 p, # p_ptr,
839 lse, # softmax_lse_ptr,
840 q.stride(-3), # q_row_stride,
841 k.stride(-3), # k_row_stride,
842 v.stride(-3), # v_row_stride,
843 q.stride(-2), # q_head_stride,
844 k.stride(-2), # k_head_stride,
845 v.stride(-2), # v_head_stride,
846 out.stride(-3), # o_row_stride,
847 out.stride(-2), # o_head_stride,
848 q.stride(0), # q_batch_stride,
849 k.stride(0), # k_batch_stride,
850 v.stride(0), # v_batch_stride,
851 out.stride(0), # o_batch_stride,
852 False, # is_cu_seqlens_q,
853 None, # cu_seqlens_q_ptr,
854 False, # is_cu_seqlens_k,
855 None, # cu_seqlens_k_ptr,
856 False, # is_seqused_k,
857 None, # seqused_k_ptr,
858 # sizes
859 batch_size, # b,
860 0, # bk,
861 num_heads, # h,
862 num_heads_k, # hk,
863 num_heads // num_heads_k, # h_hk_ratio,
864 seqlen_q, # seqlen_q,
865 seqlen_k, # seqlen_k,
866 seqlen_q_rounded, # seqlen_q_rounded,
867 seqlen_k_rounded, # seqlen_k_rounded,
868 head_size, # d,
869 head_size_rounded, # d_rounded,
870 # scaling factors
871 is_softcap,
872 adjusted_softcap, # softcap,
873 adjusted_scale_softmax, # scale_softmax,
874 adjusted_scale_softmax_log2e, # scale_softmax_log2,
875 # dropout
876 is_dropout,
877 p_dropout,
878 rp_dropout,
879 p_dropout_in_uint8_t,
880 philox_args,
881 return_softmax,
882 # causal and swa
883 is_causal, # is_causal,
884 is_local, # is_local,
885 window_size_left, # window_size_left,
886 window_size_right, # window_size_right,
887 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
888 False, # is_paged,
889 # alibi
890 is_alibi, #
891 alibi_slopes, # alibi_slopes_ptr,
892 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
893 # block table params
894 0, # total_q,
895 None, # page_table_ptr,
896 0, # page_table_batch_stride,
897 0, # block_size,
898 )
900 # Move TxD to last dims for correct stride in Triton tt.load
901 if flag_gems.vendor_name == "iluvatar":
902 params.q_ptr = q.transpose(1, 2)
903 params.k_ptr = k.transpose(1, 2)
904 params.v_ptr = v.transpose(1, 2)
905 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params)
907 if _debug:
908 print(f"{kernel.name} shared memory:", kernel.metadata.shared)
909 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps)
910 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages)
911 # print(kernel.asm['ttgir'])
913 if seqlenq_ngroups_swapped:
914 out = out.transpose(1, 2).reshape(
915 (batch_size, 1, num_heads_k * seqlen_q, head_size)
916 )
917 q = q.transpose(1, 2).reshape(
918 (batch_size, 1, num_heads_k * seqlen_q, head_size)
919 )
920 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1))
922 unused = torch.empty((), dtype=torch.int64, device=q_device)
924 return out, q, k, v, lse, philox_args, unused, p