Coverage for src/flag_gems/fused/reshape_and_cache.py: 51%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.runtime import torch_device_fn
7from flag_gems.utils import libentry
9logger = logging.getLogger(__name__)
12@libentry()
13@triton.jit
14def reshape_and_cache_kernel(
15 key,
16 value,
17 key_cache,
18 value_cache,
19 slot_mapping,
20 key_stride,
21 value_stride,
22 num_heads,
23 head_size,
24 block_size,
25 x,
26 k_scale,
27 v_scale,
28 n: tl.constexpr,
29):
30 token_idx = tl.program_id(0)
31 slot_idx = tl.load(slot_mapping + token_idx)
32 if slot_idx < 0:
33 return
35 block_idx = slot_idx // block_size
36 block_offset = slot_idx % block_size
37 i = tl.arange(0, triton.next_power_of_2(n))
38 mask = i < n
40 src_key_idx = token_idx * key_stride + i
41 src_value_idx = token_idx * value_stride + i
42 head_idx = i // head_size
43 head_offset = i % head_size
44 x_idx = head_offset // x
45 x_offset = head_offset % x
47 tgt_key_idx = (
48 block_idx * num_heads * (head_size // x) * block_size * x
49 + head_idx * (head_size // x) * block_size * x
50 + x_idx * block_size * x
51 + block_offset * x
52 + x_offset
53 )
54 tgt_value_idx = (
55 block_idx * num_heads * head_size * block_size
56 + head_idx * head_size * block_size
57 + head_offset * block_size
58 + block_offset
59 )
61 tgt_key = tl.load(key + src_key_idx, mask=mask)
62 tgt_value = tl.load(value + src_value_idx, mask=mask)
64 # TODO: support fp8 dtype
65 tl.store(key_cache + tgt_key_idx, tgt_key, mask=mask)
66 tl.store(value_cache + tgt_value_idx, tgt_value, mask=mask)
69def reshape_and_cache(
70 key, # [num_tokens, num_heads, head_size]
71 value, # [num_tokens, num_heads, head_size]
72 key_cache, # [num_blocks, num_heads, head_size/x, block_size, x]
73 value_cache, # [num_blocks, num_heads, head_size, block_size]
74 slot_mapping, # [num_tokens]
75 kv_cache_dtype,
76 k_scale,
77 v_scale,
78):
79 logger.debug("GEMS RESHAPE_AND_CACHE")
80 num_tokens = slot_mapping.size(0)
81 num_heads = key.size(1)
82 head_size = key.size(2)
83 block_size = key_cache.size(3)
84 x = key_cache.size(4)
86 key_stride = key.stride(0)
87 value_stride = value.stride(0)
89 grid = (num_tokens,)
90 with torch_device_fn.device(key.device):
91 reshape_and_cache_kernel[grid](
92 key,
93 value,
94 key_cache,
95 value_cache,
96 slot_mapping,
97 key_stride,
98 value_stride,
99 num_heads,
100 head_size,
101 block_size,
102 x,
103 k_scale,
104 v_scale,
105 num_heads * head_size,
106 )