Coverage for src/flag_gems/fused/top_k_per_row_decode.py: 7%

305 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1"""Triton top_k_per_row_decode for DeepSeek V4 decode-phase token selection. 

2 

3Replaces vLLM's top_k_per_row_decode CUDA kernel with a pure Triton 

4implementation using radix-select (4-iteration 8-bit histogram radix). 

5 

6Background: 

7 In DeepSeek V4 decode, each step selects the top-K token indices from a 

8 single row of logits [1, vocab_size]. The vLLM CUDA kernel uses a 

9 radix-based approach; this Triton kernel matches that strategy with 

10 three dispatch tiers optimized for different vocab sizes. 

11 

12Strategy: 

13 1. Single-block path (vocab_size <= 8192): All data fits in one thread 

14 block's registers. Four radix iterations with tl.histogram, no 

15 inter-block synchronization, no global memory scratch. 

16 2. Medium multi-block path (8192 < vocab_size <= 32768): All blocks 

17 participate in all 4 radix iterations. Double-buffered per-block 

18 histograms with 4 barriers (1 per iteration). Eliminates serial 

19 block-0 bottleneck seen in buffer-based approaches. 

20 3. Large multi-block path (vocab_size > 32768): First radix iteration 

21 runs across all blocks with per-block histograms + barrier. Remaining 

22 3 iterations run on block-0 only using a compacted buffer, avoiding 

23 barrier overhead for high block counts. 

24 

25Performance (DeepSeek V4 config, H20 GPU): 

26 - vocab=129280, k=1024: 1.82x faster than vLLM CUDA 

27 - vocab=32768, k=512: 0.78x vs vLLM CUDA 

28 - vocab=8192, k=128: 0.50x vs vLLM CUDA 

29""" 

30 

31import logging 

32 

33import torch 

34import triton 

35import triton.language as tl 

36 

37logger = logging.getLogger(__name__) 

38 

39_SIGN_BIT = tl.constexpr(-(1 << 31)) 

40 

41 

42@triton.jit 

43def _float_to_sortable(val): 

44 """Convert IEEE 754 float to order-preserving unsigned integer. 

45 

46 XOR with sign-dependent mask so that sorted int order == sorted float order. 

47 """ 

48 bits = val.to(tl.int32, bitcast=True) 

49 sign_ext = bits >> 31 

50 mask = sign_ext | tl.full(bits.shape, _SIGN_BIT, dtype=tl.int32) 

51 return bits ^ mask 

52 

53 

54@triton.jit 

55def _topk_single_block( 

56 logits_ptr, 

57 seq_len_ptr, 

58 indices_ptr, 

59 stride1, 

60 N: tl.constexpr, 

61 BLOCK: tl.constexpr, 

62 TOP_K: tl.constexpr, 

63): 

64 """Single-block radix select: all 4 iterations in-register, no barriers.""" 

65 offs = tl.arange(0, BLOCK) 

66 seq_len = tl.load(seq_len_ptr) 

67 valid = (offs < N) & (offs < seq_len) 

68 

69 vals = tl.load(logits_ptr + offs * stride1, mask=valid, other=float("-inf")) 

70 sortable = _float_to_sortable(vals) 

71 

72 bins = tl.arange(0, 256) 

73 

74 # Radix iteration 0: byte 3 (MSB) 

75 bucket_0 = (sortable >> 24) & 0xFF 

76 counts_0 = tl.histogram(bucket_0, 256, mask=valid) 

77 total_0 = tl.sum(counts_0) 

78 ps_0 = tl.cumsum(counts_0, axis=0) 

79 ss_0 = total_0 - ps_0 + counts_0 

80 pivot_0 = tl.max(tl.where(ss_0 >= TOP_K, bins, -1)) 

81 ca_0 = tl.sum(tl.where(bins > pivot_0, counts_0, 0)) 

82 remaining_k = TOP_K - ca_0 

83 match_0 = (bucket_0 == pivot_0) & valid 

84 

85 # Radix iteration 1: byte 2 

86 bucket_1 = (sortable >> 16) & 0xFF 

87 counts_1 = tl.histogram(bucket_1, 256, mask=match_0) 

88 total_1 = tl.sum(counts_1) 

