Coverage for src/flag_gems/fused/deepseek_v4_attention_combine_topk_swa_indices.py: 48%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1from typing import Tuple
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
9_SPARSE_PREFILL_TOPK_ALIGNMENT = 128
12def _next_power_of_2_or_1(x: int) -> int:
13 return 1 if x <= 1 else triton.next_power_of_2(x)
16@triton.jit
17def _combine_topk_swa_indices_kernel(
18 combined_ptr,
19 combined_stride,
20 lens_ptr,
21 topk_ptr,
22 topk_stride,
23 query_start_loc_ptr,
24 seq_lens_ptr,
25 gather_lens_ptr,
26 M,
27 N,
28 TOP_K: tl.constexpr,
29 COMPRESS_RATIO: tl.constexpr,
30 WINDOW_SIZE: tl.constexpr,
31 PADDED_TOP_K: tl.constexpr,
32 PADDED_WINDOW_SIZE: tl.constexpr,
33):
34 batch_idx = tl.program_id(0)
35 worker_idx = tl.program_id(1)
36 num_workers = tl.num_programs(1)
37 base = tl.load(query_start_loc_ptr)
38 query_start = tl.load(query_start_loc_ptr + batch_idx) - base
39 query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - base
40 query_len = query_end - query_start
41 seq_len = tl.load(seq_lens_ptr + batch_idx)
42 gather_len = tl.load(gather_lens_ptr + batch_idx)
43 start_pos = seq_len - query_len
44 gather_start = seq_len - gather_len
46 for token_idx in range(query_start + worker_idx, query_end, num_workers):
47 token_in_query = token_idx - query_start
48 pos = start_pos + token_in_query
49 topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K)
50 swa_len = tl.minimum(pos + 1, WINDOW_SIZE)
52 offs = tl.arange(0, PADDED_TOP_K)
53 mask = offs < topk_len
54 topk_vals = tl.load(
55 topk_ptr + token_idx * topk_stride + offs, mask=mask, other=-1
56 )
57 tl.store(
58 combined_ptr + token_idx * combined_stride + offs,
59 topk_vals + M * batch_idx,
60 mask=mask,
61 )
63 swa_offs = tl.arange(0, PADDED_WINDOW_SIZE)
64 tl.store(
65 combined_ptr + token_idx * combined_stride + topk_len + swa_offs,
66 M * batch_idx + N + swa_offs + pos - swa_len + 1 - gather_start,
67 mask=(swa_offs < swa_len) & (swa_offs < WINDOW_SIZE),
68 )
69 tl.store(lens_ptr + token_idx, topk_len + swa_len)
72def combine_topk_swa_indices(
73 topk_indices: torch.Tensor,
74 query_start_loc: torch.Tensor,
75 seq_lens: torch.Tensor,
76 gather_lens: torch.Tensor,
77 window_size: int,
78 compress_ratio: int,
79 topk: int,
80 M: int,
81 N: int,
82) -> Tuple[torch.Tensor, torch.Tensor]:
83 assert topk_indices.ndim == 2
84 num_tokens = topk_indices.shape[0]
85 num_reqs = seq_lens.shape[0]
86 combined_topk = (
87 (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1)
88 // _SPARSE_PREFILL_TOPK_ALIGNMENT
89 * _SPARSE_PREFILL_TOPK_ALIGNMENT
90 )
91 combined = torch.full(
92 (num_tokens, combined_topk), -1, device=topk_indices.device, dtype=torch.int32
93 )
94 lens = torch.empty((num_tokens,), device=topk_indices.device, dtype=torch.int32)
95 with torch_device_fn.device(topk_indices.device):
96 _combine_topk_swa_indices_kernel[(num_reqs, 128)](
97 combined,
98 combined.stride(0),
99 lens,
100 topk_indices,
101 topk_indices.stride(0),
102 query_start_loc,
103 seq_lens,
104 gather_lens,
105 M,
106 N,
107 TOP_K=topk,
108 COMPRESS_RATIO=compress_ratio,
109 WINDOW_SIZE=window_size,
110 PADDED_TOP_K=_next_power_of_2_or_1(topk_indices.shape[-1]),
111 PADDED_WINDOW_SIZE=_next_power_of_2_or_1(window_size),
112 )
113 return combined, lens
116__all__ = ["combine_topk_swa_indices"]