Coverage for src/flag_gems/fused/cp_gather_indexer_k_quant_cache.py: 43%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1# Adapted from vLLM v0.20.2: 

2# csrc/cache_kernels.cu::cp_gather_indexer_k_quant_cache_kernel 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8 

9@triton.jit 

10def _cp_gather_indexer_quant_cache_kernel( 

11 kv_cache_ptr, 

12 kv_cache_scale_ptr, 

13 k_fp8_ptr, 

14 k_scale_ptr, 

15 block_table_ptr, 

16 cu_seqlen_ptr, 

17 block_size, 

18 block_table_stride, 

19 kv_cache_stride, 

20 kv_cache_scale_stride, 

21 k_fp8_stride, 

22 num_quant_blocks, 

23 batch_size: tl.constexpr, 

24 HEAD_DIM: tl.constexpr, 

25 QUANT_BLOCK_SIZE: tl.constexpr, 

26 BATCH_BLOCK: tl.constexpr, 

27): 

28 tid = tl.program_id(0) 

29 quant_block_id = tl.program_id(1) 

30 batch_offsets = tl.arange(0, BATCH_BLOCK) 

31 batch_mask = batch_offsets < batch_size 

32 seq_starts = tl.load(cu_seqlen_ptr + batch_offsets, mask=batch_mask, other=0) 

33 seq_ends = tl.load(cu_seqlen_ptr + batch_offsets + 1, mask=batch_mask, other=0) 

34 in_batch = (tid >= seq_starts) & (tid < seq_ends) & batch_mask 

35 batch_id = tl.max(tl.where(in_batch, batch_offsets, -1), axis=0) 

36 if batch_id < 0: 

37 return 

38 

39 batch_start = tl.load(cu_seqlen_ptr + batch_id) 

40 batch_offset = tid - batch_start 

41 

42 block_table_id = batch_offset // block_size 

43 block_offset = batch_offset % block_size 

44 block_table_offset = batch_id * block_table_stride + block_table_id 

45 block_id = tl.load(block_table_ptr + block_table_offset) 

46 

47 offsets = quant_block_id * QUANT_BLOCK_SIZE + tl.arange(0, QUANT_BLOCK_SIZE) 

48 mask = offsets < HEAD_DIM 

49 src_cache_offset = block_id * kv_cache_stride + block_offset * HEAD_DIM 

50 src_scale_offset = ( 

51 block_id * kv_cache_scale_stride 

52 + block_offset * num_quant_blocks 

53 + quant_block_id 

54 ) 

55 dst_offset = tid * k_fp8_stride 

56 

57 src_scale_ptr = kv_cache_scale_ptr + src_scale_offset 

58 src_cache_ptr = kv_cache_ptr + src_cache_offset 

59 dst_k_ptr = k_fp8_ptr + dst_offset 

60 

61 scale_val = tl.load(src_scale_ptr) 

62 tl.store(k_scale_ptr + tid * num_quant_blocks + quant_block_id, scale_val) 

63 val = tl.load(src_cache_ptr + offsets, mask=mask) 

64 tl.store(dst_k_ptr + offsets, val, mask=mask) 

65 

66 

67def cp_gather_indexer_k_quant_cache( 

68 k_cache: torch.Tensor, 

69 k_fp8: torch.Tensor, 

70 k_fp8_scale: torch.Tensor, 

71 block_table: torch.Tensor, 

72 cu_seqlen: torch.Tensor, 

73): 

74 num_tokens = k_fp8.size(0) 

75 block_size = k_cache.size(1) 

76 block_table_stride = block_table.stride(0) 

77 head_dim = k_fp8.shape[-1] 

78 num_blocks = k_cache.shape[0] 

79 quant_block_size = head_dim * 4 // k_fp8_scale.size(1) 

80 if head_dim % quant_block_size != 0: 

81 raise ValueError("head_dim must be divisible by quant_block_size") 

82 num_quant_blocks = head_dim // quant_block_size 

83 

84 k_cache_flat = k_cache.view(num_blocks, -1) 

85 k_cache_value = k_cache_flat[:, : block_size * head_dim] 

86 k_cache_scale = k_cache_flat[:, block_size * head_dim :].view(torch.float32) 

87 k_fp8 = k_fp8.view(torch.uint8) 

88 k_fp8_scale = k_fp8_scale.view(torch.float32) 

89 batch_size = block_table.shape[0] 

90 batch_block = triton.next_power_of_2(batch_size) 

91 

92 _cp_gather_indexer_quant_cache_kernel[(num_tokens, num_quant_blocks)]( 

93 k_cache_value, 

94 k_cache_scale, 

95 k_fp8, 

96 k_fp8_scale, 

97 block_table, 

98 cu_seqlen, 

99 block_size, 

100 block_table_stride, 

101 k_cache_value.stride(0), 

102 k_cache_scale.stride(0), 

103 k_fp8.stride(0), 

104 num_quant_blocks, 

105 batch_size, 

106 head_dim, 

107 quant_block_size, 

108 batch_block, 

109 )