Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/concat_and_cache_mla.py: 0%
79 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11# enum Fp8KVCacheDataType
12FP8_KV_CACHE_DATA_TYPE_AUTO = tl.constexpr(0)
13FP8_KV_CACHE_DATA_TYPE_FP8E4M3 = tl.constexpr(1)
14FP8_KV_CACHE_DATA_TYPE_FP8E5M2 = tl.constexpr(2)
17@libentry()
18@triton.jit
19def concat_and_cache_mla_kernel(
20 # pointers
21 kv_c_ptr, # in, [num_tokens, kv_lora_rank]
22 k_pe_ptr, # in, [num_tokens, pe_dim]
23 kv_cache_ptr, # out, [num_blocks, block_size, kv_lora_rank + pe_dim]
24 slot_mapping_ptr, # in, [num_tokens]
25 # strides
26 block_stride,
27 entry_stride,
28 kv_c_stride,
29 k_pe_stride,
30 # dims
31 kv_lora_rank,
32 pe_dim,
33 block_size, # kv cache block size
34 scale_ptr,
35 # data type
36 kv_dtype: tl.constexpr, # one of Fp8KVCacheDataType
37 BLOCK_SIZE: tl.constexpr,
38):
39 token_idx = tl.program_id(0)
40 slot_idx = tl.load(slot_mapping_ptr + token_idx)
42 # Skip padded tokens
43 if slot_idx < 0:
44 return
46 # Calculate cache position
47 block_id = slot_idx // block_size
48 block_offset = slot_idx % block_size
49 cache_base = block_id * block_stride + block_offset * entry_stride
51 # Preload scale if needed
52 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO:
53 scale_val = tl.load(scale_ptr)
55 # Process kv_c section
56 for i in range(0, kv_lora_rank, BLOCK_SIZE):
57 idx = i + tl.arange(0, BLOCK_SIZE)
58 mask = idx < kv_lora_rank
60 src_ptr = kv_c_ptr + token_idx * kv_c_stride + idx
61 dst_ptr = kv_cache_ptr + cache_base + idx
63 val = tl.load(src_ptr, mask=mask, other=0)
65 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO:
66 if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3:
67 val = (val / scale_val).to(tl.float8e4nv)
68 elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2:
69 val = (val / scale_val).to(tl.float8e5)
70 val = val.to(tl.uint8, bitcast=True)
71 tl.store(dst_ptr, val, mask=mask)
73 # Process k_pe section
74 for j in range(0, pe_dim, BLOCK_SIZE):
75 idx = j + tl.arange(0, BLOCK_SIZE)
76 mask = idx < pe_dim
78 src_ptr = k_pe_ptr + token_idx * k_pe_stride + idx
79 dst_ptr = kv_cache_ptr + cache_base + kv_lora_rank + idx
81 val = tl.load(src_ptr, mask=mask, other=0)
83 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO:
84 if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3:
85 val = (val / scale_val).to(tl.float8e4nv)
86 elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2:
87 val = (val / scale_val).to(tl.float8e5)
88 val = val.to(tl.uint8, bitcast=True)
89 tl.store(dst_ptr, val, mask=mask)
92class ConcatAndCacheMla(torch.autograd.Function):
93 @staticmethod
94 def forward(
95 ctx,
96 kv_c: torch.Tensor,
97 k_pe: torch.Tensor,
98 kv_cache: torch.Tensor,
99 slot_mapping: torch.Tensor,
100 kv_cache_dtype: str,
101 scale: torch.Tensor,
102 ):
103 if kv_cache_dtype != "auto" and kv_cache.dtype != torch.uint8:
104 raise ValueError("For FP8 kv_cache must be uint8 dtype")
105 if kv_cache_dtype == "auto" and kv_cache.dtype != kv_c.dtype:
106 raise ValueError("For auto mode kv_cache must match input dtype")
108 # Map string dtype to internal constants
109 kv_dtype_map = {
110 "auto": FP8_KV_CACHE_DATA_TYPE_AUTO,
111 "fp8": FP8_KV_CACHE_DATA_TYPE_FP8E4M3,
112 "fp8e4m3": FP8_KV_CACHE_DATA_TYPE_FP8E4M3,
113 "fp8e5m2": FP8_KV_CACHE_DATA_TYPE_FP8E5M2,
114 }
115 kv_dtype = kv_dtype_map.get(kv_cache_dtype)
116 if kv_dtype is None:
117 raise ValueError(f"Unsupported kv_cache_dtype: {kv_cache_dtype}")
118 kv_dtype = int(kv_dtype) # tl.constexpr->int
120 kv_lora_rank = kv_c.size(1)
121 pe_dim = k_pe.size(1)
122 num_tokens = slot_mapping.size(0)
124 # make sure `scale` is a scalar tensor
125 if scale.numel() != 1:
126 scale = scale.view(1)
128 # make sure all tensors are on the same device
129 device = kv_c.device
130 k_pe = k_pe.to(device)
131 kv_cache = kv_cache.to(device)
132 slot_mapping = slot_mapping.to(device)
133 scale = scale.to(device)
135 # configure kernel launch
136 grid = (num_tokens,)
137 BLOCK_SIZE = min(kv_lora_rank, 512)
139 assert kv_cache.dim() == 3, "kv_cache must be a 3D tensor"
140 assert (
141 kv_cache.size(2) == kv_lora_rank + pe_dim
142 ), "kv_cache's last dimension must match kv_lora_rank + pe_dim"
143 with torch.cuda.device(device):
144 concat_and_cache_mla_kernel[grid](
145 kv_c,
146 k_pe,
147 kv_cache,
148 slot_mapping,
149 kv_cache.stride(0), # block_stride
150 kv_cache.stride(1), # entry_stride
151 kv_c.stride(0), # kv_c_stride
152 k_pe.stride(0), # k_pe_stride
153 kv_lora_rank,
154 pe_dim,
155 kv_cache.size(1), # kv cache block_size
156 scale,
157 kv_dtype=kv_dtype,
158 BLOCK_SIZE=BLOCK_SIZE,
159 )
160 return kv_cache
163def concat_and_cache_mla(
164 kv_c: torch.Tensor,
165 k_pe: torch.Tensor,
166 kv_cache: torch.Tensor,
167 slot_mapping: torch.Tensor,
168 kv_cache_dtype: str,
169 scale: torch.Tensor,
170) -> None:
171 logger.debug("GEMS CONCAT_AND_CACHE_MLA")
172 return ConcatAndCacheMla.apply(
173 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
174 )