Coverage for src/flag_gems/fused/deepseek_v4_attention_combine_topk_swa_indices.py: 48%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1from typing import Tuple 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8 

9_SPARSE_PREFILL_TOPK_ALIGNMENT = 128 

10 

11 

12def _next_power_of_2_or_1(x: int) -> int: 

13 return 1 if x <= 1 else triton.next_power_of_2(x) 

14 

15 

16@triton.jit 

17def _combine_topk_swa_indices_kernel( 

18 combined_ptr, 

19 combined_stride, 

20 lens_ptr, 

21 topk_ptr, 

22 topk_stride, 

23 query_start_loc_ptr, 

24 seq_lens_ptr, 

25 gather_lens_ptr, 

26 M, 

27 N, 

28 TOP_K: tl.constexpr, 

29 COMPRESS_RATIO: tl.constexpr, 

30 WINDOW_SIZE: tl.constexpr, 

31 PADDED_TOP_K: tl.constexpr, 

32 PADDED_WINDOW_SIZE: tl.constexpr, 

33): 

34 batch_idx = tl.program_id(0) 

35 worker_idx = tl.program_id(1) 

36 num_workers = tl.num_programs(1) 

37 base = tl.load(query_start_loc_ptr) 

38 query_start = tl.load(query_start_loc_ptr + batch_idx) - base 

39 query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - base 

40 query_len = query_end - query_start 

41 seq_len = tl.load(seq_lens_ptr + batch_idx) 

42 gather_len = tl.load(gather_lens_ptr + batch_idx) 

43 start_pos = seq_len - query_len 

44 gather_start = seq_len - gather_len 

45 

46 for token_idx in range(query_start + worker_idx, query_end, num_workers): 

47 token_in_query = token_idx - query_start 

48 pos = start_pos + token_in_query 

49 topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K) 

50 swa_len = tl.minimum(pos + 1, WINDOW_SIZE) 

51 

52 offs = tl.arange(0, PADDED_TOP_K) 

53 mask = offs < topk_len 

54 topk_vals = tl.load( 

55 topk_ptr + token_idx * topk_stride + offs, mask=mask, other=-1 

56 ) 

57 tl.store( 

58 combined_ptr + token_idx * combined_stride + offs, 

59 topk_vals + M * batch_idx, 

60 mask=mask, 

61 ) 

62 

63 swa_offs = tl.arange(0, PADDED_WINDOW_SIZE) 

64 tl.store( 

65 combined_ptr + token_idx * combined_stride + topk_len + swa_offs, 

66 M * batch_idx + N + swa_offs + pos - swa_len + 1 - gather_start, 

67 mask=(swa_offs < swa_len) & (swa_offs < WINDOW_SIZE), 

68 ) 

69 tl.store(lens_ptr + token_idx, topk_len + swa_len) 

70 

71 

72def combine_topk_swa_indices( 

73 topk_indices: torch.Tensor, 

74 query_start_loc: torch.Tensor, 

75 seq_lens: torch.Tensor, 

76 gather_lens: torch.Tensor, 

77 window_size: int, 

78 compress_ratio: int, 

79 topk: int, 

80 M: int, 

81 N: int, 

82) -> Tuple[torch.Tensor, torch.Tensor]: 

83 assert topk_indices.ndim == 2 

84 num_tokens = topk_indices.shape[0] 

85 num_reqs = seq_lens.shape[0] 

86 combined_topk = ( 

87 (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) 

88 // _SPARSE_PREFILL_TOPK_ALIGNMENT 

89 * _SPARSE_PREFILL_TOPK_ALIGNMENT 

90 ) 

91 combined = torch.full( 

92 (num_tokens, combined_topk), -1, device=topk_indices.device, dtype=torch.int32 

93 ) 

94 lens = torch.empty((num_tokens,), device=topk_indices.device, dtype=torch.int32) 

95 with torch_device_fn.device(topk_indices.device): 

96 _combine_topk_swa_indices_kernel[(num_reqs, 128)]( 

97 combined, 

98 combined.stride(0), 

99 lens, 

100 topk_indices, 

101 topk_indices.stride(0), 

102 query_start_loc, 

103 seq_lens, 

104 gather_lens, 

105 M, 

106 N, 

107 TOP_K=topk, 

108 COMPRESS_RATIO=compress_ratio, 

109 WINDOW_SIZE=window_size, 

110 PADDED_TOP_K=_next_power_of_2_or_1(topk_indices.shape[-1]), 

111 PADDED_WINDOW_SIZE=_next_power_of_2_or_1(window_size), 

112 ) 

113 return combined, lens 

114 

115 

116__all__ = ["combine_topk_swa_indices"]