89 ps_1 = tl.cumsum(counts_1, axis=0) 

90 ss_1 = total_1 - ps_1 + counts_1 

91 pivot_1 = tl.max(tl.where(ss_1 >= remaining_k, bins, -1)) 

92 ca_1 = tl.sum(tl.where(bins > pivot_1, counts_1, 0)) 

93 remaining_k = remaining_k - ca_1 

94 match_1 = match_0 & (bucket_1 == pivot_1) 

95 

96 # Radix iteration 2: byte 1 

97 bucket_2 = (sortable >> 8) & 0xFF 

98 counts_2 = tl.histogram(bucket_2, 256, mask=match_1) 

99 total_2 = tl.sum(counts_2) 

100 ps_2 = tl.cumsum(counts_2, axis=0) 

101 ss_2 = total_2 - ps_2 + counts_2 

102 pivot_2 = tl.max(tl.where(ss_2 >= remaining_k, bins, -1)) 

103 ca_2 = tl.sum(tl.where(bins > pivot_2, counts_2, 0)) 

104 remaining_k = remaining_k - ca_2 

105 match_2 = match_1 & (bucket_2 == pivot_2) 

106 

107 # Radix iteration 3: byte 0 (LSB) 

108 bucket_3 = sortable & 0xFF 

109 counts_3 = tl.histogram(bucket_3, 256, mask=match_2) 

110 total_3 = tl.sum(counts_3) 

111 ps_3 = tl.cumsum(counts_3, axis=0) 

112 ss_3 = total_3 - ps_3 + counts_3 

113 pivot_3 = tl.max(tl.where(ss_3 >= remaining_k, bins, -1)) 

114 ca_3 = tl.sum(tl.where(bins > pivot_3, counts_3, 0)) 

115 remaining_k = remaining_k - ca_3 

116 

117 # Selection: write indices for elements above threshold, then equal 

118 threshold = (pivot_0 << 24) | (pivot_1 << 16) | (pivot_2 << 8) | pivot_3 

119 above_total = TOP_K - remaining_k 

120 

121 s_shifted = sortable ^ tl.full(sortable.shape, _SIGN_BIT, dtype=tl.int32) 

122 t_shifted = threshold ^ _SIGN_BIT 

123 

124 above = (s_shifted > t_shifted) & valid 

125 equal = (sortable == threshold) & valid 

126 

127 pa = tl.cumsum(above.to(tl.int32), axis=0) 

128 tl.store( 

129 indices_ptr + pa - 1, 

130 offs.to(tl.int32), 

131 mask=above & (pa - 1 >= 0) & (pa - 1 < TOP_K), 

132 ) 

133 

134 pe = tl.cumsum(equal.to(tl.int32), axis=0) 

135 wpe = above_total + pe - 1 

136 tl.store( 

137 indices_ptr + wpe, 

138 offs.to(tl.int32), 

139 mask=equal & ((pe - 1) < remaining_k) & (wpe >= 0) & (wpe < TOP_K), 

140 ) 

141 

142 

143@triton.jit 

144def _topk_medium_block( 

145 logits_ptr, 

146 seq_len_ptr, 

147 pb_hist_a_ptr, 

148 pb_hist_b_ptr, 

149 sync_ptr, 

150 counter_ptr, 

151 indices_ptr, 

152 stride1, 

153 N: tl.constexpr, 

154 NUM_BLOCKS: tl.constexpr, 

155 BLOCK: tl.constexpr, 

156 TOP_K: tl.constexpr, 

157): 

158 """Multi-block radix select for medium vocab (8K-32K). 

159 

160 All blocks participate in all 4 radix iterations using double-buffered 

161 per-block histograms. 4 barriers total (1 per iteration). 

162 """ 

163 pid = tl.program_id(0) 

164 offs = pid * BLOCK + tl.arange(0, BLOCK) 

165 seq_len = tl.load(seq_len_ptr) 

166 valid = (offs < N) & (offs < seq_len) 

167 

168 vals = tl.load(logits_ptr + offs * stride1, mask=valid, other=float("-inf")) 

169 sortable = _float_to_sortable(vals) 

170 

171 bins = tl.arange(0, 256) 

