Coverage for src/flag_gems/fused/concat_and_cache_mla.py: 51%
80 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
12# enum Fp8KVCacheDataType
13FP8_KV_CACHE_DATA_TYPE_AUTO = tl.constexpr(0)
14FP8_KV_CACHE_DATA_TYPE_FP8E4M3 = tl.constexpr(1)
15FP8_KV_CACHE_DATA_TYPE_FP8E5M2 = tl.constexpr(2)
18@libentry()
19@triton.jit
20def concat_and_cache_mla_kernel(
21 # pointers
22 kv_c_ptr, # in, [num_tokens, kv_lora_rank]
23 k_pe_ptr, # in, [num_tokens, pe_dim]
24 kv_cache_ptr, # out, [num_blocks, block_size, kv_lora_rank + pe_dim]
25 slot_mapping_ptr, # in, [num_tokens]
26 # strides
27 block_stride,
28 entry_stride,
29 kv_c_stride,
30 k_pe_stride,
31 # dims
32 kv_lora_rank,
33 pe_dim,
34 block_size, # kv cache block size
35 scale_ptr,
36 # data type
37 kv_dtype: tl.constexpr, # one of Fp8KVCacheDataType
38 BLOCK_SIZE: tl.constexpr,
39):
40 token_idx = tl.program_id(0)
41 slot_idx = tl.load(slot_mapping_ptr + token_idx)
43 # Skip padded tokens
44 if slot_idx < 0:
45 return
47 # Calculate cache position
48 block_id = slot_idx // block_size
49 block_offset = slot_idx % block_size
50 cache_base = block_id * block_stride + block_offset * entry_stride
52 # Preload scale if needed
53 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO:
54 scale_val = tl.load(scale_ptr)
56 # Process kv_c section
57 for i in range(0, kv_lora_rank, BLOCK_SIZE):
58 idx = i + tl.arange(0, BLOCK_SIZE)
59 mask = idx < kv_lora_rank
61 src_ptr = kv_c_ptr + token_idx * kv_c_stride + idx
62 dst_ptr = kv_cache_ptr + cache_base + idx
64 val = tl.load(src_ptr, mask=mask, other=0)
66 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO:
67 if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3:
68 val = (val / scale_val).to(tl.float8e4nv)
69 elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2:
70 val = (val / scale_val).to(tl.float8e5)
71 val = val.to(tl.uint8, bitcast=True)
72 tl.store(dst_ptr, val, mask=mask)
74 # Process k_pe section
75 for j in range(0, pe_dim, BLOCK_SIZE):
76 idx = j + tl.arange(0, BLOCK_SIZE)
77 mask = idx < pe_dim
79 src_ptr = k_pe_ptr + token_idx * k_pe_stride + idx
80 dst_ptr = kv_cache_ptr + cache_base + kv_lora_rank + idx
82 val = tl.load(src_ptr, mask=mask, other=0)
84 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO:
85 if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3:
86 val = (val / scale_val).to(tl.float8e4nv)
87 elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2:
88 val = (val / scale_val).to(tl.float8e5)
89 val = val.to(tl.uint8, bitcast=True)
90 tl.store(dst_ptr, val, mask=mask)
93class ConcatAndCacheMla(torch.autograd.Function):
94 @staticmethod
95 def forward(
96 ctx,
97 kv_c: torch.Tensor,
98 k_pe: torch.Tensor,
99 kv_cache: torch.Tensor,
100 slot_mapping: torch.Tensor,
101 kv_cache_dtype: str,
102 scale: torch.Tensor,
103 ):
104 if kv_cache_dtype != "auto" and kv_cache.dtype != torch.uint8:
105 raise ValueError("For FP8 kv_cache must be uint8 dtype")
106 if kv_cache_dtype == "auto" and kv_cache.dtype != kv_c.dtype:
107 raise ValueError("For auto mode kv_cache must match input dtype")
109 # Map string dtype to internal constants
110 kv_dtype_map = {
111 "auto": FP8_KV_CACHE_DATA_TYPE_AUTO,
112 "fp8": FP8_KV_CACHE_DATA_TYPE_FP8E4M3,
113 "fp8e4m3": FP8_KV_CACHE_DATA_TYPE_FP8E4M3,
114 "fp8e5m2": FP8_KV_CACHE_DATA_TYPE_FP8E5M2,
115 }
116 kv_dtype = kv_dtype_map.get(kv_cache_dtype)
117 if kv_dtype is None:
118 raise ValueError(f"Unsupported kv_cache_dtype: {kv_cache_dtype}")
119 kv_dtype = int(kv_dtype) # tl.constexpr->int
121 kv_lora_rank = kv_c.size(1)
122 pe_dim = k_pe.size(1)
123 num_tokens = slot_mapping.size(0)
125 # make sure `scale` is a scalar tensor
126 if scale.numel() != 1:
127 scale = scale.view(1)
129 # make sure all tensors are on the same device
130 device = kv_c.device
131 k_pe = k_pe.to(device)
132 kv_cache = kv_cache.to(device)
133 slot_mapping = slot_mapping.to(device)
134 scale = scale.to(device)
136 # configure kernel launch
137 grid = (num_tokens,)
138 BLOCK_SIZE = min(kv_lora_rank, 512)
140 assert kv_cache.dim() == 3, "kv_cache must be a 3D tensor"
141 assert (
142 kv_cache.size(2) == kv_lora_rank + pe_dim
143 ), "kv_cache's last dimension must match kv_lora_rank + pe_dim"
144 with torch_device_fn.device(device):
145 concat_and_cache_mla_kernel[grid](
146 kv_c,
147 k_pe,
148 kv_cache,
149 slot_mapping,
150 kv_cache.stride(0), # block_stride
151 kv_cache.stride(1), # entry_stride
152 kv_c.stride(0), # kv_c_stride
153 k_pe.stride(0), # k_pe_stride
154 kv_lora_rank,
155 pe_dim,
156 kv_cache.size(1), # kv cache block_size
157 scale,
158 kv_dtype=kv_dtype,
159 BLOCK_SIZE=BLOCK_SIZE,
160 )
161 return None
164def concat_and_cache_mla(
165 kv_c: torch.Tensor,
166 k_pe: torch.Tensor,
167 kv_cache: torch.Tensor,
168 slot_mapping: torch.Tensor,
169 kv_cache_dtype: str,
170 scale: torch.Tensor,
171) -> None:
172 logger.debug("GEMS CONCAT_AND_CACHE_MLA")
173 return ConcatAndCacheMla.apply(
174 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
175 )