Coverage for src/flag_gems/fused/indexer_k_quant_and_cache.py: 14%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1# Adapted from vLLM v0.20.2:
2# csrc/cache_kernels.cu::indexer_k_quant_and_cache_kernel
4import torch
5import triton
6import triton.language as tl
9def _get_fp8_dtype() -> torch.dtype:
10 try:
11 from vllm.platforms import current_platform
13 return current_platform.fp8_dtype()
14 except ImportError:
15 pass
17 if getattr(torch.version, "hip", None) is not None and hasattr(
18 torch, "float8_e4m3fnuz"
19 ):
20 return torch.float8_e4m3fnuz
21 if hasattr(torch, "float8_e4m3fn"):
22 return torch.float8_e4m3fn
23 raise RuntimeError("float8_e4m3fn is required for indexer_k_quant_and_cache")
26def _is_fp8_fnuz(dtype: torch.dtype) -> bool:
27 return hasattr(torch, "float8_e4m3fnuz") and dtype == torch.float8_e4m3fnuz
30@triton.jit
31def _indexer_k_quant_and_cache_kernel(
32 k_ptr,
33 kv_cache_ptr,
34 kv_cache_scale_ptr,
35 slot_mapping_ptr,
36 kv_cache_scale_stride,
37 kv_cache_value_stride,
38 block_size,
39 num_quant_blocks,
40 head_dim: tl.constexpr,
41 QUANT_BLOCK_SIZE: tl.constexpr,
42 IS_FNUZ: tl.constexpr,
43 USE_UE8M0: tl.constexpr,
44):
45 tid = tl.program_id(0)
46 quant_block_id = tl.program_id(1) * 4
47 quant_block_offsets = tl.arange(0, 4)
48 head_offsets = tl.arange(0, QUANT_BLOCK_SIZE)
49 offsets = (
50 quant_block_id + quant_block_offsets[:, None]
51 ) * QUANT_BLOCK_SIZE + head_offsets[None, :]
52 mask = offsets < head_dim
54 src_ptr = k_ptr + tid * head_dim
55 slot_id = tl.load(slot_mapping_ptr + tid)
56 if slot_id < 0:
57 return
59 block_id = slot_id // block_size
60 block_offset = slot_id % block_size
62 val = tl.load(src_ptr + offsets, mask=mask, other=0.0)
63 amax = tl.max(tl.abs(val).to(tl.float32), axis=1)
64 if IS_FNUZ:
65 scale = tl.maximum(1e-4, amax) / 224.0
66 else:
67 scale = tl.maximum(1e-4, amax) / 448.0
69 if USE_UE8M0:
70 scale = tl.exp2(tl.ceil(tl.log2(scale)))
72 fp8_val = (val.to(tl.float32) / scale[:, None]).to(kv_cache_ptr.type.element_ty)
73 dst_ptr = kv_cache_ptr + block_id * kv_cache_value_stride + block_offset * head_dim
74 tl.store(dst_ptr + offsets, fp8_val, mask=mask)
76 dst_scale_ptr = (
77 kv_cache_scale_ptr
78 + block_id * kv_cache_scale_stride
79 + block_offset * num_quant_blocks
80 + quant_block_id
81 )
82 scale_mask = quant_block_id + quant_block_offsets < num_quant_blocks
83 tl.store(dst_scale_ptr + quant_block_offsets, scale, mask=scale_mask)
86def indexer_k_quant_and_cache(
87 k: torch.Tensor,
88 kv_cache: torch.Tensor,
89 slot_mapping: torch.Tensor,
90 quant_block_size,
91 scale_fmt,
92):
93 num_blocks = kv_cache.shape[0]
94 head_dim = k.shape[-1]
95 num_tokens = slot_mapping.shape[0]
96 block_size = kv_cache.shape[1]
97 if head_dim % quant_block_size != 0:
98 raise ValueError("head_dim must be divisible by quant_block_size")
99 num_quant_blocks = head_dim // quant_block_size
101 kv_cache_flat = kv_cache.view(num_blocks, -1)
102 fp8_dtype = _get_fp8_dtype()
103 kv_cache_value = kv_cache_flat[:, : block_size * head_dim].view(fp8_dtype)
104 kv_cache_scale = kv_cache_flat[:, block_size * head_dim :].view(torch.float32)
105 _indexer_k_quant_and_cache_kernel[(num_tokens, triton.cdiv(num_quant_blocks, 4))](
106 k,
107 kv_cache_value,
108 kv_cache_scale,
109 slot_mapping,
110 kv_cache_scale.stride(0),
111 kv_cache_value.stride(0),
112 block_size,
113 num_quant_blocks,
114 head_dim,
115 quant_block_size,
116 IS_FNUZ=_is_fp8_fnuz(fp8_dtype),
117 USE_UE8M0=scale_fmt == "ue8m0",
118 )