Coverage for src/flag_gems/ops/conv3d.py: 57%
104 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.ops.conv2d import conv2d_output_size
10from flag_gems.utils import libentry
12logger = logging.getLogger(__name__)
15def conv3d_output_size(
16 in_size: int,
17 kernel_size: int,
18 stride: int,
19 padding: int,
20 dilation: int,
21) -> int:
22 """
23 Determines the output size of a 3D convolution operation.
25 Args:
26 in_size: Input size.
27 kernel_size: Kernel size.
28 stride: Stride.
29 padding: Padding.
30 dilation: Dilation.
32 Returns:
33 Output size of 3D convolution.
34 """
35 return conv2d_output_size(in_size, kernel_size, stride, padding, dilation)
38@libentry()
39@triton.autotune(
40 configs=runtime.get_tuned_config("conv3d_forward"),
41 key=[
42 "in_n",
43 "weight_c",
44 "input_depth",
45 "input_height",
46 "input_width",
47 "out_c",
48 "out_depth",
49 "out_height",
50 "out_width",
51 "weight_depth",
52 "weight_height",
53 "weight_width",
54 "stride_depth",
55 "stride_height",
56 "stride_width",
57 "padding_depth",
58 "padding_height",
59 "padding_width",
60 "groups",
61 ],
62)
63@triton.jit
64def conv3d_forward_kernel(
65 input_pointer,
66 weight_pointer,
67 output_pointer,
68 bias_pointer,
69 in_n,
70 input_depth,
71 input_height,
72 input_width,
73 out_c,
74 out_depth,
75 out_height,
76 out_width,
77 input_n_stride,
78 input_c_stride,
79 input_depth_stride,
80 input_height_stride,
81 input_width_stride,
82 weight_n_stride,
83 weight_c_stride,
84 weight_depth_stride,
85 weight_height_stride,
86 weight_width_stride,
87 output_n_stride,
88 output_c_stride,
89 output_depth_stride,
90 output_height_stride,
91 output_width_stride,
92 weight_c: tl.constexpr,
93 weight_depth: tl.constexpr,
94 weight_height: tl.constexpr,
95 weight_width: tl.constexpr,
96 stride_depth: tl.constexpr,
97 stride_height: tl.constexpr,
98 stride_width: tl.constexpr,
99 padding_depth: tl.constexpr,
100 padding_height: tl.constexpr,
101 padding_width: tl.constexpr,
102 dilation_depth: tl.constexpr,
103 dilation_height: tl.constexpr,
104 dilation_width: tl.constexpr,
105 groups: tl.constexpr,
106 BLOCK_NI_DO_HO_WO: tl.constexpr,
107 BLOCK_CI: tl.constexpr,
108 BLOCK_CO: tl.constexpr,
109):
110 pid_ni_do_ho_wo = tl.program_id(0)
111 pid_co = tl.program_id(1)
112 pid_group = tl.program_id(2)
114 # caculate in_n out_depth out_height out_weight value in kernel
115 ni_do_ho_wo_offset = pid_ni_do_ho_wo * BLOCK_NI_DO_HO_WO + tl.arange(
116 0, BLOCK_NI_DO_HO_WO
117 )
118 ni_do_ho_offset = ni_do_ho_wo_offset // out_width
119 ni_do_offset = ni_do_ho_offset // out_height
120 in_n_point_value = ni_do_offset // out_depth
121 output_depth_point_value = ni_do_offset % out_depth
122 output_height_point_value = ni_do_ho_offset % out_height
123 output_width_point_value = ni_do_ho_wo_offset % out_width
125 # Load the input and weight pointers. input and weight are of shape
126 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width]
127 out_per_group_c = out_c // groups
128 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
129 input_pointer += (
130 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c
131 )[:, None]
132 weight_pointer += (
133 weight_n_stride * output_c_offset
134 + weight_n_stride * pid_group * out_per_group_c
135 )[None, :]
137 accum = tl.zeros((BLOCK_NI_DO_HO_WO, BLOCK_CO), dtype=tl.float32)
138 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI
139 for dhwc in range(weight_depth * weight_height * weight_width * BLOCK_CI_COUNT):
140 c = (dhwc % BLOCK_CI_COUNT) * BLOCK_CI
141 dhw = dhwc // BLOCK_CI_COUNT
142 dh = dhw // weight_width
143 d = dh // weight_height
144 h = dh % weight_height
145 w = dhw % weight_width
147 input_c_offset = c + tl.arange(0, BLOCK_CI)
148 input_depth_offset = (
149 d * dilation_depth - padding_depth + stride_depth * output_depth_point_value
150 )
151 input_height_offset = (
152 h * dilation_height
153 - padding_height
154 + stride_height * output_height_point_value
155 )
156 input_width_offset = (
157 w * dilation_width - padding_width + stride_width * output_width_point_value
158 )
160 curr_input_pointer = (
161 input_pointer
162 + (input_c_stride * input_c_offset)[None, :]
163 + (input_depth_stride * input_depth_offset)[:, None]
164 + (input_height_stride * input_height_offset)[:, None]
165 + (input_width_stride * input_width_offset)[:, None]
166 )
167 curr_weight_pointer = (
168 weight_pointer
169 + (weight_c_stride * input_c_offset)[:, None]
170 + (weight_depth_stride * d)
171 + (weight_height_stride * h)
172 + (weight_width_stride * w)
173 )
175 input_mask = (
176 (in_n_point_value < in_n)[:, None]
177 & (input_c_offset < weight_c)[None, :]
178 & (0 <= input_depth_offset)[:, None]
179 & (input_depth_offset < input_depth)[:, None]
180 & (0 <= input_height_offset)[:, None]
181 & (input_height_offset < input_height)[:, None]
182 & (0 <= input_width_offset)[:, None]
183 & (input_width_offset < input_width)[:, None]
184 )
185 weight_mask = (input_c_offset < weight_c)[:, None] & (
186 output_c_offset < out_per_group_c
187 )[None, :]
189 input_block = tl.load(curr_input_pointer, mask=input_mask)
190 weight_block = tl.load(curr_weight_pointer, mask=weight_mask)
192 accum += tl.dot(input_block, weight_block, allow_tf32=False)
193 bias_pointer += (pid_group[None] * out_per_group_c)[None, :] + output_c_offset[
194 None, :
195 ]
196 mask_bias = (output_c_offset < out_per_group_c)[None, :]
197 bias = tl.load(bias_pointer, mask_bias).to(tl.float32)
198 accum += bias
199 output_pointer += (
200 (output_n_stride * in_n_point_value)[:, None]
201 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :]
202 + (output_depth_stride * output_depth_point_value)[:, None]
203 + (output_height_stride * output_height_point_value)[:, None]
204 + (output_width_stride * output_width_point_value)[:, None]
205 )
206 output_mask = (
207 (in_n_point_value < in_n)[:, None]
208 & (output_c_offset < out_per_group_c)[None, :]
209 & (output_depth_point_value < out_depth)[:, None]
210 & (output_height_point_value < out_height)[:, None]
211 & (output_width_point_value < out_width)[:, None]
212 )
214 tl.store(output_pointer, accum, mask=output_mask)
217# class Conv3d(torch.autograd.Function):
218# @staticmethod
219# def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
220# pass
223def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
224 logger.debug("GEMS CONV3D")
225 assert weight.ndim == 5, "Weights must be 5D, received shape {weight.shape}"
226 assert (
227 bias is None or bias.ndim == 1
228 ), "Bias must be 1D, received shape {bias.shape}"
230 assert (
231 input.shape[1] == groups * weight.shape[1]
232 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups"
233 assert (
234 bias is None or weight.shape[0] == bias.shape[0]
235 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape"
237 if isinstance(stride, (list, tuple)):
238 stride_depth, stride_height, stride_width = stride
239 else:
240 stride_depth = stride_height = stride_width = stride
242 if isinstance(dilation, (list, tuple)):
243 dilation_depth, dilation_height, dilation_width = dilation
244 else:
245 dilation_depth = dilation_height = dilation_width = dilation
247 if isinstance(padding, str):
248 if padding == "same":
249 assert (
250 stride_depth == 1 and stride_height == 1 and stride_width == 1
251 ), \
252 "Doesn't support any stride values other than 1 in padding = 'same' mode, \
253 received stride value {stride}"
254 id = input.shape[-3]
255 ih = input.shape[-2]
256 iw = input.shape[-1]
257 kernel_size_d = weight.shape[-3]
258 kernel_size_h = weight.shape[-2]
259 kernel_size_w = weight.shape[-1]
260 padding_depth = math.ceil(
261 (
262 stride_depth * (id - 1)
263 + 1
264 + dilation_depth * (kernel_size_d - 1)
265 - id
266 )
267 / 2
268 )
269 padding_height = math.ceil(
270 (
271 stride_height * (ih - 1)
272 + 1
273 + dilation_height * (kernel_size_h - 1)
274 - ih
275 )
276 / 2
277 )
278 padding_width = math.ceil(
279 (
280 stride_width * (iw - 1)
281 + 1
282 + dilation_width * (kernel_size_w - 1)
283 - iw
284 )
285 / 2
286 )
287 od = int(
288 (id + 2 * padding_depth - dilation_depth * (kernel_size_d - 1) - 1)
289 / stride_depth
290 + 1
291 )
292 oh = int(
293 (ih + 2 * padding_height - dilation_height * (kernel_size_h - 1) - 1)
294 / stride_height
295 + 1
296 )
297 ow = int(
298 (iw + 2 * padding_width - dilation_width * (kernel_size_w - 1) - 1)
299 / stride_width
300 + 1
301 )
302 elif padding == "valid":
303 padding_depth = padding_height = padding_width = 0
304 else:
305 raise ValueError(
306 f"Unsupported padding string: {padding}, only'valild'/'same' are allowed."
307 )
308 elif isinstance(padding, (list, tuple)):
309 padding_depth, padding_height, padding_width = padding
310 else:
311 padding_depth = padding_height = padding_width = padding
313 in_n, _, input_depth, input_height, input_width = input.shape
314 out_c, weight_c, weight_depth, weight_height, weight_width = weight.shape
315 out_depth = conv3d_output_size(
316 input_depth, weight_depth, stride_depth, padding_depth, dilation_depth
317 )
319 out_height = conv3d_output_size(
320 input_height, weight_height, stride_height, padding_height, dilation_height
321 )
322 out_width = conv3d_output_size(
323 input_width, weight_width, stride_width, padding_width, dilation_width
324 )
326 output_dtype = input.dtype
327 output = torch.empty(
328 (in_n, out_c, out_depth, out_height, out_width),
329 device=input.device,
330 dtype=output_dtype,
331 )
333 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions,
334 # BLOCK_CO along the out_c,
335 # one group per cat
336 grid = lambda META: (
337 triton.cdiv(
338 in_n * out_depth * out_height * out_width, META["BLOCK_NI_DO_HO_WO"]
339 ),
340 triton.cdiv(out_c // groups, META["BLOCK_CO"]),
341 groups,
342 )
344 if bias is None:
345 bias_pointer = torch.zeros(out_c, device=input.device, dtype=output_dtype)
346 else:
347 bias_pointer = bias
349 conv3d_forward_kernel[grid](
350 input,
351 weight,
352 output,
353 bias_pointer,
354 in_n,
355 input_depth,
356 input_height,
357 input_width,
358 out_c,
359 out_depth,
360 out_height,
361 out_width,
362 *input.stride(),
363 *weight.stride(),
364 *output.stride(),
365 weight_c,
366 weight_depth,
367 weight_height,
368 weight_width,
369 stride_depth,
370 stride_height,
371 stride_width,
372 padding_depth,
373 padding_height,
374 padding_width,
375 dilation_depth,
376 dilation_height,
377 dilation_width,
378 groups=groups,
379 )
381 if padding == "same":
382 output = output[..., (od - id) :, (oh - ih) :, (ow - iw) :]
384 return output