Coverage for src/flag_gems/fused/reshape_and_cache_flash.py: 57%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.config import use_c_extension
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def reshape_and_cache_flash_kernel(
17 key,
18 value,
19 key_cache,
20 value_cache,
21 slot_mapping,
22 block_stride,
23 key_stride,
24 value_stride,
25 num_heads,
26 head_size,
27 block_size,
28 k_scale,
29 v_scale,
30 n: tl.constexpr,
31):
32 token_idx = tl.program_id(0)
33 slot_idx = tl.load(slot_mapping + token_idx)
34 if slot_idx < 0:
35 return
37 block_idx = slot_idx // block_size
38 block_offset = slot_idx % block_size
39 i = tl.arange(0, triton.next_power_of_2(n))
40 mask = i < n
42 src_key_idx = token_idx * key_stride + i
43 src_value_idx = token_idx * value_stride + i
44 head_idx = i // head_size
45 head_offset = i % head_size
46 tgt_key_value_idx = (
47 block_idx * block_stride
48 + block_offset * num_heads * head_size
49 + head_idx * head_size
50 + head_offset
51 )
53 tgt_key = tl.load(key + src_key_idx, mask=mask)
54 tgt_value = tl.load(value + src_value_idx, mask=mask)
56 # TODO: support fp8 dtype
57 tl.store(key_cache + tgt_key_value_idx, tgt_key, mask=mask)
58 tl.store(value_cache + tgt_key_value_idx, tgt_value, mask=mask)
61def reshape_and_cache_flash(
62 key, # [num_tokens, num_heads, head_size]
63 value, # [num_tokens, num_heads, head_size]
64 key_cache, # [num_blocks, block_size, num_heads, head_size]
65 value_cache, # [num_blocks, block_size, num_heads, head_size]
66 slot_mapping, # [num_tokens]
67 kv_cache_dtype,
68 k_scale,
69 v_scale,
70):
71 if use_c_extension:
72 logger.debug("GEMS RESHAPE_AND_CACHE_FLASH(C EXTENSION)")
73 torch.ops.flag_gems.reshape_and_cache_flash(
74 key,
75 value,
76 key_cache,
77 value_cache,
78 slot_mapping,
79 kv_cache_dtype,
80 k_scale,
81 v_scale,
82 )
83 else:
84 logger.debug("GEMS RESHAPE_AND_CACHE_FLASH")
85 num_tokens = slot_mapping.size(0)
86 num_heads = key.size(1)
87 head_size = key.size(2)
88 block_size = key_cache.size(1)
90 key_stride = key.stride(0)
91 value_stride = value.stride(0)
92 block_stride = key_cache.stride(0)
93 assert key_cache.stride(0) == value_cache.stride(0)
95 grid = (num_tokens,)
96 with torch_device_fn.device(key.device):
97 reshape_and_cache_flash_kernel[grid](
98 key,
99 value,
100 key_cache,
101 value_cache,
102 slot_mapping,
103 block_stride,
104 key_stride,
105 value_stride,
106 num_heads,
107 head_size,
108 block_size,
109 k_scale,
110 v_scale,
111 num_heads * head_size,
112 )