Coverage for src/flag_gems/fused/deepseek_v4_attention_dequantize_and_gather_k_cache.py: 17%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1from typing import Optional 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8 

9 

10def _default_scale_slots(nope_dim: int) -> int: 

11 return triton.cdiv(nope_dim, 64) + (1 if nope_dim % 64 == 0 else 0) 

12 

13 

14def _as_cache_2d(k_cache: torch.Tensor) -> torch.Tensor: 

15 if k_cache.ndim == 2: 

16 return k_cache 

17 if k_cache.ndim == 3: 

18 if k_cache.is_contiguous(): 

19 return k_cache.view(k_cache.shape[0], -1) 

20 return k_cache.contiguous().view(k_cache.shape[0], -1) 

21 raise ValueError(f"k_cache must be 2D or 3D, got shape={tuple(k_cache.shape)}") 

22 

23 

24@triton.jit 

25def _dequantize_and_gather_k_cache_kernel( 

26 out_ptr, 

27 out_stride0, 

28 out_stride1, 

29 k_cache_ptr, 

30 seq_lens_ptr, 

31 block_table_ptr, 

32 offset, 

33 gather_lens_ptr, 

34 max_blocks_per_seq: tl.constexpr, 

35 nope_dim: tl.constexpr, 

36 rope_dim: tl.constexpr, 

37 scale_slots: tl.constexpr, 

38 quant_block: tl.constexpr, 

39 cache_block_size: tl.constexpr, 

40 token_data_size: tl.constexpr, 

41 cache_block_stride: tl.constexpr, 

42 output_dim: tl.constexpr, 

43 num_workers: tl.constexpr, 

44 HAVE_GATHER_LENS: tl.constexpr, 

45): 

46 req_idx = tl.program_id(0) 

47 worker_idx = tl.program_id(1) 

48 seq_len = tl.load(seq_lens_ptr + req_idx) 

49 if HAVE_GATHER_LENS: 

50 gather_len = tl.load(gather_lens_ptr + req_idx) 

51 else: 

52 gather_len = seq_len 

53 start_pos = seq_len - gather_len 

54 

55 for local_i in range(worker_idx, gather_len, num_workers): 

56 pos = start_pos + local_i 

57 block_in_seq = pos // cache_block_size 

58 pos_in_block = pos - block_in_seq * cache_block_size 

59 physical_block = tl.load( 

60 block_table_ptr + req_idx * max_blocks_per_seq + block_in_seq 

61 ) 

62 cache_block = k_cache_ptr + physical_block.to(tl.int64) * cache_block_stride 

63 token_data = cache_block + pos_in_block * token_data_size 

64 scale_base = ( 

65 cache_block 

66 + cache_block_size * token_data_size 

67 + pos_in_block * scale_slots 

68 ) 

69 out_row = out_ptr + req_idx * out_stride0 + (offset + local_i) * out_stride1 

70 

71 for qblock in tl.static_range(0, scale_slots): 

72 qoffs = qblock * quant_block + tl.arange(0, quant_block) 

73 qmask = qoffs < nope_dim 

74 x_u8 = tl.load(token_data + qoffs, mask=qmask, other=0).to(tl.uint8) 

75 x_fp8 = x_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32) 

76 encoded = tl.load(scale_base + qblock) 

77 scale = tl.exp2(encoded.to(tl.float32) - 127.0) 

78 x = x_fp8 * scale 

79 tl.store(out_row + qoffs, x.to(tl.bfloat16), mask=qmask) 

80 

81 bf16_ptr = (token_data + nope_dim).to(tl.pointer_type(tl.bfloat16)) 

82 for rblock in tl.static_range(0, rope_dim, 16): 

83 roffs = rblock + tl.arange(0, 16) 

84 rmask = roffs < rope_dim 

85 vals = tl.load(bf16_ptr + roffs, mask=rmask, other=0.0) 

86 tl.store(out_row + nope_dim + roffs, vals, mask=rmask) 

87 

88 

89def dequantize_and_gather_k_cache( 

90 out: torch.Tensor, 

91 k_cache: torch.Tensor, 

92 seq_lens: torch.Tensor, 

93 gather_lens: Optional[torch.Tensor], 

94 block_table: torch.Tensor, 

95 block_size: int, 

96 offset: int = 0, 

97 rope_dim: int = 64, 

98 nope_dim: Optional[int] = None, 

99 scale_slots: Optional[int] = None, 

100) -> None: 

101 assert out.ndim == 3 and out.dtype == torch.bfloat16 

102 assert seq_lens.ndim == 1 and block_table.ndim == 2 

103 assert seq_lens.shape[0] == block_table.shape[0] <= out.shape[0] 

104 output_dim = out.shape[-1] 

105 if nope_dim is None: 

106 nope_dim = output_dim - rope_dim 

107 if scale_slots is None: 

108 scale_slots = _default_scale_slots(nope_dim) 

109 assert nope_dim + rope_dim <= output_dim 

110 

111 k_cache_2d = _as_cache_2d(k_cache) 

112 token_data_size = nope_dim + rope_dim * 2 

113 num_reqs = seq_lens.shape[0] 

114 num_workers = 128 

115 with torch_device_fn.device(out.device): 

116 _dequantize_and_gather_k_cache_kernel[(num_reqs, num_workers)]( 

117 out, 

118 out.stride(0), 

119 out.stride(1), 

120 k_cache_2d, 

121 seq_lens, 

122 block_table, 

123 offset, 

124 gather_lens, 

125 block_table.shape[-1], 

126 nope_dim=nope_dim, 

127 rope_dim=rope_dim, 

128 scale_slots=scale_slots, 

129 quant_block=64, 

130 cache_block_size=block_size, 

131 token_data_size=token_data_size, 

132 cache_block_stride=k_cache_2d.stride(0), 

133 output_dim=output_dim, 

134 num_workers=num_workers, 

135 HAVE_GATHER_LENS=gather_lens is not None, 

136 ) 

137 

138 

139__all__ = ["dequantize_and_gather_k_cache"]