172 ha_base = pb_hist_a_ptr + pid * 256 

173 hb_base = pb_hist_b_ptr + pid * 256 

174 

175 # Iteration 0: byte 3 (MSB), write to buf_A 

176 bucket_0 = (sortable >> 24) & 0xFF 

177 local_hist_0 = tl.histogram(bucket_0, 256, mask=valid) 

178 tl.store(ha_base + bins, local_hist_0) 

179 

180 tl.debug_barrier() 

181 tl.atomic_add(sync_ptr, 1) 

182 while tl.atomic_add(sync_ptr, 0) < NUM_BLOCKS: 

183 pass 

184 

185 counts = tl.zeros([256], dtype=tl.int32) 

186 for i in tl.static_range(NUM_BLOCKS): 

187 counts += tl.load(pb_hist_a_ptr + i * 256 + bins) 

188 

189 total_0 = tl.sum(counts) 

190 ps_0 = tl.cumsum(counts, axis=0) 

191 ss_0 = total_0 - ps_0 + counts 

192 pivot_0 = tl.max(tl.where(ss_0 >= TOP_K, bins, -1)) 

193 ca_0 = tl.sum(tl.where(bins > pivot_0, counts, 0)) 

194 remaining_k = TOP_K - ca_0 

195 match = (bucket_0 == pivot_0) & valid 

196 

197 # Iteration 1: byte 2, write to buf_B 

198 bucket_1 = (sortable >> 16) & 0xFF 

199 local_hist_1 = tl.histogram(bucket_1, 256, mask=match) 

200 tl.store(hb_base + bins, local_hist_1) 

201 

202 tl.debug_barrier() 

203 tl.atomic_add(sync_ptr + 1, 1) 

204 while tl.atomic_add(sync_ptr + 1, 0) < NUM_BLOCKS: 

205 pass 

206 

207 counts = tl.zeros([256], dtype=tl.int32) 

208 for i in tl.static_range(NUM_BLOCKS): 

209 counts += tl.load(pb_hist_b_ptr + i * 256 + bins) 

210 

211 total_1 = tl.sum(counts) 

212 ps_1 = tl.cumsum(counts, axis=0) 

213 ss_1 = total_1 - ps_1 + counts 

214 pivot_1 = tl.max(tl.where(ss_1 >= remaining_k, bins, -1)) 

215 ca_1 = tl.sum(tl.where(bins > pivot_1, counts, 0)) 

216 remaining_k = remaining_k - ca_1 

217 match = match & (bucket_1 == pivot_1) 

218 

219 # Iteration 2: byte 1, write to buf_A 

220 bucket_2 = (sortable >> 8) & 0xFF 

221 local_hist_2 = tl.histogram(bucket_2, 256, mask=match) 

222 tl.store(ha_base + bins, local_hist_2) 

223 

224 tl.debug_barrier() 

225 tl.atomic_add(sync_ptr + 2, 1) 

226 while tl.atomic_add(sync_ptr + 2, 0) < NUM_BLOCKS: 

227 pass 

228 

229 counts = tl.zeros([256], dtype=tl.int32) 

230 for i in tl.static_range(NUM_BLOCKS): 

231 counts += tl.load(pb_hist_a_ptr + i * 256 + bins) 

232 

233 total_2 = tl.sum(counts) 

234 ps_2 = tl.cumsum(counts, axis=0) 

235 ss_2 = total_2 - ps_2 + counts 

236 pivot_2 = tl.max(tl.where(ss_2 >= remaining_k, bins, -1)) 

237 ca_2 = tl.sum(tl.where(bins > pivot_2, counts, 0)) 

238 remaining_k = remaining_k - ca_2 

239 match = match & (bucket_2 == pivot_2) 

240 

241 # Iteration 3: byte 0 (LSB), write to buf_B 

242 bucket_3 = sortable & 0xFF 

243 local_hist_3 = tl.histogram(bucket_3, 256, mask=match) 

244 tl.store(hb_base + bins, local_hist_3) 

245 

246 tl.debug_barrier() 

247 tl.atomic_add(sync_ptr + 3, 1) 

248 while tl.atomic_add(sync_ptr + 3, 0) < NUM_BLOCKS: 

249 pass 

250 

