Coverage for src/flag_gems/fused/cp_gather_indexer_k_quant_cache.py: 46%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1# Adapted from vLLM v0.20.2:
2# csrc/cache_kernels.cu::cp_gather_indexer_k_quant_cache_kernel
4import torch
5import triton
6import triton.language as tl
9@triton.jit
10def _cp_gather_indexer_quant_cache_kernel(
11 kv_cache_ptr,
12 kv_cache_scale_ptr,
13 k_fp8_ptr,
14 k_scale_ptr,
15 block_table_ptr,
16 cu_seqlen_ptr,
17 block_size,
18 block_table_stride,
19 kv_cache_stride,
20 kv_cache_scale_stride,
21 k_fp8_stride,
22 num_quant_blocks,
23 batch_size: tl.constexpr,
24 HEAD_DIM: tl.constexpr,
25 QUANT_BLOCK_SIZE: tl.constexpr,
26 TOKEN_BLOCK: tl.constexpr,
27 BATCH_SCAN_SIZE: tl.constexpr,
28 SEARCH_STEPS: tl.constexpr,
29):
30 tid = tl.program_id(0) * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
31 quant_block_id = tl.program_id(1)
33 if batch_size <= 16:
34 batch_offsets = tl.arange(0, BATCH_SCAN_SIZE)
35 batch_mask = batch_offsets < batch_size
36 seq_starts = tl.load(cu_seqlen_ptr + batch_offsets, mask=batch_mask, other=0)
37 seq_ends = tl.load(cu_seqlen_ptr + batch_offsets + 1, mask=batch_mask, other=0)
38 in_batch = (
39 (tid[:, None] >= seq_starts[None, :])
40 & (tid[:, None] < seq_ends[None, :])
41 & batch_mask[None, :]
42 )
43 batch_id = tl.max(tl.where(in_batch, batch_offsets[None, :], -1), axis=1)
44 else:
45 left = tl.full((TOKEN_BLOCK,), 0, dtype=tl.int32)
46 right = tl.full((TOKEN_BLOCK,), batch_size + 1, dtype=tl.int32)
47 for _ in tl.static_range(0, SEARCH_STEPS):
48 mid = (left + right) // 2
49 seq_start = tl.load(
50 cu_seqlen_ptr + mid,
51 mask=mid <= batch_size,
52 other=2147483647,
53 )
54 seq_start_before_token = seq_start <= tid
55 left = tl.where(seq_start_before_token, mid + 1, left)
56 right = tl.where(seq_start_before_token, right, mid)
57 batch_id = left - 1
58 valid_batch = (batch_id >= 0) & (batch_id < batch_size)
59 safe_batch_id = tl.minimum(tl.maximum(batch_id, 0), batch_size - 1)
60 batch_start = tl.load(cu_seqlen_ptr + safe_batch_id, mask=valid_batch, other=0)
61 batch_end = tl.load(cu_seqlen_ptr + safe_batch_id + 1, mask=valid_batch, other=0)
62 valid_tokens = valid_batch & (tid >= batch_start) & (tid < batch_end)
63 batch_offset = tid - batch_start
64 block_table_id = batch_offset // block_size
65 block_offset = batch_offset % block_size
66 block_table_offset = safe_batch_id * block_table_stride + block_table_id
67 block_id = tl.load(block_table_ptr + block_table_offset, mask=valid_tokens, other=0)
69 offsets = quant_block_id * QUANT_BLOCK_SIZE + tl.arange(0, QUANT_BLOCK_SIZE)
70 mask = valid_tokens[:, None]
71 src_cache_offset = (
72 block_id[:, None].to(tl.int64) * kv_cache_stride
73 + block_offset[:, None].to(tl.int64) * HEAD_DIM
74 )
75 src_scale_offset = (
76 block_id * kv_cache_scale_stride
77 + block_offset * num_quant_blocks
78 + quant_block_id
79 )
80 dst_offset = tid[:, None].to(tl.int64) * k_fp8_stride
82 src_scale_ptr = kv_cache_scale_ptr + src_scale_offset
83 src_cache_ptr = kv_cache_ptr + src_cache_offset
84 dst_k_ptr = k_fp8_ptr + dst_offset
86 scale_val = tl.load(
87 src_scale_ptr,
88 mask=valid_tokens,
89 other=0.0,
90 )
91 tl.store(
92 k_scale_ptr + tid * num_quant_blocks + quant_block_id,
93 scale_val,
94 mask=valid_tokens,
95 )
96 val = tl.load(src_cache_ptr + offsets[None, :], mask=mask)
97 tl.store(dst_k_ptr + offsets[None, :], val, mask=mask)
100def cp_gather_indexer_k_quant_cache(
101 k_cache: torch.Tensor,
102 k_fp8: torch.Tensor,
103 k_fp8_scale: torch.Tensor,
104 block_table: torch.Tensor,
105 cu_seqlen: torch.Tensor,
106):
107 num_tokens = k_fp8.size(0)
108 block_size = k_cache.size(1)
109 block_table_stride = block_table.stride(0)
110 head_dim = k_fp8.shape[-1]
111 num_blocks = k_cache.shape[0]
112 quant_block_size = head_dim * 4 // k_fp8_scale.size(1)
113 if head_dim % quant_block_size != 0:
114 raise ValueError("head_dim must be divisible by quant_block_size")
115 num_quant_blocks = head_dim // quant_block_size
117 k_cache_flat = k_cache.view(num_blocks, -1)
118 k_cache_value = k_cache_flat[:, : block_size * head_dim]
119 k_cache_scale = k_cache_flat[:, block_size * head_dim :].view(torch.float32)
120 k_fp8 = k_fp8.view(torch.uint8)
121 k_fp8_scale = k_fp8_scale.view(torch.float32)
122 batch_size = block_table.shape[0]
123 if num_tokens < 32:
124 token_block = 1
125 elif num_tokens < 64:
126 token_block = 2
127 elif num_tokens < 128:
128 token_block = 4
129 elif num_tokens < 256:
130 token_block = 8
131 elif num_tokens < 512:
132 token_block = 16
133 else:
134 token_block = 32
135 if batch_size <= 16:
136 batch_scan_size = triton.next_power_of_2(batch_size)
137 else:
138 # Unused by the binary-search path; kept as a valid constexpr placeholder.
139 batch_scan_size = 1
140 search_steps = batch_size.bit_length()
142 grid = (triton.cdiv(num_tokens, token_block), num_quant_blocks)
143 _cp_gather_indexer_quant_cache_kernel[grid](
144 k_cache_value,
145 k_cache_scale,
146 k_fp8,
147 k_fp8_scale,
148 block_table,
149 cu_seqlen,
150 block_size,
151 block_table_stride,
152 k_cache_value.stride(0),
153 k_cache_scale.stride(0),
154 k_fp8.stride(0),
155 num_quant_blocks,
156 batch_size,
157 head_dim,
158 quant_block_size,
159 token_block,
160 batch_scan_size,
161 search_steps,
162 )