Coverage for src/flag_gems/fused/apply_repetition_penalties.py: 53%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _repetition_penalty_kernel( 

8 logits_ptr, 

9 prompt_mask_ptr, 

10 output_mask_ptr, 

11 penalties_ptr, 

12 num_seqs, 

13 vocab_size, 

14 BLOCK_SIZE: tl.constexpr, 

15): 

16 seq_idx = tl.program_id(0) 

17 vocab_offset = tl.program_id(1) * BLOCK_SIZE 

18 

19 if seq_idx >= num_seqs: 

20 return 

21 

22 penalty = tl.load(penalties_ptr + seq_idx) 

23 

24 vocab_idx = vocab_offset + tl.arange(0, BLOCK_SIZE) 

25 

26 valid_vocab = vocab_idx < vocab_size 

27 

28 logits_idx = seq_idx * vocab_size + vocab_idx 

29 mask_idx = logits_idx 

30 

31 prompt_mask = tl.load(prompt_mask_ptr + mask_idx, mask=valid_vocab, other=False) 

32 output_mask = tl.load(output_mask_ptr + mask_idx, mask=valid_vocab, other=False) 

33 logits = tl.load(logits_ptr + logits_idx, mask=valid_vocab, other=0.0) 

34 

35 is_repeated = prompt_mask | output_mask 

36 

37 logits = tl.where(is_repeated & (logits > 0), logits / penalty, logits) 

38 logits = tl.where(is_repeated & (logits <= 0), logits * penalty, logits) 

39 

40 tl.store(logits_ptr + logits_idx, logits, mask=valid_vocab) 

41 

42 

43def apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties): 

44 assert logits.is_contiguous(), "logits must be contiguous" 

45 assert ( 

46 prompt_mask.is_contiguous() and prompt_mask.dtype == torch.bool 

47 ), "prompt_mask must be contiguous bool tensor" 

48 assert ( 

49 output_mask.is_contiguous() and output_mask.dtype == torch.bool 

50 ), "output_mask must be contiguous bool tensor" 

51 assert ( 

52 repetition_penalties.is_contiguous() 

53 ), "repetition_penalties must be contiguous" 

54 assert logits.dim() == 2, f"logits must be 2D, got {logits.dim()}D" 

55 assert ( 

56 logits.shape == prompt_mask.shape == output_mask.shape 

57 ), "shape mismatch between logits and masks" 

58 assert ( 

59 repetition_penalties.dim() == 1 

60 and repetition_penalties.numel() == logits.shape[0] 

61 ), "repetition_penalties must be 1D with length equal to num_seqs" 

62 

63 num_seqs, vocab_size = logits.shape 

64 

65 BLOCK_SIZE = 1024 

66 

67 grid = ( 

68 num_seqs, 

69 triton.cdiv(vocab_size, BLOCK_SIZE), 

70 ) 

71 

72 _repetition_penalty_kernel[grid]( 

73 logits, 

74 prompt_mask, 

75 output_mask, 

76 repetition_penalties, 

77 num_seqs, 

78 vocab_size, 

79 BLOCK_SIZE=BLOCK_SIZE, 

80 ) 

81 return None