251 counts = tl.zeros([256], dtype=tl.int32) 

252 for i in tl.static_range(NUM_BLOCKS): 

253 counts += tl.load(pb_hist_b_ptr + i * 256 + bins) 

254 

255 total_3 = tl.sum(counts) 

256 ps_3 = tl.cumsum(counts, axis=0) 

257 ss_3 = total_3 - ps_3 + counts 

258 pivot_3 = tl.max(tl.where(ss_3 >= remaining_k, bins, -1)) 

259 ca_3 = tl.sum(tl.where(bins > pivot_3, counts, 0)) 

260 remaining_k = remaining_k - ca_3 

261 

262 # Selection phase 

263 threshold = (pivot_0 << 24) | (pivot_1 << 16) | (pivot_2 << 8) | pivot_3 

264 above_total = TOP_K - remaining_k 

265 

266 s_shifted = sortable ^ tl.full(sortable.shape, _SIGN_BIT, dtype=tl.int32) 

267 t_shifted = threshold ^ _SIGN_BIT 

268 

269 above = (s_shifted > t_shifted) & valid 

270 equal = (sortable == threshold) & valid 

271 

272 n_above = tl.sum(above.to(tl.int32)) 

273 if n_above > 0: 

274 pa = tl.cumsum(above.to(tl.int32), axis=0) 

275 base_a = tl.atomic_add(counter_ptr, n_above) 

276 wp = base_a + pa - 1 

277 tl.store( 

278 indices_ptr + wp, 

279 offs.to(tl.int32), 

280 mask=above & (wp >= 0) & (wp < TOP_K), 

281 ) 

282 

283 n_equal = tl.sum(equal.to(tl.int32)) 

284 if n_equal > 0: 

285 pe = tl.cumsum(equal.to(tl.int32), axis=0) 

286 base_e = tl.atomic_add(counter_ptr + 1, n_equal) 

287 wpe = above_total + base_e + pe - 1 

288 tl.store( 

289 indices_ptr + wpe, 

290 offs.to(tl.int32), 

291 mask=equal & ((base_e + pe - 1) < remaining_k) & (wpe >= 0) & (wpe < TOP_K), 

292 ) 

293 

294 # Zero shared state for next call 

295 if pid == 0: 

296 tl.store(sync_ptr + tl.arange(0, 4), tl.zeros([4], dtype=tl.int32)) 

297 tl.store(counter_ptr, 0) 

298 tl.store(counter_ptr + 1, 0) 

299 

300 

301@triton.jit 

302def _topk_multi_block( 

303 logits_ptr, 

304 seq_len_ptr, 

305 pb_hist_ptr, 

306 sync_ptr, 

307 buf_val_ptr, 

308 buf_idx_ptr, 

309 counter_ptr, 

310 indices_ptr, 

311 stride1, 

312 N: tl.constexpr, 

313 NUM_BLOCKS: tl.constexpr, 

314 BLOCK: tl.constexpr, 

315 TOP_K: tl.constexpr, 

316 BUF_SIZE: tl.constexpr, 

317): 

318 """Multi-block radix select for large vocab (>32K). 

319 

320 Iteration 0: all blocks compute byte-3 histograms + barrier + reduce. 

321 Iterations 1-3: block-0 only, operating on a compacted buffer of 

322 elements matching the byte-3 pivot. Avoids barrier overhead for 

323 high block counts (e.g. 32 blocks for vocab=129280). 

324 """ 

325 pid = tl.program_id(0) 

326 offs = pid * BLOCK + tl.arange(0, BLOCK) 

327 seq_len = tl.load(seq_len_ptr) 

328 valid = (offs < N) & (offs < seq_len) 

329 

330 vals = tl.load(logits_ptr + offs * stride1, mask=valid, other=float("-inf")) 

331 sortable = _float_to_sortable(vals) 

332 

333 # Iteration 0: all blocks compute byte-3 histogram 

334 bucket = (sortable >> 24) & 0xFF 

335 local_hist = tl.histogram(bucket, 256, mask=valid) 

336 

337 bins = tl.arange(0, 256) 

338 h_base = pb_hist_ptr + pid * 256 

339 tl.store(h_base + bins, local_hist) 

340 

