Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/reshape_and_cache.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7from flag_gems.utils import libentry 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12@libentry() 

13@triton.jit 

14def reshape_and_cache_kernel( 

15 key, 

16 value, 

17 key_cache, 

18 value_cache, 

19 slot_mapping, 

20 key_stride, 

21 value_stride, 

22 num_heads, 

23 head_size, 

24 block_size, 

25 x, 

26 k_scale, 

27 v_scale, 

28 n: tl.constexpr, 

29): 

30 token_idx = tl.program_id(0) 

31 slot_idx = tl.load(slot_mapping + token_idx) 

32 if slot_idx < 0: 

33 return 

34 

35 block_idx = slot_idx // block_size 

36 block_offset = slot_idx % block_size 

37 i = tl.arange(0, triton.next_power_of_2(n)) 

38 mask = i < n 

39 

40 src_key_idx = token_idx * key_stride + i 

41 src_value_idx = token_idx * value_stride + i 

42 head_idx = i // head_size 

43 head_offset = i % head_size 

44 x_idx = head_offset // x 

45 x_offset = head_offset % x 

46 

47 tgt_key_idx = ( 

48 block_idx * num_heads * (head_size // x) * block_size * x 

49 + head_idx * (head_size // x) * block_size * x 

50 + x_idx * block_size * x 

51 + block_offset * x 

52 + x_offset 

53 ) 

54 tgt_value_idx = ( 

55 block_idx * num_heads * head_size * block_size 

56 + head_idx * head_size * block_size 

57 + head_offset * block_size 

58 + block_offset 

59 ) 

60 

61 tgt_key = tl.load(key + src_key_idx, mask=mask) 

62 tgt_value = tl.load(value + src_value_idx, mask=mask) 

63 

64 # TODO: support fp8 dtype 

65 tl.store(key_cache + tgt_key_idx, tgt_key, mask=mask) 

66 tl.store(value_cache + tgt_value_idx, tgt_value, mask=mask) 

67 

68 

69def reshape_and_cache( 

70 key, # [num_tokens, num_heads, head_size] 

71 value, # [num_tokens, num_heads, head_size] 

72 key_cache, # [num_blocks, num_heads, head_size/x, block_size, x] 

73 value_cache, # [num_blocks, num_heads, head_size, block_size] 

74 slot_mapping, # [num_tokens] 

75 kv_cache_dtype, 

76 k_scale, 

77 v_scale, 

78): 

79 logger.debug("GEMS RESHAPE_AND_CACHE") 

80 num_tokens = slot_mapping.size(0) 

81 num_heads = key.size(1) 

82 head_size = key.size(2) 

83 block_size = key_cache.size(3) 

84 x = key_cache.size(4) 

85 

86 key_stride = key.stride(0) 

87 value_stride = value.stride(0) 

88 

89 grid = (num_tokens,) 

90 with torch_device_fn.device(key.device): 

91 reshape_and_cache_kernel[grid]( 

92 key, 

93 value, 

94 key_cache, 

95 value_cache, 

96 slot_mapping, 

97 key_stride, 

98 value_stride, 

99 num_heads, 

100 head_size, 

101 block_size, 

102 x, 

103 k_scale, 

104 v_scale, 

105 num_heads * head_size, 

106 )