Coverage for src/flag_gems/fused/cp_gather_indexer_k_quant_cache.py: 43%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1# Adapted from vLLM v0.20.2:
2# csrc/cache_kernels.cu::cp_gather_indexer_k_quant_cache_kernel
4import torch
5import triton
6import triton.language as tl
9@triton.jit
10def _cp_gather_indexer_quant_cache_kernel(
11 kv_cache_ptr,
12 kv_cache_scale_ptr,
13 k_fp8_ptr,
14 k_scale_ptr,
15 block_table_ptr,
16 cu_seqlen_ptr,
17 block_size,
18 block_table_stride,
19 kv_cache_stride,
20 kv_cache_scale_stride,
21 k_fp8_stride,
22 num_quant_blocks,
23 batch_size: tl.constexpr,
24 HEAD_DIM: tl.constexpr,
25 QUANT_BLOCK_SIZE: tl.constexpr,
26 BATCH_BLOCK: tl.constexpr,
27):
28 tid = tl.program_id(0)
29 quant_block_id = tl.program_id(1)
30 batch_offsets = tl.arange(0, BATCH_BLOCK)
31 batch_mask = batch_offsets < batch_size
32 seq_starts = tl.load(cu_seqlen_ptr + batch_offsets, mask=batch_mask, other=0)
33 seq_ends = tl.load(cu_seqlen_ptr + batch_offsets + 1, mask=batch_mask, other=0)
34 in_batch = (tid >= seq_starts) & (tid < seq_ends) & batch_mask
35 batch_id = tl.max(tl.where(in_batch, batch_offsets, -1), axis=0)
36 if batch_id < 0:
37 return
39 batch_start = tl.load(cu_seqlen_ptr + batch_id)
40 batch_offset = tid - batch_start
42 block_table_id = batch_offset // block_size
43 block_offset = batch_offset % block_size
44 block_table_offset = batch_id * block_table_stride + block_table_id
45 block_id = tl.load(block_table_ptr + block_table_offset)
47 offsets = quant_block_id * QUANT_BLOCK_SIZE + tl.arange(0, QUANT_BLOCK_SIZE)
48 mask = offsets < HEAD_DIM
49 src_cache_offset = block_id * kv_cache_stride + block_offset * HEAD_DIM
50 src_scale_offset = (
51 block_id * kv_cache_scale_stride
52 + block_offset * num_quant_blocks
53 + quant_block_id
54 )
55 dst_offset = tid * k_fp8_stride
57 src_scale_ptr = kv_cache_scale_ptr + src_scale_offset
58 src_cache_ptr = kv_cache_ptr + src_cache_offset
59 dst_k_ptr = k_fp8_ptr + dst_offset
61 scale_val = tl.load(src_scale_ptr)
62 tl.store(k_scale_ptr + tid * num_quant_blocks + quant_block_id, scale_val)
63 val = tl.load(src_cache_ptr + offsets, mask=mask)
64 tl.store(dst_k_ptr + offsets, val, mask=mask)
67def cp_gather_indexer_k_quant_cache(
68 k_cache: torch.Tensor,
69 k_fp8: torch.Tensor,
70 k_fp8_scale: torch.Tensor,
71 block_table: torch.Tensor,
72 cu_seqlen: torch.Tensor,
73):
74 num_tokens = k_fp8.size(0)
75 block_size = k_cache.size(1)
76 block_table_stride = block_table.stride(0)
77 head_dim = k_fp8.shape[-1]
78 num_blocks = k_cache.shape[0]
79 quant_block_size = head_dim * 4 // k_fp8_scale.size(1)
80 if head_dim % quant_block_size != 0:
81 raise ValueError("head_dim must be divisible by quant_block_size")
82 num_quant_blocks = head_dim // quant_block_size
84 k_cache_flat = k_cache.view(num_blocks, -1)
85 k_cache_value = k_cache_flat[:, : block_size * head_dim]
86 k_cache_scale = k_cache_flat[:, block_size * head_dim :].view(torch.float32)
87 k_fp8 = k_fp8.view(torch.uint8)
88 k_fp8_scale = k_fp8_scale.view(torch.float32)
89 batch_size = block_table.shape[0]
90 batch_block = triton.next_power_of_2(batch_size)
92 _cp_gather_indexer_quant_cache_kernel[(num_tokens, num_quant_blocks)](
93 k_cache_value,
94 k_cache_scale,
95 k_fp8,
96 k_fp8_scale,
97 block_table,
98 cu_seqlen,
99 block_size,
100 block_table_stride,
101 k_cache_value.stride(0),
102 k_cache_scale.stride(0),
103 k_fp8.stride(0),
104 num_quant_blocks,
105 batch_size,
106 head_dim,
107 quant_block_size,
108 batch_block,
109 )