341 tl.debug_barrier() 

342 tl.atomic_add(sync_ptr, 1) 

343 while tl.atomic_add(sync_ptr, 0) < NUM_BLOCKS: 

344 pass 

345 

346 counts = tl.zeros([256], dtype=tl.int32) 

347 for i in tl.static_range(NUM_BLOCKS): 

348 counts += tl.load(pb_hist_ptr + i * 256 + bins) 

349 

350 total = tl.sum(counts) 

351 ps = tl.cumsum(counts, axis=0) 

352 ss = total - ps + counts 

353 pivot_0 = tl.max(tl.where(ss >= TOP_K, bins, -1)) 

354 count_above_0 = tl.sum(tl.where(bins > pivot_0, counts, 0)) 

355 remaining_k = TOP_K - count_above_0 

356 

357 above = (bucket > pivot_0) & valid 

358 match = (bucket == pivot_0) & valid 

359 

360 # Write above-threshold indices directly to output 

361 n_above = tl.sum(above.to(tl.int32)) 

362 if n_above > 0: 

363 pa = tl.cumsum(above.to(tl.int32), axis=0) 

364 base_a = tl.atomic_add(counter_ptr, n_above) 

365 wp = base_a + pa - 1 

366 tl.store( 

367 indices_ptr + wp, 

368 offs.to(tl.int32), 

369 mask=above & (wp >= 0) & (wp < TOP_K), 

370 ) 

371 

372 # Compact matching elements into buffer for block-0 

373 n_match = tl.sum(match.to(tl.int32)) 

374 if n_match > 0: 

375 pm = tl.cumsum(match.to(tl.int32), axis=0) 

376 base_m = tl.atomic_add(counter_ptr + 1, n_match) 

377 bp = base_m + pm - 1 

378 tl.store( 

379 buf_val_ptr + bp, 

380 sortable, 

381 mask=match & (bp >= 0) & (bp < BUF_SIZE), 

382 ) 

383 tl.store( 

384 buf_idx_ptr + bp, 

385 offs.to(tl.int32), 

386 mask=match & (bp >= 0) & (bp < BUF_SIZE), 

387 ) 

388 

389 # Iterations 1-3: block-0 processes compacted buffer 

390 tl.debug_barrier() 

391 tl.atomic_add(sync_ptr + 1, 1) 

392 if pid == 0: 

393 while tl.atomic_add(sync_ptr + 1, 0) < NUM_BLOCKS: 

394 pass 

395 

396 buf_count = tl.atomic_add(counter_ptr + 1, 0) 

397 

398 b_offs = tl.arange(0, BUF_SIZE) 

399 b_valid = b_offs < buf_count 

400 b_vals = tl.load(buf_val_ptr + b_offs, mask=b_valid, other=0) 

401 b_idxs = tl.load(buf_idx_ptr + b_offs, mask=b_valid, other=0) 

402 

403 # Iteration 1: byte 2 

404 b_byte_1 = (b_vals >> 16) & 0xFF 

405 counts_1 = tl.histogram(b_byte_1, 256, mask=b_valid) 

406 total_1 = tl.sum(counts_1) 

407 ps_1 = tl.cumsum(counts_1, axis=0) 

408 ss_1 = total_1 - ps_1 + counts_1 

409 pivot_1 = tl.max(tl.where(ss_1 >= remaining_k, bins, -1)) 

410 ca_1 = tl.sum(tl.where(bins > pivot_1, counts_1, 0)) 

411 remaining_k = remaining_k - ca_1 

412 

413 # Iteration 2: byte 1 

414 prefix_hi16 = (pivot_0 << 8) | pivot_1 

415 upper16 = (b_vals >> 16) & 0xFFFF 

416 b_match_2 = (upper16 == prefix_hi16) & b_valid 

417 b_bucket_2 = (b_vals >> 8) & 0xFF 

418 counts_2 = tl.histogram(b_bucket_2, 256, mask=b_match_2) 

419 total_2 = tl.sum(counts_2) 

420 ps_2 = tl.cumsum(counts_2, axis=0) 

421 ss_2 = total_2 - ps_2 + counts_2 

422 pivot_2 = tl.max(tl.where(ss_2 >= remaining_k, bins, -1)) 

