Coverage for src/flag_gems/fused/indexer_k_quant_and_cache.py: 14%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# Adapted from vLLM v0.20.2: 

2# csrc/cache_kernels.cu::indexer_k_quant_and_cache_kernel 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8 

9def _get_fp8_dtype() -> torch.dtype: 

10 try: 

11 from vllm.platforms import current_platform 

12 

13 return current_platform.fp8_dtype() 

14 except ImportError: 

15 pass 

16 

17 if getattr(torch.version, "hip", None) is not None and hasattr( 

18 torch, "float8_e4m3fnuz" 

19 ): 

20 return torch.float8_e4m3fnuz 

21 if hasattr(torch, "float8_e4m3fn"): 

22 return torch.float8_e4m3fn 

23 raise RuntimeError("float8_e4m3fn is required for indexer_k_quant_and_cache") 

24 

25 

26def _is_fp8_fnuz(dtype: torch.dtype) -> bool: 

27 return hasattr(torch, "float8_e4m3fnuz") and dtype == torch.float8_e4m3fnuz 

28 

29 

30@triton.jit 

31def _indexer_k_quant_and_cache_kernel( 

32 k_ptr, 

33 kv_cache_ptr, 

34 kv_cache_scale_ptr, 

35 slot_mapping_ptr, 

36 kv_cache_scale_stride, 

37 kv_cache_value_stride, 

38 block_size, 

39 num_quant_blocks, 

40 head_dim: tl.constexpr, 

41 QUANT_BLOCK_SIZE: tl.constexpr, 

42 IS_FNUZ: tl.constexpr, 

43 USE_UE8M0: tl.constexpr, 

44): 

45 tid = tl.program_id(0) 

46 quant_block_id = tl.program_id(1) * 4 

47 quant_block_offsets = tl.arange(0, 4) 

48 head_offsets = tl.arange(0, QUANT_BLOCK_SIZE) 

49 offsets = ( 

50 quant_block_id + quant_block_offsets[:, None] 

51 ) * QUANT_BLOCK_SIZE + head_offsets[None, :] 

52 mask = offsets < head_dim 

53 

54 src_ptr = k_ptr + tid * head_dim 

55 slot_id = tl.load(slot_mapping_ptr + tid) 

56 if slot_id < 0: 

57 return 

58 

59 block_id = slot_id // block_size 

60 block_offset = slot_id % block_size 

61 

62 val = tl.load(src_ptr + offsets, mask=mask, other=0.0) 

63 amax = tl.max(tl.abs(val).to(tl.float32), axis=1) 

64 if IS_FNUZ: 

65 scale = tl.maximum(1e-4, amax) / 224.0 

66 else: 

67 scale = tl.maximum(1e-4, amax) / 448.0 

68 

69 if USE_UE8M0: 

70 scale = tl.exp2(tl.ceil(tl.log2(scale))) 

71 

72 fp8_val = (val.to(tl.float32) / scale[:, None]).to(kv_cache_ptr.type.element_ty) 

73 dst_ptr = kv_cache_ptr + block_id * kv_cache_value_stride + block_offset * head_dim 

74 tl.store(dst_ptr + offsets, fp8_val, mask=mask) 

75 

76 dst_scale_ptr = ( 

77 kv_cache_scale_ptr 

78 + block_id * kv_cache_scale_stride 

79 + block_offset * num_quant_blocks 

80 + quant_block_id 

81 ) 

82 scale_mask = quant_block_id + quant_block_offsets < num_quant_blocks 

83 tl.store(dst_scale_ptr + quant_block_offsets, scale, mask=scale_mask) 

84 

85 

86def indexer_k_quant_and_cache( 

87 k: torch.Tensor, 

88 kv_cache: torch.Tensor, 

89 slot_mapping: torch.Tensor, 

90 quant_block_size, 

91 scale_fmt, 

92): 

93 num_blocks = kv_cache.shape[0] 

94 head_dim = k.shape[-1] 

95 num_tokens = slot_mapping.shape[0] 

96 block_size = kv_cache.shape[1] 

97 if head_dim % quant_block_size != 0: 

98 raise ValueError("head_dim must be divisible by quant_block_size") 

99 num_quant_blocks = head_dim // quant_block_size 

100 

101 kv_cache_flat = kv_cache.view(num_blocks, -1) 

102 fp8_dtype = _get_fp8_dtype() 

103 kv_cache_value = kv_cache_flat[:, : block_size * head_dim].view(fp8_dtype) 

104 kv_cache_scale = kv_cache_flat[:, block_size * head_dim :].view(torch.float32) 

105 _indexer_k_quant_and_cache_kernel[(num_tokens, triton.cdiv(num_quant_blocks, 4))]( 

106 k, 

107 kv_cache_value, 

108 kv_cache_scale, 

109 slot_mapping, 

110 kv_cache_scale.stride(0), 

111 kv_cache_value.stride(0), 

112 block_size, 

113 num_quant_blocks, 

114 head_dim, 

115 quant_block_size, 

116 IS_FNUZ=_is_fp8_fnuz(fp8_dtype), 

117 USE_UE8M0=scale_fmt == "ue8m0", 

118 )