Coverage for src/flag_gems/ops/get_paged_mqa_logits_metadata.py: 15%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _paged_mqa_logits_metadata_kernel( 

8 context_lens_ptr, 

9 context_lens_stride, 

10 schedule_metadata_ptr, 

11 batch_size, 

12 split_kv, 

13 num_sms, 

14 BLOCK_SIZE: tl.constexpr, 

15): 

16 sm_idx = tl.program_id(0) 

17 

18 # 1. Create offsets and mask 

19 offsets = tl.arange(0, BLOCK_SIZE) 

20 mask = offsets < batch_size 

21 

22 # 2. Load effective context lengths from global memory to SRAM 

23 ctx_lens = tl.load( 

24 context_lens_ptr + offsets * context_lens_stride, mask=mask, other=0 

25 ) 

26 

27 # 3. Compute num_segs and mask out-of-bounds 

28 num_segs = (ctx_lens + split_kv - 1) // split_kv 

29 num_segs = tl.where(mask, num_segs, 0) 

30 

31 # 4. Compute Inclusive Prefix Sum 

32 prefix_sum = tl.cumsum(num_segs, axis=0) 

33 

34 # 5. Total segment count is the max value of prefix sum 

35 total_segs = tl.max(prefix_sum) 

36 

37 # 6. Compute each SM's work allocation boundary 

38 q = total_segs // num_sms 

39 r = total_segs % num_sms 

40 min_r = sm_idx if sm_idx < r else r 

41 seg_starts = sm_idx * q + min_r 

42 

43 # 7. Compute q_idx using inclusive prefix sum with <= comparison 

44 is_le = (prefix_sum <= seg_starts) & mask 

45 q_idx = tl.sum(tl.where(is_le, 1, 0)) 

46 

47 # 8. Compute kv_split_idx 

48 prev_mask = offsets < q_idx 

49 prev_prefix = tl.max(tl.where(prev_mask, prefix_sum, 0)) 

50 kv_split_idx = seg_starts - prev_prefix 

51 

52 # 9. Write back results 

53 out_idx = sm_idx * 2 

54 tl.store(schedule_metadata_ptr + out_idx, q_idx) 

55 tl.store(schedule_metadata_ptr + out_idx + 1, kv_split_idx) 

56 

57 

58def get_paged_mqa_logits_metadata( 

59 context_lens: torch.Tensor, block_size: int, num_sms: int 

60) -> torch.Tensor: 

61 SPLIT_KV = 256 

62 device = context_lens.device 

63 

64 # 1. Handle 1D / 2D input 

65 if context_lens.dim() == 2: 

66 batch_size, next_n = context_lens.shape 

67 effective_context_lens = context_lens[:, next_n - 1] 

68 else: 

69 batch_size = context_lens.shape[0] 

70 effective_context_lens = context_lens 

71 

72 # Edge case: if batch is 0, return all zeros 

73 if batch_size == 0: 

74 return torch.zeros((num_sms + 1, 2), dtype=torch.int32, device=device) 

75 

76 # 2. Parameter calculation and configuration 

77 grid = (num_sms + 1,) 

78 

79 # Auto-get next power of 2 as BLOCK_SIZE to cover entire batch 

80 BLOCK_SIZE = triton.next_power_of_2(max(16, batch_size)) 

81 

82 # 3. Launch kernel 

83 schedule_metadata = torch.zeros((num_sms + 1, 2), dtype=torch.int32, device=device) 

84 

85 _paged_mqa_logits_metadata_kernel[grid]( 

86 effective_context_lens, 

87 effective_context_lens.stride(0), 

88 schedule_metadata, 

89 batch_size, 

90 SPLIT_KV, 

91 num_sms, 

92 BLOCK_SIZE=BLOCK_SIZE, 

93 ) 

94 

95 return schedule_metadata