Coverage for src/flag_gems/fused/indexer_k_quant_and_cache.py: 15%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1# Adapted from vLLM v0.20.2: 

2# csrc/cache_kernels.cu::indexer_k_quant_and_cache_kernel 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8 

9def _get_fp8_dtype() -> torch.dtype: 

10 try: 

11 from vllm.platforms import current_platform 

12 

13 return current_platform.fp8_dtype() 

14 except ImportError: 

15 pass 

16 

17 if getattr(torch.version, "hip", None) is not None and hasattr( 

18 torch, "float8_e4m3fnuz" 

19 ): 

20 return torch.float8_e4m3fnuz 

21 if hasattr(torch, "float8_e4m3fn"): 

22 return torch.float8_e4m3fn 

23 raise RuntimeError("float8_e4m3fn is required for indexer_k_quant_and_cache") 

24 

25 

26def _is_fp8_fnuz(dtype: torch.dtype) -> bool: 

27 return hasattr(torch, "float8_e4m3fnuz") and dtype == torch.float8_e4m3fnuz 

28 

29 

30@triton.jit 

31def _indexer_k_quant_and_cache_kernel( 

32 k_ptr, 

33 kv_cache_ptr, 

34 kv_cache_scale_ptr, 

35 slot_mapping_ptr, 

36 kv_cache_scale_stride, 

37 kv_cache_value_stride, 

38 block_size, 

39 num_quant_blocks, 

40 head_dim: tl.constexpr, 

41 QUANT_BLOCK_SIZE: tl.constexpr, 

42 IS_FNUZ: tl.constexpr, 

43 USE_UE8M0: tl.constexpr, 

44): 

45 tid = tl.program_id(0) 

46 quant_block_id = tl.program_id(1) 

47 offsets = quant_block_id * QUANT_BLOCK_SIZE + tl.arange(0, QUANT_BLOCK_SIZE) 

48 mask = offsets < head_dim 

49 

50 src_ptr = k_ptr + tid * head_dim 

51 slot_id = tl.load(slot_mapping_ptr + tid) 

52 if slot_id < 0: 

53 return 

54 

55 block_id = slot_id // block_size 

56 block_offset = slot_id % block_size 

57 

58 val = tl.load(src_ptr + offsets, mask=mask, other=0.0) 

59 amax = tl.max(tl.abs(val).to(tl.float32), axis=0) 

60 if IS_FNUZ: 

61 scale = tl.maximum(1e-4, amax) / 224.0 

62 else: 

63 scale = tl.maximum(1e-4, amax) / 448.0 

64 

65 if USE_UE8M0: 

66 scale = tl.exp2(tl.ceil(tl.log2(scale))) 

67 

68 fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) 

69 dst_ptr = kv_cache_ptr + block_id * kv_cache_value_stride + block_offset * head_dim 

70 tl.store(dst_ptr + offsets, fp8_val, mask=mask) 

71 

72 dst_scale_ptr = ( 

73 kv_cache_scale_ptr 

74 + block_id * kv_cache_scale_stride 

75 + block_offset * num_quant_blocks 

76 + quant_block_id 

77 ) 

78 tl.store(dst_scale_ptr, scale) 

79 

80 

81def indexer_k_quant_and_cache( 

82 k: torch.Tensor, 

83 kv_cache: torch.Tensor, 

84 slot_mapping: torch.Tensor, 

85 quant_block_size, 

86 scale_fmt, 

87): 

88 num_blocks = kv_cache.shape[0] 

89 head_dim = k.shape[-1] 

90 num_tokens = slot_mapping.shape[0] 

91 block_size = kv_cache.shape[1] 

92 if head_dim % quant_block_size != 0: 

93 raise ValueError("head_dim must be divisible by quant_block_size") 

94 num_quant_blocks = head_dim // quant_block_size 

95 

96 kv_cache_flat = kv_cache.view(num_blocks, -1) 

97 fp8_dtype = _get_fp8_dtype() 

98 kv_cache_value = kv_cache_flat[:, : block_size * head_dim].view(fp8_dtype) 

99 kv_cache_scale = kv_cache_flat[:, block_size * head_dim :].view(torch.float32) 

100 _indexer_k_quant_and_cache_kernel[(num_tokens, num_quant_blocks)]( 

101 k, 

102 kv_cache_value, 

103 kv_cache_scale, 

104 slot_mapping, 

105 kv_cache_scale.stride(0), 

106 kv_cache_value.stride(0), 

107 block_size, 

108 num_quant_blocks, 

109 head_dim, 

110 quant_block_size, 

111 IS_FNUZ=_is_fp8_fnuz(fp8_dtype), 

112 USE_UE8M0=scale_fmt == "ue8m0", 

113 )