Coverage for src/flag_gems/fused/reshape_and_cache_flash.py: 57%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.config import use_c_extension 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def reshape_and_cache_flash_kernel( 

17 key, 

18 value, 

19 key_cache, 

20 value_cache, 

21 slot_mapping, 

22 block_stride, 

23 key_stride, 

24 value_stride, 

25 num_heads, 

26 head_size, 

27 block_size, 

28 k_scale, 

29 v_scale, 

30 n: tl.constexpr, 

31): 

32 token_idx = tl.program_id(0) 

33 slot_idx = tl.load(slot_mapping + token_idx) 

34 if slot_idx < 0: 

35 return 

36 

37 block_idx = slot_idx // block_size 

38 block_offset = slot_idx % block_size 

39 i = tl.arange(0, triton.next_power_of_2(n)) 

40 mask = i < n 

41 

42 src_key_idx = token_idx * key_stride + i 

43 src_value_idx = token_idx * value_stride + i 

44 head_idx = i // head_size 

45 head_offset = i % head_size 

46 tgt_key_value_idx = ( 

47 block_idx * block_stride 

48 + block_offset * num_heads * head_size 

49 + head_idx * head_size 

50 + head_offset 

51 ) 

52 

53 tgt_key = tl.load(key + src_key_idx, mask=mask) 

54 tgt_value = tl.load(value + src_value_idx, mask=mask) 

55 

56 # TODO: support fp8 dtype 

57 tl.store(key_cache + tgt_key_value_idx, tgt_key, mask=mask) 

58 tl.store(value_cache + tgt_key_value_idx, tgt_value, mask=mask) 

59 

60 

61def reshape_and_cache_flash( 

62 key, # [num_tokens, num_heads, head_size] 

63 value, # [num_tokens, num_heads, head_size] 

64 key_cache, # [num_blocks, block_size, num_heads, head_size] 

65 value_cache, # [num_blocks, block_size, num_heads, head_size] 

66 slot_mapping, # [num_tokens] 

67 kv_cache_dtype, 

68 k_scale, 

69 v_scale, 

70): 

71 if use_c_extension: 

72 logger.debug("GEMS RESHAPE_AND_CACHE_FLASH(C EXTENSION)") 

73 torch.ops.flag_gems.reshape_and_cache_flash( 

74 key, 

75 value, 

76 key_cache, 

77 value_cache, 

78 slot_mapping, 

79 kv_cache_dtype, 

80 k_scale, 

81 v_scale, 

82 ) 

83 else: 

84 logger.debug("GEMS RESHAPE_AND_CACHE_FLASH") 

85 num_tokens = slot_mapping.size(0) 

86 num_heads = key.size(1) 

87 head_size = key.size(2) 

88 block_size = key_cache.size(1) 

89 

90 key_stride = key.stride(0) 

91 value_stride = value.stride(0) 

92 block_stride = key_cache.stride(0) 

93 assert key_cache.stride(0) == value_cache.stride(0) 

94 

95 grid = (num_tokens,) 

96 with torch_device_fn.device(key.device): 

97 reshape_and_cache_flash_kernel[grid]( 

98 key, 

99 value, 

100 key_cache, 

101 value_cache, 

102 slot_mapping, 

103 block_stride, 

104 key_stride, 

105 value_stride, 

106 num_heads, 

107 head_size, 

108 block_size, 

109 k_scale, 

110 v_scale, 

111 num_heads * head_size, 

112 )