Coverage for src/flag_gems/fused/top_k_per_row_prefill.py: 44%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1"""Triton top_k_per_row_prefill for DeepSeek V4 sparse attention. 

2 

3Replaces vLLM's persistent_topk CUDA kernel with a Triton implementation. 

4 

5Background: 

6 In DeepSeek V4 prefill, each token computes attention logits over a subset of 

7 the vocabulary [row_starts[i], row_ends[i]) and selects the top-K indices. 

8 Typical config: vocab_size=129280, top_k=1024, num_rows=1 (decode) or 32+ (prefill). 

9 

10Strategy: 

11 1. In-place masking kernel: set logits outside [row_starts, row_ends) to -inf. 

12 Early exit when the row uses full vocab (start==0, end>=vocab_size), which is 

13 the common case during inference and avoids unnecessary memory writes. 

14 2. Adaptive top-K selection: 

15 - num_rows=1: torch.argsort (backed by CUB radix sort, O(N) for single row, 

16 ~2x faster than torch.topk for large vocab on a single row) 

17 - num_rows>1: torch.topk with sorted=False (heap-based O(N log k), better 

18 parallelism across rows than argsort) 

19 3. Fused postprocess kernel: single Triton kernel performs slice + cast + subtract 

20 in one pass, converting absolute vocab indices to 0-based indices relative to 

21 row_starts[i]. Saves one kernel launch vs separate slice/subtract ops. 

22 

23Performance (DeepSeek V4 config, vocab=129280, top_k=1024): 

24 - num_rows=1: 0.89x vs vLLM CUDA (competitive, bounded by argsort) 

25 - num_rows=32: 0.38x vs vLLM CUDA (bounded by torch.topk on large vocab) 

26""" 

27 

28import torch 

29import triton 

30import triton.language as tl 

31 

32 

33@triton.jit 

34def _mask_invalid_kernel( 

35 logits_ptr, 

36 row_starts_ptr, 

37 row_ends_ptr, 

38 stride0, # logits row stride (= vocab_size for contiguous tensor) 

39 BLOCK_SIZE: tl.constexpr, # 8192: tuned for 129280 vocab (16 blocks/row) 

40 VOCAB_SIZE: tl.constexpr, # total vocabulary size (e.g. 129280) 

41): 

42 """Mask logits outside [row_starts[i], row_ends[i]) to -inf, in-place. 

43 

44 Grid: (num_rows * num_blocks_per_row,) — 1D flat grid. 

45 Each program handles one BLOCK_SIZE chunk of one row. 

46 Early exits when the row uses full vocab to avoid unnecessary stores. 

47 """ 

48 pid = tl.program_id(0) 

49 num_blocks_per_row = tl.cdiv(VOCAB_SIZE, BLOCK_SIZE) 

50 row_id = pid // num_blocks_per_row 

51 block_id = pid % num_blocks_per_row 

52 

53 start = tl.load(row_starts_ptr + row_id) 

54 end = tl.load(row_ends_ptr + row_id) 

55 

56 # Early exit: most rows in inference use full vocab (start=0, end=vocab_size). 

57 # Skipping these avoids ~90% of memory writes in typical workloads. 

58 if start == 0 and end >= VOCAB_SIZE: 

59 return 

60 

61 # Compute which positions in this block are outside the valid range 

62 offs = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

63 out_of_range = (offs < start) | (offs >= end) 

64 # Only write to positions that are both within vocab bounds AND out of valid range 

65 mask = (offs < VOCAB_SIZE) & out_of_range 

66 

67 tl.store(logits_ptr + row_id * stride0 + offs, float("-inf"), mask=mask) 

68 

69 

70@triton.jit 

71def _fused_postprocess_kernel( 

72 src_ptr, # source indices (from argsort or topk) 

73 dst_ptr, # destination: output indices buffer [num_rows, top_k] 

74 row_starts_ptr, # per-row start offsets for index adjustment 

75 num_rows: tl.constexpr, 

76 top_k: tl.constexpr, # 1024 in DeepSeek V4 

77 src_stride0: tl.constexpr, # row stride of src (vocab_size for argsort, top_k for topk) 

78 BLOCK_SIZE: tl.constexpr, # next_power_of_2(top_k), e.g. 1024 

79): 

80 """Fused slice + cast + subtract: convert absolute indices to row-relative. 

81 

82 For each row i, computes: dst[i, :top_k] = src[i, :top_k] - row_starts[i] 

83 This converts absolute vocab indices to 0-based indices within the valid range. 

84 Grid: (num_rows,) — one program per row. 

85 """ 

86 row_id = tl.program_id(0) 

87 if row_id >= num_rows: 

88 return 

