Coverage for src/flag_gems/experimental_ops/replication_pad3d.py: 0%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def replication_pad3d_kernel( 

8 in_ptr, 

9 out_ptr, 

10 N, 

11 C, 

12 D_in, 

13 H_in, 

14 W_in, 

15 D_out, 

16 H_out, 

17 W_out, 

18 pad_d_before, 

19 pad_h_before, 

20 pad_w_before, 

21 in_stride_n, 

22 in_stride_c, 

23 in_stride_d, 

24 in_stride_h, 

25 in_stride_w, 

26 out_stride_n, 

27 out_stride_c, 

28 out_stride_d, 

29 out_stride_h, 

30 out_stride_w, 

31 n_elements, 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid = tl.program_id(axis=0) 

35 block_start = pid * BLOCK_SIZE 

36 offs = block_start + tl.arange(0, BLOCK_SIZE) 

37 mask = offs < n_elements 

38 

39 # Unravel linear indices into (n, c, d_out, h_out, w_out) 

40 w_out = offs % W_out 

41 tmp = offs // W_out 

42 h_out = tmp % H_out 

43 tmp = tmp // H_out 

44 d_out = tmp % D_out 

45 tmp = tmp // D_out 

46 c = tmp % C 

47 n = tmp // C 

48 

49 # Compute clamped input indices 

50 w_in = w_out - pad_w_before 

51 w_in = tl.maximum(w_in, 0) 

52 w_in = tl.minimum(w_in, W_in - 1) 

53 

54 h_in = h_out - pad_h_before 

55 h_in = tl.maximum(h_in, 0) 

56 h_in = tl.minimum(h_in, H_in - 1) 

57 

58 d_in = d_out - pad_d_before 

59 d_in = tl.maximum(d_in, 0) 

60 d_in = tl.minimum(d_in, D_in - 1) 

61 

62 # Compute input and output pointers (strided) 

63 in_offset = ( 

64 n * in_stride_n 

65 + c * in_stride_c 

66 + d_in * in_stride_d 

67 + h_in * in_stride_h 

68 + w_in * in_stride_w 

69 ) 

70 out_offset = ( 

71 n * out_stride_n 

72 + c * out_stride_c 

73 + d_out * out_stride_d 

74 + h_out * out_stride_h 

75 + w_out * out_stride_w 

76 ) 

77 

78 vals = tl.load(in_ptr + in_offset, mask=mask, other=0) 

79 tl.store(out_ptr + out_offset, vals, mask=mask) 

80 

81 

82def _normalize_3d_pad(padding): 

83 if isinstance(padding, (list, tuple)) and len(padding) == 6: 

84 return tuple(int(x) for x in padding) 

85 raise ValueError( 

86 "padding must be a sequence of 6 integers: (pad_w_left, pad_w_right, pad_h_top, pad_h_bottom, pad_d_front, pad_d_back)" # noqa: E501 

87 ) 

88 

89 

90def _get_5d_shape_and_strides(t: torch.Tensor): 

91 # Returns (N, C, D, H, W), (sN, sC, sD, sH, sW), and a flag indicating if original was 4D 

92 if t.dim() == 5: 

93 N, C, D, H, W = t.shape 

94 sN, sC, sD, sH, sW = t.stride() 

95 was_4d = False 

96 return (N, C, D, H, W), (sN, sC, sD, sH, sW), was_4d 

97 elif t.dim() == 4: 

98 C, D, H, W = t.shape 

99 sC, sD, sH, sW = t.stride() 

100 # Emulate leading N=1 dimension with stride 0 for indexing 

101 N = 1 

102 sN = 0 

103 was_4d = True 

104 return (N, C, D, H, W), (sN, sC, sD, sH, sW), was_4d 

105 else: 

106 raise ValueError("Input must be 4D (C, D, H, W) or 5D (N, C, D, H, W).") 

107 

108 

109def _launch_replication_pad3d_kernel(x: torch.Tensor, padding, out: torch.Tensor): 

110 assert x.is_cuda and out.is_cuda, "Tensors must be on CUDA device" 

111 assert x.dtype == out.dtype, "Input and output dtypes must match" 

112 assert x.device == out.device, "Input and output must be on the same device" 

113 assert x.is_contiguous( 

114 memory_format=torch.contiguous_format 

115 ), "Input must be contiguous" 

116 # Output can be non-contiguous; we handle via strides 

117 

118 ( 

119 pad_w_before, 

120 pad_w_after, 

121 pad_h_before, 

122 pad_h_after, 

123 pad_d_before, 

124 pad_d_after, 

125 ) = _normalize_3d_pad(padding) 

126 

127 ( 

128 (N, C, D_in, H_in, W_in), 

129 (in_sN, in_sC, in_sD, in_sH, in_sW), 

130 x_was_4d, 

131 ) = _get_5d_shape_and_strides(x) 

132 ( 

133 (N_o, C_o, D_out, H_out, W_out), 

134 (out_sN, out_sC, out_sD, out_sH, out_sW), 

135 out_was_4d, 

136 ) = _get_5d_shape_and_strides(out) 

137 

138 # Validate shapes 

139 assert N_o == N and C_o == C, "Output N and C must match input" 

140 expected_D_out = D_in + pad_d_before + pad_d_after 

141 expected_H_out = H_in + pad_h_before + pad_h_after 

142 expected_W_out = W_in + pad_w_before + pad_w_after 

143 assert (D_out, H_out, W_out) == ( 

144 expected_D_out, 

145 expected_H_out, 

146 expected_W_out, 

147 ), "Output spatial shape mismatch" 

148 

149 n_elements = out.numel() 

150 if n_elements == 0: 

151 return out 

152 

153 BLOCK_SIZE = 1024 

154 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

155 

156 replication_pad3d_kernel[grid]( 

157 x, 

158 out, 

159 N, 

160 C, 

161 D_in, 

162 H_in, 

163 W_in, 

164 D_out, 

165 H_out, 

166 W_out, 

167 pad_d_before, 

168 pad_h_before, 

169 pad_w_before, 

170 in_sN, 

171 in_sC, 

172 in_sD, 

173 in_sH, 

174 in_sW, 

175 out_sN, 

176 out_sC, 

177 out_sD, 

178 out_sH, 

179 out_sW, 

180 n_elements, 

181 BLOCK_SIZE=BLOCK_SIZE, 

182 ) 

183 return out 

184 

185 

186def replication_pad3d(input: torch.Tensor, padding): 

187 ( 

188 pad_w_before, 

189 pad_w_after, 

190 pad_h_before, 

191 pad_h_after, 

192 pad_d_before, 

193 pad_d_after, 

194 ) = _normalize_3d_pad(padding) 

195 (N, C, D_in, H_in, W_in), _, was_4d = _get_5d_shape_and_strides(input) 

196 

197 D_out = D_in + pad_d_before + pad_d_after 

198 H_out = H_in + pad_h_before + pad_h_after 

199 W_out = W_in + pad_w_before + pad_w_after 

200 

201 if was_4d: 

202 out_shape = (C, D_out, H_out, W_out) 

203 else: 

204 out_shape = (N, C, D_out, H_out, W_out) 

205 

206 out = torch.empty(out_shape, device=input.device, dtype=input.dtype) 

207 _launch_replication_pad3d_kernel( 

208 input, 

209 ( 

210 pad_w_before, 

211 pad_w_after, 

212 pad_h_before, 

213 pad_h_after, 

214 pad_d_before, 

215 pad_d_after, 

216 ), 

217 out, 

218 ) 

219 return out 

220 

221 

222def replication_pad3d_out(input: torch.Tensor, padding, out: torch.Tensor): 

223 ( 

224 pad_w_before, 

225 pad_w_after, 

226 pad_h_before, 

227 pad_h_after, 

228 pad_d_before, 

229 pad_d_after, 

230 ) = _normalize_3d_pad(padding) 

231 (N, C, D_in, H_in, W_in), _, was_4d_in = _get_5d_shape_and_strides(input) 

232 

233 D_out = D_in + pad_d_before + pad_d_after 

234 H_out = H_in + pad_h_before + pad_h_after 

235 W_out = W_in + pad_w_before + pad_w_after 

236 

237 # Validate provided out shape 

238 if out.dim() == 5: 

239 expected_out_shape = (N, C, D_out, H_out, W_out) 

240 elif out.dim() == 4: 

241 expected_out_shape = (C, D_out, H_out, W_out) 

242 else: 

243 raise ValueError("out tensor must be 4D or 5D") 

244 assert ( 

245 tuple(out.shape) == expected_out_shape 

246 ), f"out has incorrect shape, expected {expected_out_shape}, got {tuple(out.shape)}" 

247 

248 _launch_replication_pad3d_kernel( 

249 input, 

250 ( 

251 pad_w_before, 

252 pad_w_after, 

253 pad_h_before, 

254 pad_h_after, 

255 pad_d_before, 

256 pad_d_after, 

257 ), 

258 out, 

259 ) 

260 return out