Coverage for src/flag_gems/ops/reflection_pad2d.py: 46%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def reflection_pad2d_kernel( 

13 in_ptr, 

14 out_ptr, 

15 B, 

16 H_in, 

17 W_in, 

18 pad_left, 

19 pad_top, 

20 H_out, 

21 W_out, 

22 BLOCK_HW: tl.constexpr, 

23): 

24 pid_b = tl.program_id(axis=0) 

25 pid_n = tl.program_id(axis=1) 

26 

27 # Flatten 2D index to 1D for block processing 

28 offs_n = pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW) 

29 # Decode to (h, w) coordinates 

30 h_idx = offs_n // W_out 

31 w_idx = offs_n % W_out 

32 

33 mask = (offs_n < H_out * W_out) & (pid_b < B) 

34 

35 base_in = pid_b * (H_in * W_in) 

36 base_out = pid_b * (H_out * W_out) 

37 

38 # Compute reflected indices for height 

39 y = h_idx.to(tl.int32) - pad_top 

40 Hm1 = H_in - 1 

41 pH = 2 * Hm1 

42 t_h = tl.abs(y) 

43 m_h = t_h % pH 

44 ih = tl.where(m_h < H_in, m_h, pH - m_h) 

45 

46 # Compute reflected indices for width 

47 x = w_idx.to(tl.int32) - pad_left 

48 Wm1 = W_in - 1 

49 pW = 2 * Wm1 

50 t_w = tl.abs(x) 

51 m_w = t_w % pW 

52 iw = tl.where(m_w < W_in, m_w, pW - m_w) 

53 

54 # Load from input and store to output 

55 in_offs = ih * W_in + iw 

56 vals = tl.load(in_ptr + base_in + in_offs, mask=mask, other=0) 

57 tl.store(out_ptr + base_out + offs_n, vals, mask=mask) 

58 

59 

60@triton.jit 

61def copy_tensor_kernel(in_ptr, out_ptr, B, H, W, BLOCK_HW: tl.constexpr): 

62 pid_b = tl.program_id(axis=0) 

63 pid_n = tl.program_id(axis=1) 

64 

65 offs_n = pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW) 

66 mask = (offs_n < H * W) & (pid_b < B) 

67 

68 base = pid_b * (H * W) 

69 vals = tl.load(in_ptr + base + offs_n, mask=mask, other=0) 

70 tl.store(out_ptr + base + offs_n, vals, mask=mask) 

71 

72 

73def launch_reflection_pad2d(input: torch.Tensor, padding, out: torch.Tensor = None): 

74 # Validate padding format 

75 if not isinstance(padding, (list, tuple)): 

76 raise ValueError("padding must be a sequence") 

77 if len(padding) != 4: 

78 raise ValueError( 

79 "padding must be a sequence of length 4: (pad_left, pad_right, pad_top, pad_bottom)" 

80 ) 

81 pad_left, pad_right, pad_top, pad_bottom = [int(p) for p in padding] 

82 

83 # Validate padding values 

84 if pad_left < 0 or pad_right < 0 or pad_top < 0 or pad_bottom < 0: 

85 raise ValueError("padding values must be >= 0") 

86 

87 # Validate input 

88 if input.dim() < 3: 

89 raise ValueError("input must have at least 3 dimensions") 

90 if not input.is_cuda: 

91 raise ValueError("input must be a CUDA tensor") 

92 

93 x = input.contiguous() 

94 H_in = int(x.shape[-2]) 

95 W_in = int(x.shape[-1]) 

96 # Validate reflection padding constraints 

97 if H_in < 2 or W_in < 2: 

98 raise ValueError( 

99 "input spatial dimensions must be at least 2 for reflection padding when padding > 0" 

100 ) 

101 if H_in <= 0 or W_in <= 0: 

102 raise ValueError("spatial dimensions must be > 0") 

103 if pad_left >= W_in or pad_right >= W_in or pad_top >= H_in or pad_bottom >= H_in: 

104 raise ValueError( 

105 "padding values must be less than the input spatial dimensions for reflection padding" 

106 ) 

107 

108 H_out = H_in + pad_top + pad_bottom 

109 W_out = W_in + pad_left + pad_right 

110 

111 leading_shape = x.shape[:-2] 

112 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1 

113 

114 # Handle output tensor 

115 if out is None: 

116 out = torch.empty( 

117 (*leading_shape, H_out, W_out), device=x.device, dtype=x.dtype 

118 ) 

119 else: 

120 if not out.is_cuda: 

121 raise ValueError("out must be a CUDA tensor") 

122 expected_shape = (*leading_shape, H_out, W_out) 

123 if tuple(out.shape) != expected_shape: 

124 raise ValueError( 

125 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}" 

126 ) 

127 if out.dtype != x.dtype: 

128 raise ValueError( 

129 f"out dtype {out.dtype} does not match input dtype {x.dtype}" 

130 ) 

131 if out.device != x.device: 

132 raise ValueError("out must be on the same device as input") 

133 out = out.contiguous() 

134 

135 # No padding: just copy 

136 if pad_left == 0 and pad_right == 0 and pad_top == 0 and pad_bottom == 0: 

137 BLOCK_HW = 256 

138 grid = (B, triton.cdiv(H_in * W_in, BLOCK_HW)) 

139 copy_tensor_kernel[grid](x, out, B, H_in, W_in, BLOCK_HW=BLOCK_HW) 

140 return out 

141 

142 BLOCK_HW = 256 

143 grid = (B, triton.cdiv(H_out * W_out, BLOCK_HW)) 

144 reflection_pad2d_kernel[grid]( 

145 x, out, B, H_in, W_in, pad_left, pad_top, H_out, W_out, BLOCK_HW=BLOCK_HW 

146 ) 

147 return out 

148 

149 

150def reflection_pad2d(input: torch.Tensor, padding): 

151 logger.debug("GEMS REFLECTION_PAD2D") 

152 return launch_reflection_pad2d(input, padding, out=None) 

153 

154 

155def reflection_pad2d_out(input: torch.Tensor, padding, out: torch.Tensor): 

156 logger.debug("GEMS REFLECTION_PAD2D_OUT") 

157 return launch_reflection_pad2d(input, padding, out=out)