Coverage for src/flag_gems/fused/deepseek_v4_attention_dequantize_and_gather_k_cache.py: 17%
65 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1from typing import Optional
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
10def _default_scale_slots(nope_dim: int) -> int:
11 return triton.cdiv(nope_dim, 64) + (1 if nope_dim % 64 == 0 else 0)
14def _as_cache_2d(k_cache: torch.Tensor) -> torch.Tensor:
15 if k_cache.ndim == 2:
16 return k_cache
17 if k_cache.ndim == 3:
18 if k_cache.is_contiguous():
19 return k_cache.view(k_cache.shape[0], -1)
20 return k_cache.contiguous().view(k_cache.shape[0], -1)
21 raise ValueError(f"k_cache must be 2D or 3D, got shape={tuple(k_cache.shape)}")
24@triton.jit
25def _dequantize_and_gather_k_cache_kernel(
26 out_ptr,
27 out_stride0,
28 out_stride1,
29 k_cache_ptr,
30 seq_lens_ptr,
31 block_table_ptr,
32 offset,
33 gather_lens_ptr,
34 max_blocks_per_seq: tl.constexpr,
35 nope_dim: tl.constexpr,
36 rope_dim: tl.constexpr,
37 scale_slots: tl.constexpr,
38 quant_block: tl.constexpr,
39 cache_block_size: tl.constexpr,
40 token_data_size: tl.constexpr,
41 cache_block_stride: tl.constexpr,
42 output_dim: tl.constexpr,
43 num_workers: tl.constexpr,
44 HAVE_GATHER_LENS: tl.constexpr,
45):
46 req_idx = tl.program_id(0)
47 worker_idx = tl.program_id(1)
48 seq_len = tl.load(seq_lens_ptr + req_idx)
49 if HAVE_GATHER_LENS:
50 gather_len = tl.load(gather_lens_ptr + req_idx)
51 else:
52 gather_len = seq_len
53 start_pos = seq_len - gather_len
55 for local_i in range(worker_idx, gather_len, num_workers):
56 pos = start_pos + local_i
57 block_in_seq = pos // cache_block_size
58 pos_in_block = pos - block_in_seq * cache_block_size
59 physical_block = tl.load(
60 block_table_ptr + req_idx * max_blocks_per_seq + block_in_seq
61 )
62 cache_block = k_cache_ptr + physical_block.to(tl.int64) * cache_block_stride
63 token_data = cache_block + pos_in_block * token_data_size
64 scale_base = (
65 cache_block
66 + cache_block_size * token_data_size
67 + pos_in_block * scale_slots
68 )
69 out_row = out_ptr + req_idx * out_stride0 + (offset + local_i) * out_stride1
71 for qblock in tl.static_range(0, scale_slots):
72 qoffs = qblock * quant_block + tl.arange(0, quant_block)
73 qmask = qoffs < nope_dim
74 x_u8 = tl.load(token_data + qoffs, mask=qmask, other=0).to(tl.uint8)
75 x_fp8 = x_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32)
76 encoded = tl.load(scale_base + qblock)
77 scale = tl.exp2(encoded.to(tl.float32) - 127.0)
78 x = x_fp8 * scale
79 tl.store(out_row + qoffs, x.to(tl.bfloat16), mask=qmask)
81 bf16_ptr = (token_data + nope_dim).to(tl.pointer_type(tl.bfloat16))
82 for rblock in tl.static_range(0, rope_dim, 16):
83 roffs = rblock + tl.arange(0, 16)
84 rmask = roffs < rope_dim
85 vals = tl.load(bf16_ptr + roffs, mask=rmask, other=0.0)
86 tl.store(out_row + nope_dim + roffs, vals, mask=rmask)
89def dequantize_and_gather_k_cache(
90 out: torch.Tensor,
91 k_cache: torch.Tensor,
92 seq_lens: torch.Tensor,
93 gather_lens: Optional[torch.Tensor],
94 block_table: torch.Tensor,
95 block_size: int,
96 offset: int = 0,
97 rope_dim: int = 64,
98 nope_dim: Optional[int] = None,
99 scale_slots: Optional[int] = None,
100) -> None:
101 assert out.ndim == 3 and out.dtype == torch.bfloat16
102 assert seq_lens.ndim == 1 and block_table.ndim == 2
103 assert seq_lens.shape[0] == block_table.shape[0] <= out.shape[0]
104 output_dim = out.shape[-1]
105 if nope_dim is None:
106 nope_dim = output_dim - rope_dim
107 if scale_slots is None:
108 scale_slots = _default_scale_slots(nope_dim)
109 assert nope_dim + rope_dim <= output_dim
111 k_cache_2d = _as_cache_2d(k_cache)
112 token_data_size = nope_dim + rope_dim * 2
113 num_reqs = seq_lens.shape[0]
114 num_workers = 128
115 with torch_device_fn.device(out.device):
116 _dequantize_and_gather_k_cache_kernel[(num_reqs, num_workers)](
117 out,
118 out.stride(0),
119 out.stride(1),
120 k_cache_2d,
121 seq_lens,
122 block_table,
123 offset,
124 gather_lens,
125 block_table.shape[-1],
126 nope_dim=nope_dim,
127 rope_dim=rope_dim,
128 scale_slots=scale_slots,
129 quant_block=64,
130 cache_block_size=block_size,
131 token_data_size=token_data_size,
132 cache_block_stride=k_cache_2d.stride(0),
133 output_dim=output_dim,
134 num_workers=num_workers,
135 HAVE_GATHER_LENS=gather_lens is not None,
136 )
139__all__ = ["dequantize_and_gather_k_cache"]