Coverage for src/flag_gems/fused/top_k_per_row_prefill.py: 44%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""Triton top_k_per_row_prefill for DeepSeek V4 sparse attention.
3Replaces vLLM's persistent_topk CUDA kernel with a Triton implementation.
5Background:
6 In DeepSeek V4 prefill, each token computes attention logits over a subset of
7 the vocabulary [row_starts[i], row_ends[i]) and selects the top-K indices.
8 Typical config: vocab_size=129280, top_k=1024, num_rows=1 (decode) or 32+ (prefill).
10Strategy:
11 1. In-place masking kernel: set logits outside [row_starts, row_ends) to -inf.
12 Early exit when the row uses full vocab (start==0, end>=vocab_size), which is
13 the common case during inference and avoids unnecessary memory writes.
14 2. Adaptive top-K selection:
15 - num_rows=1: torch.argsort (backed by CUB radix sort, O(N) for single row,
16 ~2x faster than torch.topk for large vocab on a single row)
17 - num_rows>1: torch.topk with sorted=False (heap-based O(N log k), better
18 parallelism across rows than argsort)
19 3. Fused postprocess kernel: single Triton kernel performs slice + cast + subtract
20 in one pass, converting absolute vocab indices to 0-based indices relative to
21 row_starts[i]. Saves one kernel launch vs separate slice/subtract ops.
23Performance (DeepSeek V4 config, vocab=129280, top_k=1024):
24 - num_rows=1: 0.89x vs vLLM CUDA (competitive, bounded by argsort)
25 - num_rows=32: 0.38x vs vLLM CUDA (bounded by torch.topk on large vocab)
26"""
28import torch
29import triton
30import triton.language as tl
33@triton.jit
34def _mask_invalid_kernel(
35 logits_ptr,
36 row_starts_ptr,
37 row_ends_ptr,
38 stride0, # logits row stride (= vocab_size for contiguous tensor)
39 BLOCK_SIZE: tl.constexpr, # 8192: tuned for 129280 vocab (16 blocks/row)
40 VOCAB_SIZE: tl.constexpr, # total vocabulary size (e.g. 129280)
41):
42 """Mask logits outside [row_starts[i], row_ends[i]) to -inf, in-place.
44 Grid: (num_rows * num_blocks_per_row,) — 1D flat grid.
45 Each program handles one BLOCK_SIZE chunk of one row.
46 Early exits when the row uses full vocab to avoid unnecessary stores.
47 """
48 pid = tl.program_id(0)
49 num_blocks_per_row = tl.cdiv(VOCAB_SIZE, BLOCK_SIZE)
50 row_id = pid // num_blocks_per_row
51 block_id = pid % num_blocks_per_row
53 start = tl.load(row_starts_ptr + row_id)
54 end = tl.load(row_ends_ptr + row_id)
56 # Early exit: most rows in inference use full vocab (start=0, end=vocab_size).
57 # Skipping these avoids ~90% of memory writes in typical workloads.
58 if start == 0 and end >= VOCAB_SIZE:
59 return
61 # Compute which positions in this block are outside the valid range
62 offs = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
63 out_of_range = (offs < start) | (offs >= end)
64 # Only write to positions that are both within vocab bounds AND out of valid range
65 mask = (offs < VOCAB_SIZE) & out_of_range
67 tl.store(logits_ptr + row_id * stride0 + offs, float("-inf"), mask=mask)
70@triton.jit
71def _fused_postprocess_kernel(
72 src_ptr, # source indices (from argsort or topk)
73 dst_ptr, # destination: output indices buffer [num_rows, top_k]
74 row_starts_ptr, # per-row start offsets for index adjustment
75 num_rows: tl.constexpr,
76 top_k: tl.constexpr, # 1024 in DeepSeek V4
77 src_stride0: tl.constexpr, # row stride of src (vocab_size for argsort, top_k for topk)
78 BLOCK_SIZE: tl.constexpr, # next_power_of_2(top_k), e.g. 1024
79):
80 """Fused slice + cast + subtract: convert absolute indices to row-relative.
82 For each row i, computes: dst[i, :top_k] = src[i, :top_k] - row_starts[i]
83 This converts absolute vocab indices to 0-based indices within the valid range.
84 Grid: (num_rows,) — one program per row.
85 """
86 row_id = tl.program_id(0)
87 if row_id >= num_rows:
88 return
90 row_start = tl.load(row_starts_ptr + row_id)
92 offs = tl.arange(0, BLOCK_SIZE)
93 mask = offs < top_k
95 src_idx = row_id * src_stride0 + offs
96 src_vals = tl.load(src_ptr + src_idx, mask=mask, other=0)
98 # Subtract row_start to get 0-based index within [row_start, row_end)
99 dst_vals = (src_vals - row_start).to(tl.int32)
101 dst_idx = row_id * top_k + offs
102 tl.store(dst_ptr + dst_idx, dst_vals, mask=mask)
105def top_k_per_row_prefill(
106 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k
107):
108 """Top-K per row for prefill phase of DeepSeek V4 sparse attention.
110 Masks invalid ranges in-place, then selects top-K indices per row.
111 Output indices are 0-based relative to row_starts[i].
113 Args:
114 logits: [num_rows, vocab_size] float32 tensor, modified in-place (masked to -inf).
115 In DeepSeek V4: vocab_size=129280.
116 row_starts: [num_rows] int32 — start of valid range per row (inclusive).
117 row_ends: [num_rows] int32 — end of valid range per row (exclusive).
118 indices: [num_rows, top_k] int32 — output buffer, filled with 0-based indices
119 relative to row_starts[i]. Caller pre-allocates this.
120 num_rows: number of rows (1 for decode, 32/64/2048 for prefill batches).
121 stride0: logits.stride(0), typically == vocab_size for contiguous tensor.
122 stride1: logits.stride(1), typically == 1 for contiguous tensor.
123 top_k: number of top elements per row (1024 in DeepSeek V4).
124 """
125 vocab_size = logits.shape[1]
127 if top_k > vocab_size:
128 raise ValueError(f"top_k ({top_k}) must not exceed vocab_size ({vocab_size})")
130 # --- Phase 1: Mask invalid ranges to -inf ---
131 # BLOCK_SIZE=8192 chosen to balance occupancy vs. grid size:
132 # For vocab=129280, this gives ceil(129280/8192)=16 blocks per row.
133 # num_warps=2 is sufficient since masking is memory-bound (simple store).
134 MASK_BS = 8192
135 num_mask_blocks = (vocab_size + MASK_BS - 1) // MASK_BS
136 _mask_invalid_kernel[(num_rows * num_mask_blocks,)](
137 logits,
138 row_starts,
139 row_ends,
140 stride0,
141 BLOCK_SIZE=MASK_BS,
142 VOCAB_SIZE=vocab_size,
143 num_warps=2,
144 )
146 # --- Phase 2: Select top-K indices ---
147 # POSTPROC_BLOCK must be power-of-2 >= top_k for tl.arange.
148 # For top_k=1024, this is exactly 1024 (no waste).
149 POSTPROC_BLOCK = triton.next_power_of_2(top_k)
151 if num_rows == 1:
152 # Single row path: torch.argsort uses CUB radix sort under the hood.
153 # For large vocab (129280) with a single row, radix sort O(N) is ~2x faster
154 # than torch.topk's heap-based O(N log k) because it fully utilizes GPU
155 # parallelism without the sequential heap maintenance bottleneck.
156 sorted_idx = torch.argsort(logits, dim=1, descending=True, stable=False)
157 # src_stride0=vocab_size because argsort returns full-width sorted indices
158 _fused_postprocess_kernel[(1,)](
159 sorted_idx,
160 indices,
161 row_starts,
162 num_rows=1,
163 top_k=top_k,
164 src_stride0=vocab_size,
165 BLOCK_SIZE=POSTPROC_BLOCK,
166 num_warps=4,
167 )
168 else:
169 # Multi-row path: torch.topk with sorted=False.
170 # For batched rows, topk's heap approach has better parallelism across rows
171 # than argsort (which serializes the full sort per row).
172 # sorted=False avoids an unnecessary final sort pass.
173 _, top_idx = torch.topk(logits, top_k, dim=1, largest=True, sorted=False)
174 # src_stride0=top_k because topk output shape is [num_rows, top_k]
175 _fused_postprocess_kernel[(num_rows,)](
176 top_idx,
177 indices,
178 row_starts,
179 num_rows=num_rows,
180 top_k=top_k,
181 src_stride0=top_k,
182 BLOCK_SIZE=POSTPROC_BLOCK,
183 num_warps=4,
184 )