Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/reshape_and_cache_flash.py: 0%
39 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.runtime import torch_device_fn
7from flag_gems.utils import libentry
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12@libentry()
13@triton.jit
14def reshape_and_cache_flash_kernel(
15 key,
16 value,
17 key_cache,
18 value_cache,
19 slot_mapping,
20 block_stride,
21 key_stride,
22 value_stride,
23 num_heads,
24 head_size,
25 block_size,
26 k_scale,
27 v_scale,
28 n: tl.constexpr,
29):
30 token_idx = tl.program_id(0)
31 slot_idx = tl.load(slot_mapping + token_idx)
32 if slot_idx < 0:
33 return
35 block_idx = slot_idx // block_size
36 block_offset = slot_idx % block_size
37 i = tl.arange(0, triton.next_power_of_2(n))
38 mask = i < n
40 src_key_idx = token_idx * key_stride + i
41 src_value_idx = token_idx * value_stride + i
42 head_idx = i // head_size
43 head_offset = i % head_size
44 tgt_key_value_idx = (
45 block_idx * block_stride
46 + block_offset * num_heads * head_size
47 + head_idx * head_size
48 + head_offset
49 )
51 tgt_key = tl.load(key + src_key_idx, mask=mask)
52 tgt_value = tl.load(value + src_value_idx, mask=mask)
54 # TODO: support fp8 dtype
55 tl.store(key_cache + tgt_key_value_idx, tgt_key, mask=mask)
56 tl.store(value_cache + tgt_key_value_idx, tgt_value, mask=mask)
59def reshape_and_cache_flash(
60 key, # [num_tokens, num_heads, head_size]
61 value, # [num_tokens, num_heads, head_size]
62 key_cache, # [num_blocks, block_size, num_heads, head_size]
63 value_cache, # [num_blocks, block_size, num_heads, head_size]
64 slot_mapping, # [num_tokens]
65 kv_cache_dtype,
66 k_scale,
67 v_scale,
68):
69 logger.debug("GEMS RESHAPE_AND_CACHE_FLASH")
70 num_tokens = slot_mapping.size(0)
71 num_heads = key.size(1)
72 head_size = key.size(2)
73 block_size = key_cache.size(1)
75 key_stride = key.stride(0)
76 value_stride = value.stride(0)
77 block_stride = key_cache.stride(0)
78 assert key_cache.stride(0) == value_cache.stride(0)
80 grid = (num_tokens,)
81 with torch_device_fn.device(key.device):
82 reshape_and_cache_flash_kernel[grid](
83 key,
84 value,
85 key_cache,
86 value_cache,
87 slot_mapping,
88 block_stride,
89 key_stride,
90 value_stride,
91 num_heads,
92 head_size,
93 block_size,
94 k_scale,
95 v_scale,
96 num_heads * head_size,
97 )