Coverage for src/flag_gems/fused/apply_repetition_penalties.py: 57%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def _repetition_penalty_kernel( 

12 logits_ptr, 

13 prompt_mask_ptr, 

14 output_mask_ptr, 

15 penalties_ptr, 

16 num_seqs, 

17 vocab_size, 

18 BLOCK_SIZE: tl.constexpr, 

19): 

20 seq_idx = tl.program_id(0) 

21 vocab_offset = tl.program_id(1) * BLOCK_SIZE 

22 

23 if seq_idx >= num_seqs: 

24 return 

25 

26 penalty = tl.load(penalties_ptr + seq_idx) 

27 

28 vocab_idx = vocab_offset + tl.arange(0, BLOCK_SIZE) 

29 

30 valid_vocab = vocab_idx < vocab_size 

31 

32 logits_idx = seq_idx * vocab_size + vocab_idx 

33 mask_idx = logits_idx 

34 

35 prompt_mask = tl.load(prompt_mask_ptr + mask_idx, mask=valid_vocab, other=False) 

36 output_mask = tl.load(output_mask_ptr + mask_idx, mask=valid_vocab, other=False) 

37 logits = tl.load(logits_ptr + logits_idx, mask=valid_vocab, other=0.0) 

38 

39 is_repeated = prompt_mask | output_mask 

40 

41 logits = tl.where(is_repeated & (logits > 0), logits / penalty, logits) 

42 logits = tl.where(is_repeated & (logits <= 0), logits * penalty, logits) 

43 

44 tl.store(logits_ptr + logits_idx, logits, mask=valid_vocab) 

45 

46 

47def apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties): 

48 logger.debug("GEMS APPLY REPETITION PENALTIES") 

49 assert logits.is_contiguous(), "logits must be contiguous" 

50 assert ( 

51 prompt_mask.is_contiguous() and prompt_mask.dtype == torch.bool 

52 ), "prompt_mask must be contiguous bool tensor" 

53 assert ( 

54 output_mask.is_contiguous() and output_mask.dtype == torch.bool 

55 ), "output_mask must be contiguous bool tensor" 

56 assert ( 

57 repetition_penalties.is_contiguous() 

58 ), "repetition_penalties must be contiguous" 

59 assert logits.dim() == 2, f"logits must be 2D, got {logits.dim()}D" 

60 assert ( 

61 logits.shape == prompt_mask.shape == output_mask.shape 

62 ), "shape mismatch between logits and masks" 

63 assert ( 

64 repetition_penalties.dim() == 1 

65 and repetition_penalties.numel() == logits.shape[0] 

66 ), "repetition_penalties must be 1D with length equal to num_seqs" 

67 

68 num_seqs, vocab_size = logits.shape 

69 

70 BLOCK_SIZE = 1024 

71 

72 grid = ( 

73 num_seqs, 

74 triton.cdiv(vocab_size, BLOCK_SIZE), 

75 ) 

76 

77 _repetition_penalty_kernel[grid]( 

78 logits, 

79 prompt_mask, 

80 output_mask, 

81 repetition_penalties, 

82 num_seqs, 

83 vocab_size, 

84 BLOCK_SIZE=BLOCK_SIZE, 

85 ) 

86 return None