Coverage for src/flag_gems/ops/reflection_pad1d.py: 54%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@triton.jit 

15def reflection_pad1d_kernel( 

16 in_ptr, out_ptr, B, W_in, pad_left, W_out, BLOCK_W: tl.constexpr 

17): 

18 pid_b = tl.program_id(axis=0) 

19 pid_w = tl.program_id(axis=1) 

20 

21 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

22 mask = (offs_w < W_out) & (pid_b < B) 

23 

24 base_in = pid_b * W_in 

25 base_out = pid_b * W_out 

26 

27 # Compute reflected indices 

28 x = offs_w.to(tl.int32) - pad_left # shift by left pad 

29 Wm1 = W_in - 1 

30 p = 2 * Wm1 # period for reflection; guaranteed > 0 when this kernel is used 

31 

32 t = tl.abs(x) 

33 m = t % p 

34 iw = tl.where(m < W_in, m, p - m) 

35 

36 vals = tl.load(in_ptr + base_in + iw, mask=mask, other=0) 

37 tl.store(out_ptr + base_out + offs_w, vals, mask=mask) 

38 

39 

40@triton.jit 

41def _copy_rows_kernel(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr): 

42 pid_b = tl.program_id(axis=0) 

43 pid_w = tl.program_id(axis=1) 

44 

45 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

46 mask = (offs_w < W) & (pid_b < B) 

47 

48 base = pid_b * W 

49 vals = tl.load(in_ptr + base + offs_w, mask=mask, other=0) 

50 tl.store(out_ptr + base + offs_w, vals, mask=mask) 

51 

52 

53def _launch_reflection_pad1d(input: torch.Tensor, padding, out: torch.Tensor = None): 

54 if not isinstance(padding, (list, tuple)) or len(padding) != 2: 

55 raise ValueError( 

56 "padding must be a sequence of length 2: (pad_left, pad_right)" 

57 ) 

58 pad_left, pad_right = int(padding[0]), int(padding[1]) 

59 if pad_left < 0 or pad_right < 0: 

60 raise ValueError("padding values must be >= 0") 

61 if input.dim() < 1: 

62 raise ValueError("input must have at least 1 dimension") 

63 

64 x = input.contiguous() 

65 W_in = int(x.shape[-1]) 

66 if W_in <= 0: 

67 raise ValueError("last dimension (width) must be > 0") 

68 

69 W_out = W_in + pad_left + pad_right 

70 leading_shape = x.shape[:-1] 

71 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1 

72 

73 if out is None: 

74 out = torch.empty((*leading_shape, W_out), device=x.device, dtype=x.dtype) 

75 else: 

76 expected_shape = (*leading_shape, W_out) 

77 if tuple(out.shape) != expected_shape: 

78 raise ValueError( 

79 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}" 

80 ) 

81 if out.dtype != x.dtype: 

82 raise ValueError( 

83 f"out dtype {out.dtype} does not match input dtype {x.dtype}" 

84 ) 

85 if out.device != x.device: 

86 raise ValueError("out must be on the same device as input") 

87 out = out.contiguous() 

88 

89 # No padding: just copy 

90 if pad_left == 0 and pad_right == 0: 

91 if W_out != W_in: 

92 raise RuntimeError( 

93 "Internal error: W_out should equal W_in when no padding" 

94 ) 

95 grid = (B, triton.cdiv(W_in, 256)) 

96 with torch_device_fn.device(x.device): 

97 _copy_rows_kernel[grid](x, out, B, W_in, BLOCK_W=256) 

98 return out 

99 

100 # Validate reflection padding constraints 

101 if W_in < 2: 

102 raise ValueError( 

103 "input width must be at least 2 for reflection padding when padding > 0" 

104 ) 

105 if pad_left >= W_in or pad_right >= W_in: 

106 raise ValueError( 

107 "padding values must be less than the input width for reflection padding" 

108 ) 

109 

110 grid = (B, triton.cdiv(W_out, 256)) 

111 with torch_device_fn.device(x.device): 

112 reflection_pad1d_kernel[grid](x, out, B, W_in, pad_left, W_out, BLOCK_W=256) 

113 return out 

114 

115 

116def reflection_pad1d(input: torch.Tensor, padding): 

117 logger.debug("GEMS REFLECTION_PAD1D") 

118 return _launch_reflection_pad1d(input, padding, out=None) 

119 

120 

121def reflection_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor): 

122 logger.debug("GEMS REFLECTION_PAD1D_OUT") 

123 return _launch_reflection_pad1d(input, padding, out=out)