423 ca_2 = tl.sum(tl.where(bins > pivot_2, counts_2, 0)) 

424 remaining_k = remaining_k - ca_2 

425 

426 # Iteration 3: byte 0 (LSB) 

427 prefix_hi24 = (prefix_hi16 << 8) | pivot_2 

428 upper24 = (b_vals >> 8) & 0xFFFFFF 

429 b_match_3 = (upper24 == prefix_hi24) & b_valid 

430 b_bucket_3 = b_vals & 0xFF 

431 counts_3 = tl.histogram(b_bucket_3, 256, mask=b_match_3) 

432 total_3 = tl.sum(counts_3) 

433 ps_3 = tl.cumsum(counts_3, axis=0) 

434 ss_3 = total_3 - ps_3 + counts_3 

435 pivot_3 = tl.max(tl.where(ss_3 >= remaining_k, bins, -1)) 

436 ca_3 = tl.sum(tl.where(bins > pivot_3, counts_3, 0)) 

437 remaining_k = remaining_k - ca_3 

438 

439 # Final selection from buffer 

440 threshold = (prefix_hi24 << 8) | pivot_3 

441 above_total = TOP_K - remaining_k 

442 

443 s_sh = b_vals ^ tl.full(b_vals.shape, _SIGN_BIT, dtype=tl.int32) 

444 t_sh = threshold ^ _SIGN_BIT 

445 

446 above_buf = (s_sh > t_sh) & b_valid 

447 equal_buf = (b_vals == threshold) & b_valid 

448 

449 pa_b = tl.cumsum(above_buf.to(tl.int32), axis=0) 

450 wp_b = count_above_0 + pa_b - 1 

451 tl.store( 

452 indices_ptr + wp_b, 

453 b_idxs, 

454 mask=above_buf & (wp_b >= 0) & (wp_b < TOP_K), 

455 ) 

456 

457 pe_b = tl.cumsum(equal_buf.to(tl.int32), axis=0) 

458 wpe_b = above_total + pe_b - 1 

459 tl.store( 

460 indices_ptr + wpe_b, 

461 b_idxs, 

462 mask=equal_buf 

463 & ((pe_b - 1) < remaining_k) 

464 & (wpe_b >= 0) 

465 & (wpe_b < TOP_K), 

466 ) 

467 

468 tl.store(sync_ptr, 0) 

469 tl.store(sync_ptr + 1, 0) 

470 tl.store(counter_ptr, 0) 

471 tl.store(counter_ptr + 1, 0) 

472 

473 

474# Persistent scratch buffers, keyed by (device_index, dispatch_tier). 

475# Allocated once per device and reused across calls to avoid cudaMalloc overhead. 

476_cache = {} 

477 

478# Dispatch thresholds for the three kernel tiers 

479_SINGLE_BLOCK_LIMIT = 8192 

480_MEDIUM_BLOCK_LIMIT = 32768 

481_MEDIUM_BLOCK_SIZE = 4096 

482_LARGE_BLOCK_SIZE = 4096 

483_LARGE_BUF_SIZE = 4096 

484 

485 

486def top_k_per_row_decode( 

487 logits, next_n, seq_lens, indices, num_rows, stride0, stride1, top_k 

488): 

489 """Top-K per row for decode phase of DeepSeek V4. 

490 

491 Selects top_k indices from a single row of logits using radix-based 

492 selection. Only valid elements within [0, seq_lens[0]) are considered. 

493 

494 Args: 

495 logits: [1, vocab_size] float32 tensor. 

496 next_n: number of next tokens (unused, kept for API compatibility). 

497 seq_lens: [1] int32 — valid range [0, seq_lens[0]). 

498 indices: [1, top_k] int32 — output buffer, filled with selected indices. 

499 num_rows: must be 1 (decode processes one row at a time). 

500 stride0: logits.stride(0). 

501 stride1: logits.stride(1). 

502 top_k: number of top elements to select. 

503 """ 

504 logger.debug("GEMS TOP_K_PER_ROW_DECODE") 

505 

506 assert num_rows == 1, "Only num_rows=1 supported in decode path" 

507 

508 vocab_size = logits.shape[1] 

509 device = logits.device 

