Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/reshape_and_cache_flash.py: 0%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7from flag_gems.utils import libentry 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12@libentry() 

13@triton.jit 

14def reshape_and_cache_flash_kernel( 

15 key, 

16 value, 

17 key_cache, 

18 value_cache, 

19 slot_mapping, 

20 block_stride, 

21 key_stride, 

22 value_stride, 

23 num_heads, 

24 head_size, 

25 block_size, 

26 k_scale, 

27 v_scale, 

28 n: tl.constexpr, 

29): 

30 token_idx = tl.program_id(0) 

31 slot_idx = tl.load(slot_mapping + token_idx) 

32 if slot_idx < 0: 

33 return 

34 

35 block_idx = slot_idx // block_size 

36 block_offset = slot_idx % block_size 

37 i = tl.arange(0, triton.next_power_of_2(n)) 

38 mask = i < n 

39 

40 src_key_idx = token_idx * key_stride + i 

41 src_value_idx = token_idx * value_stride + i 

42 head_idx = i // head_size 

43 head_offset = i % head_size 

44 tgt_key_value_idx = ( 

45 block_idx * block_stride 

46 + block_offset * num_heads * head_size 

47 + head_idx * head_size 

48 + head_offset 

49 ) 

50 

51 tgt_key = tl.load(key + src_key_idx, mask=mask) 

52 tgt_value = tl.load(value + src_value_idx, mask=mask) 

53 

54 # TODO: support fp8 dtype 

55 tl.store(key_cache + tgt_key_value_idx, tgt_key, mask=mask) 

56 tl.store(value_cache + tgt_key_value_idx, tgt_value, mask=mask) 

57 

58 

59def reshape_and_cache_flash( 

60 key, # [num_tokens, num_heads, head_size] 

61 value, # [num_tokens, num_heads, head_size] 

62 key_cache, # [num_blocks, block_size, num_heads, head_size] 

63 value_cache, # [num_blocks, block_size, num_heads, head_size] 

64 slot_mapping, # [num_tokens] 

65 kv_cache_dtype, 

66 k_scale, 

67 v_scale, 

68): 

69 logger.debug("GEMS RESHAPE_AND_CACHE_FLASH") 

70 num_tokens = slot_mapping.size(0) 

71 num_heads = key.size(1) 

72 head_size = key.size(2) 

73 block_size = key_cache.size(1) 

74 

75 key_stride = key.stride(0) 

76 value_stride = value.stride(0) 

77 block_stride = key_cache.stride(0) 

78 assert key_cache.stride(0) == value_cache.stride(0) 

79 

80 grid = (num_tokens,) 

81 with torch_device_fn.device(key.device): 

82 reshape_and_cache_flash_kernel[grid]( 

83 key, 

84 value, 

85 key_cache, 

86 value_cache, 

87 slot_mapping, 

88 block_stride, 

89 key_stride, 

90 value_stride, 

91 num_heads, 

92 head_size, 

93 block_size, 

94 k_scale, 

95 v_scale, 

96 num_heads * head_size, 

97 )