Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/conv3d.py: 0%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.utils import libentry
10from .conv2d import conv2d_output_size
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def conv3d_output_size(
16 in_size: int,
17 kernel_size: int,
18 stride: int,
19 padding: int,
20 dilation: int,
21) -> int:
22 """
23 Determines the output size of a 3D convolution operation.
25 Args:
26 in_size: Input size.
27 kernel_size: Kernel size.
28 stride: Stride.
29 padding: Padding.
30 dilation: Dilation.
32 Returns:
33 Output size of 3D convolution.
34 """
35 return conv2d_output_size(in_size, kernel_size, stride, padding, dilation)
38@libentry()
39# @triton.autotune(
40# configs=runtime.get_tuned_config("conv3d_forward"),
41# key=[
42# "in_n",
43# "weight_c",
44# "input_depth",
45# "input_height",
46# "input_width",
47# "out_c",
48# "out_depth",
49# "out_height",
50# "out_width",
51# "weight_depth",
52# "weight_height",
53# "weight_width",
54# "stride_depth",
55# "stride_height",
56# "stride_width",
57# "padding_depth",
58# "padding_height",
59# "padding_width",
60# "groups",
61# ],
62# )
63@triton.jit
64def conv3d_forward_kernel(
65 input_pointer,
66 weight_pointer,
67 output_pointer,
68 bias_pointer,
69 in_n,
70 input_depth,
71 input_height,
72 input_width,
73 out_c,
74 out_depth,
75 out_height,
76 out_width,
77 input_n_stride,
78 input_c_stride,
79 input_depth_stride,
80 input_height_stride,
81 input_width_stride,
82 weight_n_stride,
83 weight_c_stride,
84 weight_depth_stride,
85 weight_height_stride,
86 weight_width_stride,
87 output_n_stride,
88 output_c_stride,
89 output_depth_stride,
90 output_height_stride,
91 output_width_stride,
92 weight_c: tl.constexpr,
93 weight_depth: tl.constexpr,
94 weight_height: tl.constexpr,
95 weight_width: tl.constexpr,
96 stride_depth: tl.constexpr,
97 stride_height: tl.constexpr,
98 stride_width: tl.constexpr,
99 padding_depth: tl.constexpr,
100 padding_height: tl.constexpr,
101 padding_width: tl.constexpr,
102 dilation_depth: tl.constexpr,
103 dilation_height: tl.constexpr,
104 dilation_width: tl.constexpr,
105 groups: tl.constexpr,
106 BLOCK_NI_DO_HO_WO: tl.constexpr,
107 BLOCK_CI: tl.constexpr,
108 BLOCK_CO: tl.constexpr,
109):
110 pid_ni_do_ho_wo = tl.program_id(0)
111 pid_co = tl.program_id(1)
112 pid_group = tl.program_id(2)
114 # caculate in_n out_depth out_height out_weight value in kernel
115 ni_do_ho_wo_offset = pid_ni_do_ho_wo * BLOCK_NI_DO_HO_WO + tl.arange(
116 0, BLOCK_NI_DO_HO_WO
117 )
118 ni_do_ho_offset = ni_do_ho_wo_offset // out_width
119 ni_do_offset = ni_do_ho_offset // out_height
120 in_n_point_value = ni_do_offset // out_depth
121 output_depth_point_value = ni_do_offset % out_depth
122 output_height_point_value = ni_do_ho_offset % out_height
123 output_width_point_value = ni_do_ho_wo_offset % out_width
125 # Load the input and weight pointers. input and weight are of shape
126 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width]
127 out_per_group_c = out_c // groups
128 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
129 input_pointer += (
130 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c
131 )[:, None]
132 weight_pointer += (
133 weight_n_stride * output_c_offset
134 + weight_n_stride * pid_group * out_per_group_c
135 )[None, :]
137 accum = tl.zeros((BLOCK_NI_DO_HO_WO, BLOCK_CO), dtype=tl.float32)
138 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI
139 for dhwc in range(weight_depth * weight_height * weight_width * BLOCK_CI_COUNT):
140 c = (dhwc % BLOCK_CI_COUNT) * BLOCK_CI
141 dhw = dhwc // BLOCK_CI_COUNT
142 dh = dhw // weight_width
143 d = dh // weight_height
144 h = dh % weight_height
145 w = dhw % weight_width
147 input_c_offset = c + tl.arange(0, BLOCK_CI)
148 input_depth_offset = (
149 d * dilation_depth - padding_depth + stride_depth * output_depth_point_value
150 )
151 input_height_offset = (
152 h * dilation_height
153 - padding_height
154 + stride_height * output_height_point_value
155 )
156 input_width_offset = (
157 w * dilation_width - padding_width + stride_width * output_width_point_value
158 )
160 curr_input_pointer = (
161 input_pointer
162 + (input_c_stride * input_c_offset)[None, :]
163 + (input_depth_stride * input_depth_offset)[:, None]
164 + (input_height_stride * input_height_offset)[:, None]
165 + (input_width_stride * input_width_offset)[:, None]
166 )
167 curr_weight_pointer = (
168 weight_pointer
169 + (weight_c_stride * input_c_offset)[:, None]
170 + (weight_depth_stride * d)
171 + (weight_height_stride * h)
172 + (weight_width_stride * w)
173 )
175 input_mask = (
176 (in_n_point_value < in_n)[:, None]
177 & (input_c_offset < weight_c)[None, :]
178 & (0 <= input_depth_offset)[:, None]
179 & (input_depth_offset < input_depth)[:, None]
180 & (0 <= input_height_offset)[:, None]
181 & (input_height_offset < input_height)[:, None]
182 & (0 <= input_width_offset)[:, None]
183 & (input_width_offset < input_width)[:, None]
184 )
185 weight_mask = (input_c_offset < weight_c)[:, None] & (
186 output_c_offset < out_per_group_c
187 )[None, :]
189 input_block = tl.load(curr_input_pointer, mask=input_mask)
190 weight_block = tl.load(curr_weight_pointer, mask=weight_mask)
192 accum += tl.dot(input_block, weight_block, allow_tf32=False)
193 bias_pointer += (pid_group[None] * out_per_group_c)[None, :] + output_c_offset[
194 None, :
195 ]
196 mask_bias = (output_c_offset < out_per_group_c)[None, :]
197 bias = tl.load(bias_pointer, mask_bias).to(tl.float32)
198 accum += bias
199 output_pointer += (
200 (output_n_stride * in_n_point_value)[:, None]
201 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :]
202 + (output_depth_stride * output_depth_point_value)[:, None]
203 + (output_height_stride * output_height_point_value)[:, None]
204 + (output_width_stride * output_width_point_value)[:, None]
205 )
206 output_mask = (
207 (in_n_point_value < in_n)[:, None]
208 & (output_c_offset < out_per_group_c)[None, :]
209 & (output_depth_point_value < out_depth)[:, None]
210 & (output_height_point_value < out_height)[:, None]
211 & (output_width_point_value < out_width)[:, None]
212 )
214 tl.store(output_pointer, accum, mask=output_mask)
217# class Conv3d(torch.autograd.Function):
218# @staticmethod
219# def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
220# pass
223def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
224 logger.debug("GEMS CONV3D")
225 assert weight.ndim == 5, "Weights must be 5D, received shape {weight.shape}"
226 assert (
227 bias is None or bias.ndim == 1
228 ), "Bias must be 1D, received shape {bias.shape}"
230 assert (
231 input.shape[1] == groups * weight.shape[1]
232 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups"
233 assert (
234 bias is None or weight.shape[0] == bias.shape[0]
235 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape"
237 if isinstance(stride, (list, tuple)):
238 stride_depth, stride_height, stride_width = stride
239 else:
240 stride_depth = stride_height = stride_width = stride
242 if isinstance(padding, (list, tuple)):
243 padding_depth, padding_height, padding_width = padding
244 else:
245 padding_depth = padding_height = padding_width = padding
247 if isinstance(dilation, (list, tuple)):
248 dilation_depth, dilation_height, dilation_width = dilation
249 else:
250 dilation_depth = dilation_height = dilation_width = dilation
252 in_n, _, input_depth, input_height, input_width = input.shape
253 out_c, weight_c, weight_depth, weight_height, weight_width = weight.shape
254 out_depth = conv3d_output_size(
255 input_depth, weight_depth, stride_depth, padding_depth, dilation_depth
256 )
258 out_height = conv3d_output_size(
259 input_height, weight_height, stride_height, padding_height, dilation_height
260 )
261 out_width = conv3d_output_size(
262 input_width, weight_width, stride_width, padding_width, dilation_width
263 )
265 output_dtype = input.dtype
267 # For float16 inputs, promote to float32 for computation to prevent overflow
268 # Same issue as conv2d: float16 max value ~65504, 3D convolution with large
269 # channels/kernels easily overflows causing NaN propagation
270 use_fp32_compute = input.dtype == torch.float16
271 if use_fp32_compute:
272 input = input.to(torch.float32)
273 weight = weight.to(torch.float32)
274 if bias is not None:
275 bias = bias.to(torch.float32)
276 compute_dtype = torch.float32
277 else:
278 compute_dtype = output_dtype
280 output = torch.empty(
281 (in_n, out_c, out_depth, out_height, out_width),
282 device=input.device,
283 dtype=compute_dtype,
284 )
286 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions,
287 # BLOCK_CO along the out_c,
288 # one group per cat
289 grid = lambda META: (
290 triton.cdiv(
291 in_n * out_depth * out_height * out_width, META["BLOCK_NI_DO_HO_WO"]
292 ),
293 triton.cdiv(out_c // groups, META["BLOCK_CO"]),
294 groups,
295 )
297 if bias is None:
298 bias_pointer = torch.zeros(out_c, device=input.device, dtype=torch.float)
299 else:
300 bias_pointer = bias.to(torch.float)
302 conv3d_forward_kernel[grid](
303 input,
304 weight,
305 output,
306 bias_pointer,
307 in_n,
308 input_depth,
309 input_height,
310 input_width,
311 out_c,
312 out_depth,
313 out_height,
314 out_width,
315 *input.stride(),
316 *weight.stride(),
317 *output.stride(),
318 weight_c,
319 weight_depth,
320 weight_height,
321 weight_width,
322 stride_depth,
323 stride_height,
324 stride_width,
325 padding_depth,
326 padding_height,
327 padding_width,
328 dilation_depth,
329 dilation_height,
330 dilation_width,
331 groups=groups,
332 BLOCK_NI_DO_HO_WO=32,
333 BLOCK_CI=32,
334 BLOCK_CO=32,
335 )
337 # Convert back to original dtype if we promoted to fp32
338 if use_fp32_compute:
339 output = output.to(output_dtype)
341 return output