Coverage for src/flag_gems/fused/cp_gather_indexer_k_quant_cache.py: 46%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1# Adapted from vLLM v0.20.2: 

2# csrc/cache_kernels.cu::cp_gather_indexer_k_quant_cache_kernel 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8 

9@triton.jit 

10def _cp_gather_indexer_quant_cache_kernel( 

11 kv_cache_ptr, 

12 kv_cache_scale_ptr, 

13 k_fp8_ptr, 

14 k_scale_ptr, 

15 block_table_ptr, 

16 cu_seqlen_ptr, 

17 block_size, 

18 block_table_stride, 

19 kv_cache_stride, 

20 kv_cache_scale_stride, 

21 k_fp8_stride, 

22 num_quant_blocks, 

23 batch_size: tl.constexpr, 

24 HEAD_DIM: tl.constexpr, 

25 QUANT_BLOCK_SIZE: tl.constexpr, 

26 TOKEN_BLOCK: tl.constexpr, 

27 BATCH_SCAN_SIZE: tl.constexpr, 

28 SEARCH_STEPS: tl.constexpr, 

29): 

30 tid = tl.program_id(0) * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK) 

31 quant_block_id = tl.program_id(1) 

32 

33 if batch_size <= 16: 

34 batch_offsets = tl.arange(0, BATCH_SCAN_SIZE) 

35 batch_mask = batch_offsets < batch_size 

36 seq_starts = tl.load(cu_seqlen_ptr + batch_offsets, mask=batch_mask, other=0) 

37 seq_ends = tl.load(cu_seqlen_ptr + batch_offsets + 1, mask=batch_mask, other=0) 

38 in_batch = ( 

39 (tid[:, None] >= seq_starts[None, :]) 

40 & (tid[:, None] < seq_ends[None, :]) 

41 & batch_mask[None, :] 

42 ) 

43 batch_id = tl.max(tl.where(in_batch, batch_offsets[None, :], -1), axis=1) 

44 else: 

45 left = tl.full((TOKEN_BLOCK,), 0, dtype=tl.int32) 

46 right = tl.full((TOKEN_BLOCK,), batch_size + 1, dtype=tl.int32) 

47 for _ in tl.static_range(0, SEARCH_STEPS): 

48 mid = (left + right) // 2 

49 seq_start = tl.load( 

50 cu_seqlen_ptr + mid, 

51 mask=mid <= batch_size, 

52 other=2147483647, 

53 ) 

54 seq_start_before_token = seq_start <= tid 

55 left = tl.where(seq_start_before_token, mid + 1, left) 

56 right = tl.where(seq_start_before_token, right, mid) 

57 batch_id = left - 1 

58 valid_batch = (batch_id >= 0) & (batch_id < batch_size) 

59 safe_batch_id = tl.minimum(tl.maximum(batch_id, 0), batch_size - 1) 

60 batch_start = tl.load(cu_seqlen_ptr + safe_batch_id, mask=valid_batch, other=0) 

61 batch_end = tl.load(cu_seqlen_ptr + safe_batch_id + 1, mask=valid_batch, other=0) 

62 valid_tokens = valid_batch & (tid >= batch_start) & (tid < batch_end) 

63 batch_offset = tid - batch_start 

64 block_table_id = batch_offset // block_size 

65 block_offset = batch_offset % block_size 

66 block_table_offset = safe_batch_id * block_table_stride + block_table_id 

67 block_id = tl.load(block_table_ptr + block_table_offset, mask=valid_tokens, other=0) 

68 

69 offsets = quant_block_id * QUANT_BLOCK_SIZE + tl.arange(0, QUANT_BLOCK_SIZE) 

70 mask = valid_tokens[:, None] 

71 src_cache_offset = ( 

72 block_id[:, None].to(tl.int64) * kv_cache_stride 

73 + block_offset[:, None].to(tl.int64) * HEAD_DIM 

74 ) 

75 src_scale_offset = ( 

76 block_id * kv_cache_scale_stride 

77 + block_offset * num_quant_blocks 

78 + quant_block_id 

79 ) 

80 dst_offset = tid[:, None].to(tl.int64) * k_fp8_stride 

81 

82 src_scale_ptr = kv_cache_scale_ptr + src_scale_offset 

83 src_cache_ptr = kv_cache_ptr + src_cache_offset 

84 dst_k_ptr = k_fp8_ptr + dst_offset 

85 

86 scale_val = tl.load( 

87 src_scale_ptr, 

88 mask=valid_tokens, 

89 other=0.0, 

90 ) 

91 tl.store( 

92 k_scale_ptr + tid * num_quant_blocks + quant_block_id, 

93 scale_val, 

94 mask=valid_tokens, 

95 ) 

96 val = tl.load(src_cache_ptr + offsets[None, :], mask=mask) 

97 tl.store(dst_k_ptr + offsets[None, :], val, mask=mask) 

98 

99 

100def cp_gather_indexer_k_quant_cache( 

101 k_cache: torch.Tensor, 

102 k_fp8: torch.Tensor, 

103 k_fp8_scale: torch.Tensor, 

104 block_table: torch.Tensor, 

105 cu_seqlen: torch.Tensor, 

106): 

107 num_tokens = k_fp8.size(0) 

108 block_size = k_cache.size(1) 

109 block_table_stride = block_table.stride(0) 

110 head_dim = k_fp8.shape[-1] 

111 num_blocks = k_cache.shape[0] 

112 quant_block_size = head_dim * 4 // k_fp8_scale.size(1) 

113 if head_dim % quant_block_size != 0: 

114 raise ValueError("head_dim must be divisible by quant_block_size") 

115 num_quant_blocks = head_dim // quant_block_size 

116 

117 k_cache_flat = k_cache.view(num_blocks, -1) 

118 k_cache_value = k_cache_flat[:, : block_size * head_dim] 

119 k_cache_scale = k_cache_flat[:, block_size * head_dim :].view(torch.float32) 

120 k_fp8 = k_fp8.view(torch.uint8) 

121 k_fp8_scale = k_fp8_scale.view(torch.float32) 

122 batch_size = block_table.shape[0] 

123 if num_tokens < 32: 

124 token_block = 1 

125 elif num_tokens < 64: 

126 token_block = 2 

127 elif num_tokens < 128: 

128 token_block = 4 

129 elif num_tokens < 256: 

130 token_block = 8 

131 elif num_tokens < 512: 

132 token_block = 16 

133 else: 

134 token_block = 32 

135 if batch_size <= 16: 

136 batch_scan_size = triton.next_power_of_2(batch_size) 

137 else: 

138 # Unused by the binary-search path; kept as a valid constexpr placeholder. 

139 batch_scan_size = 1 

140 search_steps = batch_size.bit_length() 

141 

142 grid = (triton.cdiv(num_tokens, token_block), num_quant_blocks) 

143 _cp_gather_indexer_quant_cache_kernel[grid]( 

144 k_cache_value, 

145 k_cache_scale, 

146 k_fp8, 

147 k_fp8_scale, 

148 block_table, 

149 cu_seqlen, 

150 block_size, 

151 block_table_stride, 

152 k_cache_value.stride(0), 

153 k_cache_scale.stride(0), 

154 k_fp8.stride(0), 

155 num_quant_blocks, 

156 batch_size, 

157 head_dim, 

158 quant_block_size, 

159 token_block, 

160 batch_scan_size, 

161 search_steps, 

162 )