Coverage for src/flag_gems/experimental_ops/replication_pad3d.py: 0%
94 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def replication_pad3d_kernel(
8 in_ptr,
9 out_ptr,
10 N,
11 C,
12 D_in,
13 H_in,
14 W_in,
15 D_out,
16 H_out,
17 W_out,
18 pad_d_before,
19 pad_h_before,
20 pad_w_before,
21 in_stride_n,
22 in_stride_c,
23 in_stride_d,
24 in_stride_h,
25 in_stride_w,
26 out_stride_n,
27 out_stride_c,
28 out_stride_d,
29 out_stride_h,
30 out_stride_w,
31 n_elements,
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid = tl.program_id(axis=0)
35 block_start = pid * BLOCK_SIZE
36 offs = block_start + tl.arange(0, BLOCK_SIZE)
37 mask = offs < n_elements
39 # Unravel linear indices into (n, c, d_out, h_out, w_out)
40 w_out = offs % W_out
41 tmp = offs // W_out
42 h_out = tmp % H_out
43 tmp = tmp // H_out
44 d_out = tmp % D_out
45 tmp = tmp // D_out
46 c = tmp % C
47 n = tmp // C
49 # Compute clamped input indices
50 w_in = w_out - pad_w_before
51 w_in = tl.maximum(w_in, 0)
52 w_in = tl.minimum(w_in, W_in - 1)
54 h_in = h_out - pad_h_before
55 h_in = tl.maximum(h_in, 0)
56 h_in = tl.minimum(h_in, H_in - 1)
58 d_in = d_out - pad_d_before
59 d_in = tl.maximum(d_in, 0)
60 d_in = tl.minimum(d_in, D_in - 1)
62 # Compute input and output pointers (strided)
63 in_offset = (
64 n * in_stride_n
65 + c * in_stride_c
66 + d_in * in_stride_d
67 + h_in * in_stride_h
68 + w_in * in_stride_w
69 )
70 out_offset = (
71 n * out_stride_n
72 + c * out_stride_c
73 + d_out * out_stride_d
74 + h_out * out_stride_h
75 + w_out * out_stride_w
76 )
78 vals = tl.load(in_ptr + in_offset, mask=mask, other=0)
79 tl.store(out_ptr + out_offset, vals, mask=mask)
82def _normalize_3d_pad(padding):
83 if isinstance(padding, (list, tuple)) and len(padding) == 6:
84 return tuple(int(x) for x in padding)
85 raise ValueError(
86 "padding must be a sequence of 6 integers: (pad_w_left, pad_w_right, pad_h_top, pad_h_bottom, pad_d_front, pad_d_back)" # noqa: E501
87 )
90def _get_5d_shape_and_strides(t: torch.Tensor):
91 # Returns (N, C, D, H, W), (sN, sC, sD, sH, sW), and a flag indicating if original was 4D
92 if t.dim() == 5:
93 N, C, D, H, W = t.shape
94 sN, sC, sD, sH, sW = t.stride()
95 was_4d = False
96 return (N, C, D, H, W), (sN, sC, sD, sH, sW), was_4d
97 elif t.dim() == 4:
98 C, D, H, W = t.shape
99 sC, sD, sH, sW = t.stride()
100 # Emulate leading N=1 dimension with stride 0 for indexing
101 N = 1
102 sN = 0
103 was_4d = True
104 return (N, C, D, H, W), (sN, sC, sD, sH, sW), was_4d
105 else:
106 raise ValueError("Input must be 4D (C, D, H, W) or 5D (N, C, D, H, W).")
109def _launch_replication_pad3d_kernel(x: torch.Tensor, padding, out: torch.Tensor):
110 assert x.is_cuda and out.is_cuda, "Tensors must be on CUDA device"
111 assert x.dtype == out.dtype, "Input and output dtypes must match"
112 assert x.device == out.device, "Input and output must be on the same device"
113 assert x.is_contiguous(
114 memory_format=torch.contiguous_format
115 ), "Input must be contiguous"
116 # Output can be non-contiguous; we handle via strides
118 (
119 pad_w_before,
120 pad_w_after,
121 pad_h_before,
122 pad_h_after,
123 pad_d_before,
124 pad_d_after,
125 ) = _normalize_3d_pad(padding)
127 (
128 (N, C, D_in, H_in, W_in),
129 (in_sN, in_sC, in_sD, in_sH, in_sW),
130 x_was_4d,
131 ) = _get_5d_shape_and_strides(x)
132 (
133 (N_o, C_o, D_out, H_out, W_out),
134 (out_sN, out_sC, out_sD, out_sH, out_sW),
135 out_was_4d,
136 ) = _get_5d_shape_and_strides(out)
138 # Validate shapes
139 assert N_o == N and C_o == C, "Output N and C must match input"
140 expected_D_out = D_in + pad_d_before + pad_d_after
141 expected_H_out = H_in + pad_h_before + pad_h_after
142 expected_W_out = W_in + pad_w_before + pad_w_after
143 assert (D_out, H_out, W_out) == (
144 expected_D_out,
145 expected_H_out,
146 expected_W_out,
147 ), "Output spatial shape mismatch"
149 n_elements = out.numel()
150 if n_elements == 0:
151 return out
153 BLOCK_SIZE = 1024
154 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
156 replication_pad3d_kernel[grid](
157 x,
158 out,
159 N,
160 C,
161 D_in,
162 H_in,
163 W_in,
164 D_out,
165 H_out,
166 W_out,
167 pad_d_before,
168 pad_h_before,
169 pad_w_before,
170 in_sN,
171 in_sC,
172 in_sD,
173 in_sH,
174 in_sW,
175 out_sN,
176 out_sC,
177 out_sD,
178 out_sH,
179 out_sW,
180 n_elements,
181 BLOCK_SIZE=BLOCK_SIZE,
182 )
183 return out
186def replication_pad3d(input: torch.Tensor, padding):
187 (
188 pad_w_before,
189 pad_w_after,
190 pad_h_before,
191 pad_h_after,
192 pad_d_before,
193 pad_d_after,
194 ) = _normalize_3d_pad(padding)
195 (N, C, D_in, H_in, W_in), _, was_4d = _get_5d_shape_and_strides(input)
197 D_out = D_in + pad_d_before + pad_d_after
198 H_out = H_in + pad_h_before + pad_h_after
199 W_out = W_in + pad_w_before + pad_w_after
201 if was_4d:
202 out_shape = (C, D_out, H_out, W_out)
203 else:
204 out_shape = (N, C, D_out, H_out, W_out)
206 out = torch.empty(out_shape, device=input.device, dtype=input.dtype)
207 _launch_replication_pad3d_kernel(
208 input,
209 (
210 pad_w_before,
211 pad_w_after,
212 pad_h_before,
213 pad_h_after,
214 pad_d_before,
215 pad_d_after,
216 ),
217 out,
218 )
219 return out
222def replication_pad3d_out(input: torch.Tensor, padding, out: torch.Tensor):
223 (
224 pad_w_before,
225 pad_w_after,
226 pad_h_before,
227 pad_h_after,
228 pad_d_before,
229 pad_d_after,
230 ) = _normalize_3d_pad(padding)
231 (N, C, D_in, H_in, W_in), _, was_4d_in = _get_5d_shape_and_strides(input)
233 D_out = D_in + pad_d_before + pad_d_after
234 H_out = H_in + pad_h_before + pad_h_after
235 W_out = W_in + pad_w_before + pad_w_after
237 # Validate provided out shape
238 if out.dim() == 5:
239 expected_out_shape = (N, C, D_out, H_out, W_out)
240 elif out.dim() == 4:
241 expected_out_shape = (C, D_out, H_out, W_out)
242 else:
243 raise ValueError("out tensor must be 4D or 5D")
244 assert (
245 tuple(out.shape) == expected_out_shape
246 ), f"out has incorrect shape, expected {expected_out_shape}, got {tuple(out.shape)}"
248 _launch_replication_pad3d_kernel(
249 input,
250 (
251 pad_w_before,
252 pad_w_after,
253 pad_h_before,
254 pad_h_after,
255 pad_d_before,
256 pad_d_after,
257 ),
258 out,
259 )
260 return out