Coverage for src/flag_gems/ops/get_paged_mqa_logits_metadata.py: 15%
40 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _paged_mqa_logits_metadata_kernel(
8 context_lens_ptr,
9 context_lens_stride,
10 schedule_metadata_ptr,
11 batch_size,
12 split_kv,
13 num_sms,
14 BLOCK_SIZE: tl.constexpr,
15):
16 sm_idx = tl.program_id(0)
18 # 1. Create offsets and mask
19 offsets = tl.arange(0, BLOCK_SIZE)
20 mask = offsets < batch_size
22 # 2. Load effective context lengths from global memory to SRAM
23 ctx_lens = tl.load(
24 context_lens_ptr + offsets * context_lens_stride, mask=mask, other=0
25 )
27 # 3. Compute num_segs and mask out-of-bounds
28 num_segs = (ctx_lens + split_kv - 1) // split_kv
29 num_segs = tl.where(mask, num_segs, 0)
31 # 4. Compute Inclusive Prefix Sum
32 prefix_sum = tl.cumsum(num_segs, axis=0)
34 # 5. Total segment count is the max value of prefix sum
35 total_segs = tl.max(prefix_sum)
37 # 6. Compute each SM's work allocation boundary
38 q = total_segs // num_sms
39 r = total_segs % num_sms
40 min_r = sm_idx if sm_idx < r else r
41 seg_starts = sm_idx * q + min_r
43 # 7. Compute q_idx using inclusive prefix sum with <= comparison
44 is_le = (prefix_sum <= seg_starts) & mask
45 q_idx = tl.sum(tl.where(is_le, 1, 0))
47 # 8. Compute kv_split_idx
48 prev_mask = offsets < q_idx
49 prev_prefix = tl.max(tl.where(prev_mask, prefix_sum, 0))
50 kv_split_idx = seg_starts - prev_prefix
52 # 9. Write back results
53 out_idx = sm_idx * 2
54 tl.store(schedule_metadata_ptr + out_idx, q_idx)
55 tl.store(schedule_metadata_ptr + out_idx + 1, kv_split_idx)
58def get_paged_mqa_logits_metadata(
59 context_lens: torch.Tensor, block_size: int, num_sms: int
60) -> torch.Tensor:
61 SPLIT_KV = 256
62 device = context_lens.device
64 # 1. Handle 1D / 2D input
65 if context_lens.dim() == 2:
66 batch_size, next_n = context_lens.shape
67 effective_context_lens = context_lens[:, next_n - 1]
68 else:
69 batch_size = context_lens.shape[0]
70 effective_context_lens = context_lens
72 # Edge case: if batch is 0, return all zeros
73 if batch_size == 0:
74 return torch.zeros((num_sms + 1, 2), dtype=torch.int32, device=device)
76 # 2. Parameter calculation and configuration
77 grid = (num_sms + 1,)
79 # Auto-get next power of 2 as BLOCK_SIZE to cover entire batch
80 BLOCK_SIZE = triton.next_power_of_2(max(16, batch_size))
82 # 3. Launch kernel
83 schedule_metadata = torch.zeros((num_sms + 1, 2), dtype=torch.int32, device=device)
85 _paged_mqa_logits_metadata_kernel[grid](
86 effective_context_lens,
87 effective_context_lens.stride(0),
88 schedule_metadata,
89 batch_size,
90 SPLIT_KV,
91 num_sms,
92 BLOCK_SIZE=BLOCK_SIZE,
93 )
95 return schedule_metadata