89 

90 row_start = tl.load(row_starts_ptr + row_id) 

91 

92 offs = tl.arange(0, BLOCK_SIZE) 

93 mask = offs < top_k 

94 

95 src_idx = row_id * src_stride0 + offs 

96 src_vals = tl.load(src_ptr + src_idx, mask=mask, other=0) 

97 

98 # Subtract row_start to get 0-based index within [row_start, row_end) 

99 dst_vals = (src_vals - row_start).to(tl.int32) 

100 

101 dst_idx = row_id * top_k + offs 

102 tl.store(dst_ptr + dst_idx, dst_vals, mask=mask) 

103 

104 

105def top_k_per_row_prefill( 

106 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k 

107): 

108 """Top-K per row for prefill phase of DeepSeek V4 sparse attention. 

109 

110 Masks invalid ranges in-place, then selects top-K indices per row. 

111 Output indices are 0-based relative to row_starts[i]. 

112 

113 Args: 

114 logits: [num_rows, vocab_size] float32 tensor, modified in-place (masked to -inf). 

115 In DeepSeek V4: vocab_size=129280. 

116 row_starts: [num_rows] int32 — start of valid range per row (inclusive). 

117 row_ends: [num_rows] int32 — end of valid range per row (exclusive). 

118 indices: [num_rows, top_k] int32 — output buffer, filled with 0-based indices 

119 relative to row_starts[i]. Caller pre-allocates this. 

120 num_rows: number of rows (1 for decode, 32/64/2048 for prefill batches). 

121 stride0: logits.stride(0), typically == vocab_size for contiguous tensor. 

122 stride1: logits.stride(1), typically == 1 for contiguous tensor. 

123 top_k: number of top elements per row (1024 in DeepSeek V4). 

124 """ 

125 vocab_size = logits.shape[1] 

126 

127 if top_k > vocab_size: 

128 raise ValueError(f"top_k ({top_k}) must not exceed vocab_size ({vocab_size})") 

129 

130 # --- Phase 1: Mask invalid ranges to -inf --- 

131 # BLOCK_SIZE=8192 chosen to balance occupancy vs. grid size: 

132 # For vocab=129280, this gives ceil(129280/8192)=16 blocks per row. 

133 # num_warps=2 is sufficient since masking is memory-bound (simple store). 

134 MASK_BS = 8192 

135 num_mask_blocks = (vocab_size + MASK_BS - 1) // MASK_BS 

136 _mask_invalid_kernel[(num_rows * num_mask_blocks,)]( 

137 logits, 

138 row_starts, 

139 row_ends, 

140 stride0, 

141 BLOCK_SIZE=MASK_BS, 

142 VOCAB_SIZE=vocab_size, 

143 num_warps=2, 

144 ) 

145 

146 # --- Phase 2: Select top-K indices --- 

147 # POSTPROC_BLOCK must be power-of-2 >= top_k for tl.arange. 

148 # For top_k=1024, this is exactly 1024 (no waste). 

149 POSTPROC_BLOCK = triton.next_power_of_2(top_k) 

150 

151 if num_rows == 1: 

152 # Single row path: torch.argsort uses CUB radix sort under the hood. 

153 # For large vocab (129280) with a single row, radix sort O(N) is ~2x faster 

154 # than torch.topk's heap-based O(N log k) because it fully utilizes GPU 

155 # parallelism without the sequential heap maintenance bottleneck. 

156 sorted_idx = torch.argsort(logits, dim=1, descending=True, stable=False) 

157 # src_stride0=vocab_size because argsort returns full-width sorted indices 

158 _fused_postprocess_kernel[(1,)]( 

159 sorted_idx, 

160 indices, 

161 row_starts, 

162 num_rows=1, 

163 top_k=top_k, 

164 src_stride0=vocab_size, 

165 BLOCK_SIZE=POSTPROC_BLOCK, 

166 num_warps=4, 

167 ) 

168 else: 

169 # Multi-row path: torch.topk with sorted=False. 

170 # For batched rows, topk's heap approach has better parallelism across rows 

171 # than argsort (which serializes the full sort per row). 

172 # sorted=False avoids an unnecessary final sort pass. 

173 _, top_idx = torch.topk(logits, top_k, dim=1, largest=True, sorted=False) 

174 # src_stride0=top_k because topk output shape is [num_rows, top_k] 

175 _fused_postprocess_kernel[(num_rows,)]( 

176 top_idx, 

177 indices, 

178 row_starts, 

179 num_rows=num_rows, 

180 top_k=top_k, 

181 src_stride0=top_k, 

182 BLOCK_SIZE=POSTPROC_BLOCK, 

183 num_warps=4, 

184 )