Coverage for src/flag_gems/fused/apply_repetition_penalties.py: 57%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def _repetition_penalty_kernel(
12 logits_ptr,
13 prompt_mask_ptr,
14 output_mask_ptr,
15 penalties_ptr,
16 num_seqs,
17 vocab_size,
18 BLOCK_SIZE: tl.constexpr,
19):
20 seq_idx = tl.program_id(0)
21 vocab_offset = tl.program_id(1) * BLOCK_SIZE
23 if seq_idx >= num_seqs:
24 return
26 penalty = tl.load(penalties_ptr + seq_idx)
28 vocab_idx = vocab_offset + tl.arange(0, BLOCK_SIZE)
30 valid_vocab = vocab_idx < vocab_size
32 logits_idx = seq_idx * vocab_size + vocab_idx
33 mask_idx = logits_idx
35 prompt_mask = tl.load(prompt_mask_ptr + mask_idx, mask=valid_vocab, other=False)
36 output_mask = tl.load(output_mask_ptr + mask_idx, mask=valid_vocab, other=False)
37 logits = tl.load(logits_ptr + logits_idx, mask=valid_vocab, other=0.0)
39 is_repeated = prompt_mask | output_mask
41 logits = tl.where(is_repeated & (logits > 0), logits / penalty, logits)
42 logits = tl.where(is_repeated & (logits <= 0), logits * penalty, logits)
44 tl.store(logits_ptr + logits_idx, logits, mask=valid_vocab)
47def apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties):
48 logger.debug("GEMS APPLY REPETITION PENALTIES")
49 assert logits.is_contiguous(), "logits must be contiguous"
50 assert (
51 prompt_mask.is_contiguous() and prompt_mask.dtype == torch.bool
52 ), "prompt_mask must be contiguous bool tensor"
53 assert (
54 output_mask.is_contiguous() and output_mask.dtype == torch.bool
55 ), "output_mask must be contiguous bool tensor"
56 assert (
57 repetition_penalties.is_contiguous()
58 ), "repetition_penalties must be contiguous"
59 assert logits.dim() == 2, f"logits must be 2D, got {logits.dim()}D"
60 assert (
61 logits.shape == prompt_mask.shape == output_mask.shape
62 ), "shape mismatch between logits and masks"
63 assert (
64 repetition_penalties.dim() == 1
65 and repetition_penalties.numel() == logits.shape[0]
66 ), "repetition_penalties must be 1D with length equal to num_seqs"
68 num_seqs, vocab_size = logits.shape
70 BLOCK_SIZE = 1024
72 grid = (
73 num_seqs,
74 triton.cdiv(vocab_size, BLOCK_SIZE),
75 )
77 _repetition_penalty_kernel[grid](
78 logits,
79 prompt_mask,
80 output_mask,
81 repetition_penalties,
82 num_seqs,
83 vocab_size,
84 BLOCK_SIZE=BLOCK_SIZE,
85 )
86 return None