Coverage for src/flag_gems/fused/apply_repetition_penalties.py: 53%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _repetition_penalty_kernel(
8 logits_ptr,
9 prompt_mask_ptr,
10 output_mask_ptr,
11 penalties_ptr,
12 num_seqs,
13 vocab_size,
14 BLOCK_SIZE: tl.constexpr,
15):
16 seq_idx = tl.program_id(0)
17 vocab_offset = tl.program_id(1) * BLOCK_SIZE
19 if seq_idx >= num_seqs:
20 return
22 penalty = tl.load(penalties_ptr + seq_idx)
24 vocab_idx = vocab_offset + tl.arange(0, BLOCK_SIZE)
26 valid_vocab = vocab_idx < vocab_size
28 logits_idx = seq_idx * vocab_size + vocab_idx
29 mask_idx = logits_idx
31 prompt_mask = tl.load(prompt_mask_ptr + mask_idx, mask=valid_vocab, other=False)
32 output_mask = tl.load(output_mask_ptr + mask_idx, mask=valid_vocab, other=False)
33 logits = tl.load(logits_ptr + logits_idx, mask=valid_vocab, other=0.0)
35 is_repeated = prompt_mask | output_mask
37 logits = tl.where(is_repeated & (logits > 0), logits / penalty, logits)
38 logits = tl.where(is_repeated & (logits <= 0), logits * penalty, logits)
40 tl.store(logits_ptr + logits_idx, logits, mask=valid_vocab)
43def apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties):
44 assert logits.is_contiguous(), "logits must be contiguous"
45 assert (
46 prompt_mask.is_contiguous() and prompt_mask.dtype == torch.bool
47 ), "prompt_mask must be contiguous bool tensor"
48 assert (
49 output_mask.is_contiguous() and output_mask.dtype == torch.bool
50 ), "output_mask must be contiguous bool tensor"
51 assert (
52 repetition_penalties.is_contiguous()
53 ), "repetition_penalties must be contiguous"
54 assert logits.dim() == 2, f"logits must be 2D, got {logits.dim()}D"
55 assert (
56 logits.shape == prompt_mask.shape == output_mask.shape
57 ), "shape mismatch between logits and masks"
58 assert (
59 repetition_penalties.dim() == 1
60 and repetition_penalties.numel() == logits.shape[0]
61 ), "repetition_penalties must be 1D with length equal to num_seqs"
63 num_seqs, vocab_size = logits.shape
65 BLOCK_SIZE = 1024
67 grid = (
68 num_seqs,
69 triton.cdiv(vocab_size, BLOCK_SIZE),
70 )
72 _repetition_penalty_kernel[grid](
73 logits,
74 prompt_mask,
75 output_mask,
76 repetition_penalties,
77 num_seqs,
78 vocab_size,
79 BLOCK_SIZE=BLOCK_SIZE,
80 )
81 return None