Coverage for src/flag_gems/fused/top_k_per_row_decode.py: 7%
305 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1"""Triton top_k_per_row_decode for DeepSeek V4 decode-phase token selection.
3Replaces vLLM's top_k_per_row_decode CUDA kernel with a pure Triton
4implementation using radix-select (4-iteration 8-bit histogram radix).
6Background:
7 In DeepSeek V4 decode, each step selects the top-K token indices from a
8 single row of logits [1, vocab_size]. The vLLM CUDA kernel uses a
9 radix-based approach; this Triton kernel matches that strategy with
10 three dispatch tiers optimized for different vocab sizes.
12Strategy:
13 1. Single-block path (vocab_size <= 8192): All data fits in one thread
14 block's registers. Four radix iterations with tl.histogram, no
15 inter-block synchronization, no global memory scratch.
16 2. Medium multi-block path (8192 < vocab_size <= 32768): All blocks
17 participate in all 4 radix iterations. Double-buffered per-block
18 histograms with 4 barriers (1 per iteration). Eliminates serial
19 block-0 bottleneck seen in buffer-based approaches.
20 3. Large multi-block path (vocab_size > 32768): First radix iteration
21 runs across all blocks with per-block histograms + barrier. Remaining
22 3 iterations run on block-0 only using a compacted buffer, avoiding
23 barrier overhead for high block counts.
25Performance (DeepSeek V4 config, H20 GPU):
26 - vocab=129280, k=1024: 1.82x faster than vLLM CUDA
27 - vocab=32768, k=512: 0.78x vs vLLM CUDA
28 - vocab=8192, k=128: 0.50x vs vLLM CUDA
29"""
31import logging
33import torch
34import triton
35import triton.language as tl
37logger = logging.getLogger(__name__)
39_SIGN_BIT = tl.constexpr(-(1 << 31))
42@triton.jit
43def _float_to_sortable(val):
44 """Convert IEEE 754 float to order-preserving unsigned integer.
46 XOR with sign-dependent mask so that sorted int order == sorted float order.
47 """
48 bits = val.to(tl.int32, bitcast=True)
49 sign_ext = bits >> 31
50 mask = sign_ext | tl.full(bits.shape, _SIGN_BIT, dtype=tl.int32)
51 return bits ^ mask
54@triton.jit
55def _topk_single_block(
56 logits_ptr,
57 seq_len_ptr,
58 indices_ptr,
59 stride1,
60 N: tl.constexpr,
61 BLOCK: tl.constexpr,
62 TOP_K: tl.constexpr,
63):
64 """Single-block radix select: all 4 iterations in-register, no barriers."""
65 offs = tl.arange(0, BLOCK)
66 seq_len = tl.load(seq_len_ptr)
67 valid = (offs < N) & (offs < seq_len)
69 vals = tl.load(logits_ptr + offs * stride1, mask=valid, other=float("-inf"))
70 sortable = _float_to_sortable(vals)
72 bins = tl.arange(0, 256)
74 # Radix iteration 0: byte 3 (MSB)
75 bucket_0 = (sortable >> 24) & 0xFF
76 counts_0 = tl.histogram(bucket_0, 256, mask=valid)
77 total_0 = tl.sum(counts_0)
78 ps_0 = tl.cumsum(counts_0, axis=0)
79 ss_0 = total_0 - ps_0 + counts_0
80 pivot_0 = tl.max(tl.where(ss_0 >= TOP_K, bins, -1))
81 ca_0 = tl.sum(tl.where(bins > pivot_0, counts_0, 0))
82 remaining_k = TOP_K - ca_0
83 match_0 = (bucket_0 == pivot_0) & valid
85 # Radix iteration 1: byte 2
86 bucket_1 = (sortable >> 16) & 0xFF
87 counts_1 = tl.histogram(bucket_1, 256, mask=match_0)
88 total_1 = tl.sum(counts_1)
89 ps_1 = tl.cumsum(counts_1, axis=0)
90 ss_1 = total_1 - ps_1 + counts_1
91 pivot_1 = tl.max(tl.where(ss_1 >= remaining_k, bins, -1))
92 ca_1 = tl.sum(tl.where(bins > pivot_1, counts_1, 0))
93 remaining_k = remaining_k - ca_1
94 match_1 = match_0 & (bucket_1 == pivot_1)
96 # Radix iteration 2: byte 1
97 bucket_2 = (sortable >> 8) & 0xFF
98 counts_2 = tl.histogram(bucket_2, 256, mask=match_1)
99 total_2 = tl.sum(counts_2)
100 ps_2 = tl.cumsum(counts_2, axis=0)
101 ss_2 = total_2 - ps_2 + counts_2
102 pivot_2 = tl.max(tl.where(ss_2 >= remaining_k, bins, -1))
103 ca_2 = tl.sum(tl.where(bins > pivot_2, counts_2, 0))
104 remaining_k = remaining_k - ca_2
105 match_2 = match_1 & (bucket_2 == pivot_2)
107 # Radix iteration 3: byte 0 (LSB)
108 bucket_3 = sortable & 0xFF
109 counts_3 = tl.histogram(bucket_3, 256, mask=match_2)
110 total_3 = tl.sum(counts_3)
111 ps_3 = tl.cumsum(counts_3, axis=0)
112 ss_3 = total_3 - ps_3 + counts_3
113 pivot_3 = tl.max(tl.where(ss_3 >= remaining_k, bins, -1))
114 ca_3 = tl.sum(tl.where(bins > pivot_3, counts_3, 0))
115 remaining_k = remaining_k - ca_3
117 # Selection: write indices for elements above threshold, then equal
118 threshold = (pivot_0 << 24) | (pivot_1 << 16) | (pivot_2 << 8) | pivot_3
119 above_total = TOP_K - remaining_k
121 s_shifted = sortable ^ tl.full(sortable.shape, _SIGN_BIT, dtype=tl.int32)
122 t_shifted = threshold ^ _SIGN_BIT
124 above = (s_shifted > t_shifted) & valid
125 equal = (sortable == threshold) & valid
127 pa = tl.cumsum(above.to(tl.int32), axis=0)
128 tl.store(
129 indices_ptr + pa - 1,
130 offs.to(tl.int32),
131 mask=above & (pa - 1 >= 0) & (pa - 1 < TOP_K),
132 )
134 pe = tl.cumsum(equal.to(tl.int32), axis=0)
135 wpe = above_total + pe - 1
136 tl.store(
137 indices_ptr + wpe,
138 offs.to(tl.int32),
139 mask=equal & ((pe - 1) < remaining_k) & (wpe >= 0) & (wpe < TOP_K),
140 )
143@triton.jit
144def _topk_medium_block(
145 logits_ptr,
146 seq_len_ptr,
147 pb_hist_a_ptr,
148 pb_hist_b_ptr,
149 sync_ptr,
150 counter_ptr,
151 indices_ptr,
152 stride1,
153 N: tl.constexpr,
154 NUM_BLOCKS: tl.constexpr,
155 BLOCK: tl.constexpr,
156 TOP_K: tl.constexpr,
157):
158 """Multi-block radix select for medium vocab (8K-32K).
160 All blocks participate in all 4 radix iterations using double-buffered
161 per-block histograms. 4 barriers total (1 per iteration).
162 """
163 pid = tl.program_id(0)
164 offs = pid * BLOCK + tl.arange(0, BLOCK)
165 seq_len = tl.load(seq_len_ptr)
166 valid = (offs < N) & (offs < seq_len)
168 vals = tl.load(logits_ptr + offs * stride1, mask=valid, other=float("-inf"))
169 sortable = _float_to_sortable(vals)
171 bins = tl.arange(0, 256)
172 ha_base = pb_hist_a_ptr + pid * 256
173 hb_base = pb_hist_b_ptr + pid * 256
175 # Iteration 0: byte 3 (MSB), write to buf_A
176 bucket_0 = (sortable >> 24) & 0xFF
177 local_hist_0 = tl.histogram(bucket_0, 256, mask=valid)
178 tl.store(ha_base + bins, local_hist_0)
180 tl.debug_barrier()
181 tl.atomic_add(sync_ptr, 1)
182 while tl.atomic_add(sync_ptr, 0) < NUM_BLOCKS:
183 pass
185 counts = tl.zeros([256], dtype=tl.int32)
186 for i in tl.static_range(NUM_BLOCKS):
187 counts += tl.load(pb_hist_a_ptr + i * 256 + bins)
189 total_0 = tl.sum(counts)
190 ps_0 = tl.cumsum(counts, axis=0)
191 ss_0 = total_0 - ps_0 + counts
192 pivot_0 = tl.max(tl.where(ss_0 >= TOP_K, bins, -1))
193 ca_0 = tl.sum(tl.where(bins > pivot_0, counts, 0))
194 remaining_k = TOP_K - ca_0
195 match = (bucket_0 == pivot_0) & valid
197 # Iteration 1: byte 2, write to buf_B
198 bucket_1 = (sortable >> 16) & 0xFF
199 local_hist_1 = tl.histogram(bucket_1, 256, mask=match)
200 tl.store(hb_base + bins, local_hist_1)
202 tl.debug_barrier()
203 tl.atomic_add(sync_ptr + 1, 1)
204 while tl.atomic_add(sync_ptr + 1, 0) < NUM_BLOCKS:
205 pass
207 counts = tl.zeros([256], dtype=tl.int32)
208 for i in tl.static_range(NUM_BLOCKS):
209 counts += tl.load(pb_hist_b_ptr + i * 256 + bins)
211 total_1 = tl.sum(counts)
212 ps_1 = tl.cumsum(counts, axis=0)
213 ss_1 = total_1 - ps_1 + counts
214 pivot_1 = tl.max(tl.where(ss_1 >= remaining_k, bins, -1))
215 ca_1 = tl.sum(tl.where(bins > pivot_1, counts, 0))
216 remaining_k = remaining_k - ca_1
217 match = match & (bucket_1 == pivot_1)
219 # Iteration 2: byte 1, write to buf_A
220 bucket_2 = (sortable >> 8) & 0xFF
221 local_hist_2 = tl.histogram(bucket_2, 256, mask=match)
222 tl.store(ha_base + bins, local_hist_2)
224 tl.debug_barrier()
225 tl.atomic_add(sync_ptr + 2, 1)
226 while tl.atomic_add(sync_ptr + 2, 0) < NUM_BLOCKS:
227 pass
229 counts = tl.zeros([256], dtype=tl.int32)
230 for i in tl.static_range(NUM_BLOCKS):
231 counts += tl.load(pb_hist_a_ptr + i * 256 + bins)
233 total_2 = tl.sum(counts)
234 ps_2 = tl.cumsum(counts, axis=0)
235 ss_2 = total_2 - ps_2 + counts
236 pivot_2 = tl.max(tl.where(ss_2 >= remaining_k, bins, -1))
237 ca_2 = tl.sum(tl.where(bins > pivot_2, counts, 0))
238 remaining_k = remaining_k - ca_2
239 match = match & (bucket_2 == pivot_2)
241 # Iteration 3: byte 0 (LSB), write to buf_B
242 bucket_3 = sortable & 0xFF
243 local_hist_3 = tl.histogram(bucket_3, 256, mask=match)
244 tl.store(hb_base + bins, local_hist_3)
246 tl.debug_barrier()
247 tl.atomic_add(sync_ptr + 3, 1)
248 while tl.atomic_add(sync_ptr + 3, 0) < NUM_BLOCKS:
249 pass
251 counts = tl.zeros([256], dtype=tl.int32)
252 for i in tl.static_range(NUM_BLOCKS):
253 counts += tl.load(pb_hist_b_ptr + i * 256 + bins)
255 total_3 = tl.sum(counts)
256 ps_3 = tl.cumsum(counts, axis=0)
257 ss_3 = total_3 - ps_3 + counts
258 pivot_3 = tl.max(tl.where(ss_3 >= remaining_k, bins, -1))
259 ca_3 = tl.sum(tl.where(bins > pivot_3, counts, 0))
260 remaining_k = remaining_k - ca_3
262 # Selection phase
263 threshold = (pivot_0 << 24) | (pivot_1 << 16) | (pivot_2 << 8) | pivot_3
264 above_total = TOP_K - remaining_k
266 s_shifted = sortable ^ tl.full(sortable.shape, _SIGN_BIT, dtype=tl.int32)
267 t_shifted = threshold ^ _SIGN_BIT
269 above = (s_shifted > t_shifted) & valid
270 equal = (sortable == threshold) & valid
272 n_above = tl.sum(above.to(tl.int32))
273 if n_above > 0:
274 pa = tl.cumsum(above.to(tl.int32), axis=0)
275 base_a = tl.atomic_add(counter_ptr, n_above)
276 wp = base_a + pa - 1
277 tl.store(
278 indices_ptr + wp,
279 offs.to(tl.int32),
280 mask=above & (wp >= 0) & (wp < TOP_K),
281 )
283 n_equal = tl.sum(equal.to(tl.int32))
284 if n_equal > 0:
285 pe = tl.cumsum(equal.to(tl.int32), axis=0)
286 base_e = tl.atomic_add(counter_ptr + 1, n_equal)
287 wpe = above_total + base_e + pe - 1
288 tl.store(
289 indices_ptr + wpe,
290 offs.to(tl.int32),
291 mask=equal & ((base_e + pe - 1) < remaining_k) & (wpe >= 0) & (wpe < TOP_K),
292 )
294 # Zero shared state for next call
295 if pid == 0:
296 tl.store(sync_ptr + tl.arange(0, 4), tl.zeros([4], dtype=tl.int32))
297 tl.store(counter_ptr, 0)
298 tl.store(counter_ptr + 1, 0)
301@triton.jit
302def _topk_multi_block(
303 logits_ptr,
304 seq_len_ptr,
305 pb_hist_ptr,
306 sync_ptr,
307 buf_val_ptr,
308 buf_idx_ptr,
309 counter_ptr,
310 indices_ptr,
311 stride1,
312 N: tl.constexpr,
313 NUM_BLOCKS: tl.constexpr,
314 BLOCK: tl.constexpr,
315 TOP_K: tl.constexpr,
316 BUF_SIZE: tl.constexpr,
317):
318 """Multi-block radix select for large vocab (>32K).
320 Iteration 0: all blocks compute byte-3 histograms + barrier + reduce.
321 Iterations 1-3: block-0 only, operating on a compacted buffer of
322 elements matching the byte-3 pivot. Avoids barrier overhead for
323 high block counts (e.g. 32 blocks for vocab=129280).
324 """
325 pid = tl.program_id(0)
326 offs = pid * BLOCK + tl.arange(0, BLOCK)
327 seq_len = tl.load(seq_len_ptr)
328 valid = (offs < N) & (offs < seq_len)
330 vals = tl.load(logits_ptr + offs * stride1, mask=valid, other=float("-inf"))
331 sortable = _float_to_sortable(vals)
333 # Iteration 0: all blocks compute byte-3 histogram
334 bucket = (sortable >> 24) & 0xFF
335 local_hist = tl.histogram(bucket, 256, mask=valid)
337 bins = tl.arange(0, 256)
338 h_base = pb_hist_ptr + pid * 256
339 tl.store(h_base + bins, local_hist)
341 tl.debug_barrier()
342 tl.atomic_add(sync_ptr, 1)
343 while tl.atomic_add(sync_ptr, 0) < NUM_BLOCKS:
344 pass
346 counts = tl.zeros([256], dtype=tl.int32)
347 for i in tl.static_range(NUM_BLOCKS):
348 counts += tl.load(pb_hist_ptr + i * 256 + bins)
350 total = tl.sum(counts)
351 ps = tl.cumsum(counts, axis=0)
352 ss = total - ps + counts
353 pivot_0 = tl.max(tl.where(ss >= TOP_K, bins, -1))
354 count_above_0 = tl.sum(tl.where(bins > pivot_0, counts, 0))
355 remaining_k = TOP_K - count_above_0
357 above = (bucket > pivot_0) & valid
358 match = (bucket == pivot_0) & valid
360 # Write above-threshold indices directly to output
361 n_above = tl.sum(above.to(tl.int32))
362 if n_above > 0:
363 pa = tl.cumsum(above.to(tl.int32), axis=0)
364 base_a = tl.atomic_add(counter_ptr, n_above)
365 wp = base_a + pa - 1
366 tl.store(
367 indices_ptr + wp,
368 offs.to(tl.int32),
369 mask=above & (wp >= 0) & (wp < TOP_K),
370 )
372 # Compact matching elements into buffer for block-0
373 n_match = tl.sum(match.to(tl.int32))
374 if n_match > 0:
375 pm = tl.cumsum(match.to(tl.int32), axis=0)
376 base_m = tl.atomic_add(counter_ptr + 1, n_match)
377 bp = base_m + pm - 1
378 tl.store(
379 buf_val_ptr + bp,
380 sortable,
381 mask=match & (bp >= 0) & (bp < BUF_SIZE),
382 )
383 tl.store(
384 buf_idx_ptr + bp,
385 offs.to(tl.int32),
386 mask=match & (bp >= 0) & (bp < BUF_SIZE),
387 )
389 # Iterations 1-3: block-0 processes compacted buffer
390 tl.debug_barrier()
391 tl.atomic_add(sync_ptr + 1, 1)
392 if pid == 0:
393 while tl.atomic_add(sync_ptr + 1, 0) < NUM_BLOCKS:
394 pass
396 buf_count = tl.atomic_add(counter_ptr + 1, 0)
398 b_offs = tl.arange(0, BUF_SIZE)
399 b_valid = b_offs < buf_count
400 b_vals = tl.load(buf_val_ptr + b_offs, mask=b_valid, other=0)
401 b_idxs = tl.load(buf_idx_ptr + b_offs, mask=b_valid, other=0)
403 # Iteration 1: byte 2
404 b_byte_1 = (b_vals >> 16) & 0xFF
405 counts_1 = tl.histogram(b_byte_1, 256, mask=b_valid)
406 total_1 = tl.sum(counts_1)
407 ps_1 = tl.cumsum(counts_1, axis=0)
408 ss_1 = total_1 - ps_1 + counts_1
409 pivot_1 = tl.max(tl.where(ss_1 >= remaining_k, bins, -1))
410 ca_1 = tl.sum(tl.where(bins > pivot_1, counts_1, 0))
411 remaining_k = remaining_k - ca_1
413 # Iteration 2: byte 1
414 prefix_hi16 = (pivot_0 << 8) | pivot_1
415 upper16 = (b_vals >> 16) & 0xFFFF
416 b_match_2 = (upper16 == prefix_hi16) & b_valid
417 b_bucket_2 = (b_vals >> 8) & 0xFF
418 counts_2 = tl.histogram(b_bucket_2, 256, mask=b_match_2)
419 total_2 = tl.sum(counts_2)
420 ps_2 = tl.cumsum(counts_2, axis=0)
421 ss_2 = total_2 - ps_2 + counts_2
422 pivot_2 = tl.max(tl.where(ss_2 >= remaining_k, bins, -1))
423 ca_2 = tl.sum(tl.where(bins > pivot_2, counts_2, 0))
424 remaining_k = remaining_k - ca_2
426 # Iteration 3: byte 0 (LSB)
427 prefix_hi24 = (prefix_hi16 << 8) | pivot_2
428 upper24 = (b_vals >> 8) & 0xFFFFFF
429 b_match_3 = (upper24 == prefix_hi24) & b_valid
430 b_bucket_3 = b_vals & 0xFF
431 counts_3 = tl.histogram(b_bucket_3, 256, mask=b_match_3)
432 total_3 = tl.sum(counts_3)
433 ps_3 = tl.cumsum(counts_3, axis=0)
434 ss_3 = total_3 - ps_3 + counts_3
435 pivot_3 = tl.max(tl.where(ss_3 >= remaining_k, bins, -1))
436 ca_3 = tl.sum(tl.where(bins > pivot_3, counts_3, 0))
437 remaining_k = remaining_k - ca_3
439 # Final selection from buffer
440 threshold = (prefix_hi24 << 8) | pivot_3
441 above_total = TOP_K - remaining_k
443 s_sh = b_vals ^ tl.full(b_vals.shape, _SIGN_BIT, dtype=tl.int32)
444 t_sh = threshold ^ _SIGN_BIT
446 above_buf = (s_sh > t_sh) & b_valid
447 equal_buf = (b_vals == threshold) & b_valid
449 pa_b = tl.cumsum(above_buf.to(tl.int32), axis=0)
450 wp_b = count_above_0 + pa_b - 1
451 tl.store(
452 indices_ptr + wp_b,
453 b_idxs,
454 mask=above_buf & (wp_b >= 0) & (wp_b < TOP_K),
455 )
457 pe_b = tl.cumsum(equal_buf.to(tl.int32), axis=0)
458 wpe_b = above_total + pe_b - 1
459 tl.store(
460 indices_ptr + wpe_b,
461 b_idxs,
462 mask=equal_buf
463 & ((pe_b - 1) < remaining_k)
464 & (wpe_b >= 0)
465 & (wpe_b < TOP_K),
466 )
468 tl.store(sync_ptr, 0)
469 tl.store(sync_ptr + 1, 0)
470 tl.store(counter_ptr, 0)
471 tl.store(counter_ptr + 1, 0)
474# Persistent scratch buffers, keyed by (device_index, dispatch_tier).
475# Allocated once per device and reused across calls to avoid cudaMalloc overhead.
476_cache = {}
478# Dispatch thresholds for the three kernel tiers
479_SINGLE_BLOCK_LIMIT = 8192
480_MEDIUM_BLOCK_LIMIT = 32768
481_MEDIUM_BLOCK_SIZE = 4096
482_LARGE_BLOCK_SIZE = 4096
483_LARGE_BUF_SIZE = 4096
486def top_k_per_row_decode(
487 logits, next_n, seq_lens, indices, num_rows, stride0, stride1, top_k
488):
489 """Top-K per row for decode phase of DeepSeek V4.
491 Selects top_k indices from a single row of logits using radix-based
492 selection. Only valid elements within [0, seq_lens[0]) are considered.
494 Args:
495 logits: [1, vocab_size] float32 tensor.
496 next_n: number of next tokens (unused, kept for API compatibility).
497 seq_lens: [1] int32 — valid range [0, seq_lens[0]).
498 indices: [1, top_k] int32 — output buffer, filled with selected indices.
499 num_rows: must be 1 (decode processes one row at a time).
500 stride0: logits.stride(0).
501 stride1: logits.stride(1).
502 top_k: number of top elements to select.
503 """
504 logger.debug("GEMS TOP_K_PER_ROW_DECODE")
506 assert num_rows == 1, "Only num_rows=1 supported in decode path"
508 vocab_size = logits.shape[1]
509 device = logits.device
510 ind = indices.view(-1)
512 if vocab_size <= _SINGLE_BLOCK_LIMIT // 2:
513 # Small vocab: single block with BLOCK=4096
514 _topk_single_block[(1,)](
515 logits,
516 seq_lens,
517 ind,
518 stride1,
519 N=vocab_size,
520 BLOCK=_SINGLE_BLOCK_LIMIT // 2,
521 TOP_K=top_k,
522 num_warps=8,
523 )
524 elif vocab_size <= _SINGLE_BLOCK_LIMIT:
525 # Medium-small vocab: single block with BLOCK=8192
526 _topk_single_block[(1,)](
527 logits,
528 seq_lens,
529 ind,
530 stride1,
531 N=vocab_size,
532 BLOCK=_SINGLE_BLOCK_LIMIT,
533 TOP_K=top_k,
534 num_warps=16,
535 )
536 elif vocab_size <= _MEDIUM_BLOCK_LIMIT:
537 # Medium vocab: double-buffered all-blocks radix
538 n_blocks = (vocab_size + _MEDIUM_BLOCK_SIZE - 1) // _MEDIUM_BLOCK_SIZE
539 dev_idx = device.index if device.index is not None else 0
540 key = (dev_idx, "med")
541 if key not in _cache:
542 max_nb = (
543 _MEDIUM_BLOCK_LIMIT + _MEDIUM_BLOCK_SIZE - 1
544 ) // _MEDIUM_BLOCK_SIZE
545 pb_size = max_nb * 256
546 pb_hist_a = torch.zeros(pb_size, dtype=torch.int32, device=device)
547 pb_hist_b = torch.zeros(pb_size, dtype=torch.int32, device=device)
548 sync = torch.zeros(4, dtype=torch.int32, device=device)
549 counter = torch.zeros(2, dtype=torch.int32, device=device)
550 _cache[key] = (pb_hist_a, pb_hist_b, sync, counter)
551 pb_hist_a, pb_hist_b, sync, counter = _cache[key]
553 _topk_medium_block[(n_blocks,)](
554 logits,
555 seq_lens,
556 pb_hist_a,
557 pb_hist_b,
558 sync,
559 counter,
560 ind,
561 stride1,
562 N=vocab_size,
563 NUM_BLOCKS=n_blocks,
564 BLOCK=_MEDIUM_BLOCK_SIZE,
565 TOP_K=top_k,
566 num_warps=8,
567 )
568 else:
569 # Large vocab: buffer-based multi-block radix
570 n_blocks = (vocab_size + _LARGE_BLOCK_SIZE - 1) // _LARGE_BLOCK_SIZE
571 dev_idx = device.index if device.index is not None else 0
572 key = (dev_idx, "large")
573 if key not in _cache:
574 max_nb = 64
575 pb_size = max_nb * 256
576 total_sz = pb_size + 4
577 scratch = torch.zeros(total_sz, dtype=torch.int32, device=device)
578 buf = torch.empty(_LARGE_BUF_SIZE * 2, dtype=torch.int32, device=device)
579 _cache[key] = (
580 scratch[:pb_size],
581 scratch[pb_size : pb_size + 2],
582 buf[:_LARGE_BUF_SIZE],
583 buf[_LARGE_BUF_SIZE:],
584 scratch[pb_size + 2 : pb_size + 4],
585 )
586 pb_hist, sync, buf_val, buf_idx, counter = _cache[key]
588 _topk_multi_block[(n_blocks,)](
589 logits,
590 seq_lens,
591 pb_hist,
592 sync,
593 buf_val,
594 buf_idx,
595 counter,
596 ind,
597 stride1,
598 N=vocab_size,
599 NUM_BLOCKS=n_blocks,
600 BLOCK=_LARGE_BLOCK_SIZE,
601 TOP_K=top_k,
602 BUF_SIZE=_LARGE_BUF_SIZE,
603 num_warps=8,
604 )