510 ind = indices.view(-1) 

511 

512 if vocab_size <= _SINGLE_BLOCK_LIMIT // 2: 

513 # Small vocab: single block with BLOCK=4096 

514 _topk_single_block[(1,)]( 

515 logits, 

516 seq_lens, 

517 ind, 

518 stride1, 

519 N=vocab_size, 

520 BLOCK=_SINGLE_BLOCK_LIMIT // 2, 

521 TOP_K=top_k, 

522 num_warps=8, 

523 ) 

524 elif vocab_size <= _SINGLE_BLOCK_LIMIT: 

525 # Medium-small vocab: single block with BLOCK=8192 

526 _topk_single_block[(1,)]( 

527 logits, 

528 seq_lens, 

529 ind, 

530 stride1, 

531 N=vocab_size, 

532 BLOCK=_SINGLE_BLOCK_LIMIT, 

533 TOP_K=top_k, 

534 num_warps=16, 

535 ) 

536 elif vocab_size <= _MEDIUM_BLOCK_LIMIT: 

537 # Medium vocab: double-buffered all-blocks radix 

538 n_blocks = (vocab_size + _MEDIUM_BLOCK_SIZE - 1) // _MEDIUM_BLOCK_SIZE 

539 dev_idx = device.index if device.index is not None else 0 

540 key = (dev_idx, "med") 

541 if key not in _cache: 

542 max_nb = ( 

543 _MEDIUM_BLOCK_LIMIT + _MEDIUM_BLOCK_SIZE - 1 

544 ) // _MEDIUM_BLOCK_SIZE 

545 pb_size = max_nb * 256 

546 pb_hist_a = torch.zeros(pb_size, dtype=torch.int32, device=device) 

547 pb_hist_b = torch.zeros(pb_size, dtype=torch.int32, device=device) 

548 sync = torch.zeros(4, dtype=torch.int32, device=device) 

549 counter = torch.zeros(2, dtype=torch.int32, device=device) 

550 _cache[key] = (pb_hist_a, pb_hist_b, sync, counter) 

551 pb_hist_a, pb_hist_b, sync, counter = _cache[key] 

552 

553 _topk_medium_block[(n_blocks,)]( 

554 logits, 

555 seq_lens, 

556 pb_hist_a, 

557 pb_hist_b, 

558 sync, 

559 counter, 

560 ind, 

561 stride1, 

562 N=vocab_size, 

563 NUM_BLOCKS=n_blocks, 

564 BLOCK=_MEDIUM_BLOCK_SIZE, 

565 TOP_K=top_k, 

566 num_warps=8, 

567 ) 

568 else: 

569 # Large vocab: buffer-based multi-block radix 

570 n_blocks = (vocab_size + _LARGE_BLOCK_SIZE - 1) // _LARGE_BLOCK_SIZE 

571 dev_idx = device.index if device.index is not None else 0 

572 key = (dev_idx, "large") 

573 if key not in _cache: 

574 max_nb = 64 

575 pb_size = max_nb * 256 

576 total_sz = pb_size + 4 

577 scratch = torch.zeros(total_sz, dtype=torch.int32, device=device) 

578 buf = torch.empty(_LARGE_BUF_SIZE * 2, dtype=torch.int32, device=device) 

579 _cache[key] = ( 

580 scratch[:pb_size], 

581 scratch[pb_size : pb_size + 2], 

582 buf[:_LARGE_BUF_SIZE], 

583 buf[_LARGE_BUF_SIZE:], 

584 scratch[pb_size + 2 : pb_size + 4], 

585 ) 

586 pb_hist, sync, buf_val, buf_idx, counter = _cache[key] 

587 

588 _topk_multi_block[(n_blocks,)]( 

589 logits, 

590 seq_lens, 

591 pb_hist, 

592 sync, 

593 buf_val, 

594 buf_idx, 

595 counter, 

596 ind, 

597 stride1, 

598 N=vocab_size, 

599 NUM_BLOCKS=n_blocks, 

600 BLOCK=_LARGE_BLOCK_SIZE, 

601 TOP_K=top_k, 

602 BUF_SIZE=_LARGE_BUF_SIZE, 

603 num_warps=8, 

604 )