Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/concat_and_cache_mla.py: 0%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11# enum Fp8KVCacheDataType 

12FP8_KV_CACHE_DATA_TYPE_AUTO = tl.constexpr(0) 

13FP8_KV_CACHE_DATA_TYPE_FP8E4M3 = tl.constexpr(1) 

14FP8_KV_CACHE_DATA_TYPE_FP8E5M2 = tl.constexpr(2) 

15 

16 

17@libentry() 

18@triton.jit 

19def concat_and_cache_mla_kernel( 

20 # pointers 

21 kv_c_ptr, # in, [num_tokens, kv_lora_rank] 

22 k_pe_ptr, # in, [num_tokens, pe_dim] 

23 kv_cache_ptr, # out, [num_blocks, block_size, kv_lora_rank + pe_dim] 

24 slot_mapping_ptr, # in, [num_tokens] 

25 # strides 

26 block_stride, 

27 entry_stride, 

28 kv_c_stride, 

29 k_pe_stride, 

30 # dims 

31 kv_lora_rank, 

32 pe_dim, 

33 block_size, # kv cache block size 

34 scale_ptr, 

35 # data type 

36 kv_dtype: tl.constexpr, # one of Fp8KVCacheDataType 

37 BLOCK_SIZE: tl.constexpr, 

38): 

39 token_idx = tl.program_id(0) 

40 slot_idx = tl.load(slot_mapping_ptr + token_idx) 

41 

42 # Skip padded tokens 

43 if slot_idx < 0: 

44 return 

45 

46 # Calculate cache position 

47 block_id = slot_idx // block_size 

48 block_offset = slot_idx % block_size 

49 cache_base = block_id * block_stride + block_offset * entry_stride 

50 

51 # Preload scale if needed 

52 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO: 

53 scale_val = tl.load(scale_ptr) 

54 

55 # Process kv_c section 

56 for i in range(0, kv_lora_rank, BLOCK_SIZE): 

57 idx = i + tl.arange(0, BLOCK_SIZE) 

58 mask = idx < kv_lora_rank 

59 

60 src_ptr = kv_c_ptr + token_idx * kv_c_stride + idx 

61 dst_ptr = kv_cache_ptr + cache_base + idx 

62 

63 val = tl.load(src_ptr, mask=mask, other=0) 

64 

65 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO: 

66 if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3: 

67 val = (val / scale_val).to(tl.float8e4nv) 

68 elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2: 

69 val = (val / scale_val).to(tl.float8e5) 

70 val = val.to(tl.uint8, bitcast=True) 

71 tl.store(dst_ptr, val, mask=mask) 

72 

73 # Process k_pe section 

74 for j in range(0, pe_dim, BLOCK_SIZE): 

75 idx = j + tl.arange(0, BLOCK_SIZE) 

76 mask = idx < pe_dim 

77 

78 src_ptr = k_pe_ptr + token_idx * k_pe_stride + idx 

79 dst_ptr = kv_cache_ptr + cache_base + kv_lora_rank + idx 

80 

81 val = tl.load(src_ptr, mask=mask, other=0) 

82 

83 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO: 

84 if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3: 

85 val = (val / scale_val).to(tl.float8e4nv) 

86 elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2: 

87 val = (val / scale_val).to(tl.float8e5) 

88 val = val.to(tl.uint8, bitcast=True) 

89 tl.store(dst_ptr, val, mask=mask) 

90 

91 

92class ConcatAndCacheMla(torch.autograd.Function): 

93 @staticmethod 

94 def forward( 

95 ctx, 

96 kv_c: torch.Tensor, 

97 k_pe: torch.Tensor, 

98 kv_cache: torch.Tensor, 

99 slot_mapping: torch.Tensor, 

100 kv_cache_dtype: str, 

101 scale: torch.Tensor, 

102 ): 

103 if kv_cache_dtype != "auto" and kv_cache.dtype != torch.uint8: 

104 raise ValueError("For FP8 kv_cache must be uint8 dtype") 

105 if kv_cache_dtype == "auto" and kv_cache.dtype != kv_c.dtype: 

106 raise ValueError("For auto mode kv_cache must match input dtype") 

107 

108 # Map string dtype to internal constants 

109 kv_dtype_map = { 

110 "auto": FP8_KV_CACHE_DATA_TYPE_AUTO, 

111 "fp8": FP8_KV_CACHE_DATA_TYPE_FP8E4M3, 

112 "fp8e4m3": FP8_KV_CACHE_DATA_TYPE_FP8E4M3, 

113 "fp8e5m2": FP8_KV_CACHE_DATA_TYPE_FP8E5M2, 

114 } 

115 kv_dtype = kv_dtype_map.get(kv_cache_dtype) 

116 if kv_dtype is None: 

117 raise ValueError(f"Unsupported kv_cache_dtype: {kv_cache_dtype}") 

118 kv_dtype = int(kv_dtype) # tl.constexpr->int 

119 

120 kv_lora_rank = kv_c.size(1) 

121 pe_dim = k_pe.size(1) 

122 num_tokens = slot_mapping.size(0) 

123 

124 # make sure `scale` is a scalar tensor 

125 if scale.numel() != 1: 

126 scale = scale.view(1) 

127 

128 # make sure all tensors are on the same device 

129 device = kv_c.device 

130 k_pe = k_pe.to(device) 

131 kv_cache = kv_cache.to(device) 

132 slot_mapping = slot_mapping.to(device) 

133 scale = scale.to(device) 

134 

135 # configure kernel launch 

136 grid = (num_tokens,) 

137 BLOCK_SIZE = min(kv_lora_rank, 512) 

138 

139 assert kv_cache.dim() == 3, "kv_cache must be a 3D tensor" 

140 assert ( 

141 kv_cache.size(2) == kv_lora_rank + pe_dim 

142 ), "kv_cache's last dimension must match kv_lora_rank + pe_dim" 

143 with torch.cuda.device(device): 

144 concat_and_cache_mla_kernel[grid]( 

145 kv_c, 

146 k_pe, 

147 kv_cache, 

148 slot_mapping, 

149 kv_cache.stride(0), # block_stride 

150 kv_cache.stride(1), # entry_stride 

151 kv_c.stride(0), # kv_c_stride 

152 k_pe.stride(0), # k_pe_stride 

153 kv_lora_rank, 

154 pe_dim, 

155 kv_cache.size(1), # kv cache block_size 

156 scale, 

157 kv_dtype=kv_dtype, 

158 BLOCK_SIZE=BLOCK_SIZE, 

159 ) 

160 return kv_cache 

161 

162 

163def concat_and_cache_mla( 

164 kv_c: torch.Tensor, 

165 k_pe: torch.Tensor, 

166 kv_cache: torch.Tensor, 

167 slot_mapping: torch.Tensor, 

168 kv_cache_dtype: str, 

169 scale: torch.Tensor, 

170) -> None: 

171 logger.debug("GEMS CONCAT_AND_CACHE_MLA") 

172 return ConcatAndCacheMla.apply( 

173 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale 

174 )