Coverage for src/flag_gems/fused/concat_and_cache_mla.py: 51%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12# enum Fp8KVCacheDataType 

13FP8_KV_CACHE_DATA_TYPE_AUTO = tl.constexpr(0) 

14FP8_KV_CACHE_DATA_TYPE_FP8E4M3 = tl.constexpr(1) 

15FP8_KV_CACHE_DATA_TYPE_FP8E5M2 = tl.constexpr(2) 

16 

17 

18@libentry() 

19@triton.jit 

20def concat_and_cache_mla_kernel( 

21 # pointers 

22 kv_c_ptr, # in, [num_tokens, kv_lora_rank] 

23 k_pe_ptr, # in, [num_tokens, pe_dim] 

24 kv_cache_ptr, # out, [num_blocks, block_size, kv_lora_rank + pe_dim] 

25 slot_mapping_ptr, # in, [num_tokens] 

26 # strides 

27 block_stride, 

28 entry_stride, 

29 kv_c_stride, 

30 k_pe_stride, 

31 # dims 

32 kv_lora_rank, 

33 pe_dim, 

34 block_size, # kv cache block size 

35 scale_ptr, 

36 # data type 

37 kv_dtype: tl.constexpr, # one of Fp8KVCacheDataType 

38 BLOCK_SIZE: tl.constexpr, 

39): 

40 token_idx = tl.program_id(0) 

41 slot_idx = tl.load(slot_mapping_ptr + token_idx) 

42 

43 # Skip padded tokens 

44 if slot_idx < 0: 

45 return 

46 

47 # Calculate cache position 

48 block_id = slot_idx // block_size 

49 block_offset = slot_idx % block_size 

50 cache_base = block_id * block_stride + block_offset * entry_stride 

51 

52 # Preload scale if needed 

53 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO: 

54 scale_val = tl.load(scale_ptr) 

55 

56 # Process kv_c section 

57 for i in range(0, kv_lora_rank, BLOCK_SIZE): 

58 idx = i + tl.arange(0, BLOCK_SIZE) 

59 mask = idx < kv_lora_rank 

60 

61 src_ptr = kv_c_ptr + token_idx * kv_c_stride + idx 

62 dst_ptr = kv_cache_ptr + cache_base + idx 

63 

64 val = tl.load(src_ptr, mask=mask, other=0) 

65 

66 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO: 

67 if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3: 

68 val = (val / scale_val).to(tl.float8e4nv) 

69 elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2: 

70 val = (val / scale_val).to(tl.float8e5) 

71 val = val.to(tl.uint8, bitcast=True) 

72 tl.store(dst_ptr, val, mask=mask) 

73 

74 # Process k_pe section 

75 for j in range(0, pe_dim, BLOCK_SIZE): 

76 idx = j + tl.arange(0, BLOCK_SIZE) 

77 mask = idx < pe_dim 

78 

79 src_ptr = k_pe_ptr + token_idx * k_pe_stride + idx 

80 dst_ptr = kv_cache_ptr + cache_base + kv_lora_rank + idx 

81 

82 val = tl.load(src_ptr, mask=mask, other=0) 

83 

84 if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO: 

85 if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3: 

86 val = (val / scale_val).to(tl.float8e4nv) 

87 elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2: 

88 val = (val / scale_val).to(tl.float8e5) 

89 val = val.to(tl.uint8, bitcast=True) 

90 tl.store(dst_ptr, val, mask=mask) 

91 

92 

93class ConcatAndCacheMla(torch.autograd.Function): 

94 @staticmethod 

95 def forward( 

96 ctx, 

97 kv_c: torch.Tensor, 

98 k_pe: torch.Tensor, 

99 kv_cache: torch.Tensor, 

100 slot_mapping: torch.Tensor, 

101 kv_cache_dtype: str, 

102 scale: torch.Tensor, 

103 ): 

104 if kv_cache_dtype != "auto" and kv_cache.dtype != torch.uint8: 

105 raise ValueError("For FP8 kv_cache must be uint8 dtype") 

106 if kv_cache_dtype == "auto" and kv_cache.dtype != kv_c.dtype: 

107 raise ValueError("For auto mode kv_cache must match input dtype") 

108 

109 # Map string dtype to internal constants 

110 kv_dtype_map = { 

111 "auto": FP8_KV_CACHE_DATA_TYPE_AUTO, 

112 "fp8": FP8_KV_CACHE_DATA_TYPE_FP8E4M3, 

113 "fp8e4m3": FP8_KV_CACHE_DATA_TYPE_FP8E4M3, 

114 "fp8e5m2": FP8_KV_CACHE_DATA_TYPE_FP8E5M2, 

115 } 

116 kv_dtype = kv_dtype_map.get(kv_cache_dtype) 

117 if kv_dtype is None: 

118 raise ValueError(f"Unsupported kv_cache_dtype: {kv_cache_dtype}") 

119 kv_dtype = int(kv_dtype) # tl.constexpr->int 

120 

121 kv_lora_rank = kv_c.size(1) 

122 pe_dim = k_pe.size(1) 

123 num_tokens = slot_mapping.size(0) 

124 

125 # make sure `scale` is a scalar tensor 

126 if scale.numel() != 1: 

127 scale = scale.view(1) 

128 

129 # make sure all tensors are on the same device 

130 device = kv_c.device 

131 k_pe = k_pe.to(device) 

132 kv_cache = kv_cache.to(device) 

133 slot_mapping = slot_mapping.to(device) 

134 scale = scale.to(device) 

135 

136 # configure kernel launch 

137 grid = (num_tokens,) 

138 BLOCK_SIZE = min(kv_lora_rank, 512) 

139 

140 assert kv_cache.dim() == 3, "kv_cache must be a 3D tensor" 

141 assert ( 

142 kv_cache.size(2) == kv_lora_rank + pe_dim 

143 ), "kv_cache's last dimension must match kv_lora_rank + pe_dim" 

144 with torch_device_fn.device(device): 

145 concat_and_cache_mla_kernel[grid]( 

146 kv_c, 

147 k_pe, 

148 kv_cache, 

149 slot_mapping, 

150 kv_cache.stride(0), # block_stride 

151 kv_cache.stride(1), # entry_stride 

152 kv_c.stride(0), # kv_c_stride 

153 k_pe.stride(0), # k_pe_stride 

154 kv_lora_rank, 

155 pe_dim, 

156 kv_cache.size(1), # kv cache block_size 

157 scale, 

158 kv_dtype=kv_dtype, 

159 BLOCK_SIZE=BLOCK_SIZE, 

160 ) 

161 return None 

162 

163 

164def concat_and_cache_mla( 

165 kv_c: torch.Tensor, 

166 k_pe: torch.Tensor, 

167 kv_cache: torch.Tensor, 

168 slot_mapping: torch.Tensor, 

169 kv_cache_dtype: str, 

170 scale: torch.Tensor, 

171) -> None: 

172 logger.debug("GEMS CONCAT_AND_CACHE_MLA") 

173 return ConcatAndCacheMla.apply( 

174 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale 

175 )