Coverage for src/flag_gems/runtime/backend/_mthreads/ops/conv2d.py: 0%
161 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import libentry
10logger = logging.getLogger(
11 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
12)
15def conv2d_output_size(
16 in_size: int,
17 kernel_size: int,
18 stride: int,
19 padding: int,
20 dilation: int,
21) -> int:
22 """
23 Determines the output size of a 2D convolution operation.
25 Args:
26 in_size: Input size.
27 kernel_size: Kernel size.
28 stride: Stride.
29 padding: Padding.
30 dilation: Dilation.
32 Returns:
33 Output size of 2D convolution.
34 """
35 return (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
38@libentry()
39@triton.autotune(
40 configs=runtime.get_tuned_config("conv2d_forward"),
41 key=[
42 "in_n",
43 "weight_c",
44 "input_height",
45 "input_width",
46 "out_c",
47 "out_height",
48 "out_width",
49 "weight_height",
50 "weight_width",
51 "stride_height",
52 "stride_width",
53 "padding_height",
54 "padding_width",
55 "groups",
56 ],
57)
58@triton.jit
59def conv2d_forward_kernel(
60 input_pointer,
61 weight_pointer,
62 output_pointer,
63 bias_pointer,
64 in_n,
65 input_height,
66 input_width,
67 out_c,
68 out_height,
69 out_width,
70 input_n_stride,
71 input_c_stride,
72 input_height_stride,
73 input_width_stride,
74 weight_n_stride,
75 weight_c_stride,
76 weight_height_stride,
77 weight_width_stride,
78 output_n_stride,
79 output_c_stride,
80 output_height_stride,
81 output_width_stride,
82 weight_c: tl.constexpr,
83 weight_height: tl.constexpr,
84 weight_width: tl.constexpr,
85 stride_height: tl.constexpr,
86 stride_width: tl.constexpr,
87 padding_height: tl.constexpr,
88 padding_width: tl.constexpr,
89 dilation_height: tl.constexpr,
90 dilation_width: tl.constexpr,
91 groups: tl.constexpr,
92 BLOCK_NI_HO_WO: tl.constexpr,
93 BLOCK_CI: tl.constexpr,
94 BLOCK_CO: tl.constexpr,
95):
96 pid_ni_ho_wo = tl.program_id(0)
97 pid_co = tl.program_id(1)
98 pid_group = tl.program_id(2)
100 # caculate in_n out_height out_weight value in kernel
101 ni_ho_wo_offset = pid_ni_ho_wo * BLOCK_NI_HO_WO + tl.arange(0, BLOCK_NI_HO_WO)
102 ni_ho_offset = ni_ho_wo_offset // out_width
103 in_n_point_value = ni_ho_offset // out_height
104 output_height_point_value = ni_ho_offset % out_height
105 output_width_point_value = ni_ho_wo_offset % out_width
107 # Load the input and weight pointers. input and weight are of shape
108 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width]
109 out_per_group_c = out_c // groups
110 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
111 input_pointer += (
112 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c
113 )[:, None]
114 weight_pointer += (
115 weight_n_stride * output_c_offset
116 + weight_n_stride * pid_group * out_per_group_c
117 )[None, :]
119 accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32)
120 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI
121 for hwc in range(weight_height * weight_width * BLOCK_CI_COUNT):
122 c = (hwc % BLOCK_CI_COUNT) * BLOCK_CI
123 hw = hwc // BLOCK_CI_COUNT
124 h = hw // weight_width
125 w = hw % weight_width
127 input_c_offset = c + tl.arange(0, BLOCK_CI)
128 input_height_offset = (
129 h * dilation_height
130 - padding_height
131 + stride_height * output_height_point_value
132 )
133 input_width_offset = (
134 w * dilation_width - padding_width + stride_width * output_width_point_value
135 )
137 curr_input_pointer = (
138 input_pointer
139 + (input_c_stride * input_c_offset)[None, :]
140 + (input_height_stride * input_height_offset)[:, None]
141 + (input_width_stride * input_width_offset)[:, None]
142 )
143 curr_weight_pointer = (
144 weight_pointer
145 + (weight_c_stride * input_c_offset)[:, None]
146 + (weight_height_stride * h)
147 + (weight_width_stride * w)
148 )
150 input_mask = (
151 (in_n_point_value < in_n)[:, None]
152 & (input_c_offset < weight_c)[None, :]
153 & (0 <= input_height_offset)[:, None]
154 & (input_height_offset < input_height)[:, None]
155 & (0 <= input_width_offset)[:, None]
156 & (input_width_offset < input_width)[:, None]
157 )
158 weight_mask = (input_c_offset < weight_c)[:, None] & (
159 output_c_offset < out_per_group_c
160 )[None, :]
162 input_block = tl.load(curr_input_pointer, mask=input_mask)
163 weight_block = tl.load(curr_weight_pointer, mask=weight_mask)
165 accum += tl.dot(input_block, weight_block, allow_tf32=False)
166 bias_pointer += (pid_group[None] * out_per_group_c)[None, :] + output_c_offset[
167 None, :
168 ]
169 mask_bias = (output_c_offset < out_per_group_c)[None, :]
170 bias = tl.load(bias_pointer, mask_bias).to(tl.float32)
171 accum += bias
172 output_pointer += (
173 (output_n_stride * in_n_point_value)[:, None]
174 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :]
175 + (output_height_stride * output_height_point_value)[:, None]
176 + (output_width_stride * output_width_point_value)[:, None]
177 )
178 output_mask = (
179 (in_n_point_value < in_n)[:, None]
180 & (output_c_offset < out_per_group_c)[None, :]
181 & (output_height_point_value < out_height)[:, None]
182 & (output_width_point_value < out_width)[:, None]
183 )
185 tl.store(output_pointer, accum, mask=output_mask)
188@libentry()
189@triton.autotune(
190 configs=runtime.get_tuned_config("conv2d_backward_weight"),
191 key=[
192 "in_n",
193 "input_height",
194 "input_width",
195 "weight_height",
196 "weight_width",
197 "input_c",
198 "stride_height",
199 "stride_width",
200 "out_height",
201 "out_width",
202 "out_c",
203 "padding_height",
204 "padding_width",
205 ],
206)
207@triton.jit
208def conv2d_backward_kernel_weight(
209 input_pointer,
210 out_grad_pointer,
211 weight_pointer,
212 input_n_stride,
213 input_c_stride,
214 input_height_stride,
215 input_width_stride,
216 weight_n_stride,
217 weight_c_stride,
218 weight_height_stride,
219 weight_width_stride,
220 output_n_stride,
221 output_c_stride,
222 output_height_stride,
223 output_width_stride,
224 input_height,
225 input_width,
226 weight_height,
227 weight_width,
228 input_c,
229 in_n,
230 stride_height,
231 stride_width,
232 out_height,
233 out_width,
234 out_c,
235 padding_height,
236 padding_width,
237 dilation_height,
238 dilation_width,
239 BLOCK_NO: tl.constexpr,
240 BLOCK_CI_HK_WK: tl.constexpr,
241 BLOCK_CO: tl.constexpr,
242):
243 # load out_grad n (groups out_c) ho wo
244 # load weight (groups out_c) ci h w
245 # load input n (groups ci) hi wi
247 # init pid and offset 0 for ci*hk*wk, 1 for groups, 2 for co.
248 pid_ci_hk_wk = tl.program_id(0)
249 pid_groups = tl.program_id(1)
250 pid_co = tl.program_id(2)
252 # caculate ci weight_height weight_weight value in kernel
253 ci_hk_wk_offset = pid_ci_hk_wk * BLOCK_CI_HK_WK + tl.arange(0, BLOCK_CI_HK_WK)
254 ci_hk_offset = ci_hk_wk_offset // weight_width
255 ci_point_value = ci_hk_offset // weight_height
256 weight_height_point_value = ci_hk_offset % weight_height
257 weight_width_point_value = ci_hk_wk_offset % weight_width
259 # caculate init pointer info of tensors
260 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
261 out_grad_pointer += (output_c_offset * output_c_stride)[None, :] + (
262 pid_groups[None] * output_c_stride * out_c
263 )[:, None]
265 weight_pointer += (
266 pid_groups * weight_n_stride * out_c + output_c_offset * weight_n_stride
267 )[None, :] + (
268 ci_point_value * weight_c_stride
269 + weight_height_point_value * weight_height_stride
270 + weight_width_point_value * weight_width_stride
271 )[
272 :, None
273 ]
275 input_pointer += (ci_point_value * input_c_stride[None])[:, None] + (
276 pid_groups[None] * input_c_stride * input_c
277 )[None, :]
279 # calculate the values of the input based on the width and height of the output by looping
280 accum = tl.zeros((BLOCK_CI_HK_WK, BLOCK_CO), dtype=tl.float32)
281 for h in range(0, out_height):
282 for w in range(0, out_width):
283 for n in range(0, in_n, BLOCK_NO):
284 output_n_offset = n + tl.arange(0, BLOCK_NO)
286 # caculate input pointer to [cin*kh*kw, *] out_grad pointer to [*, out_c], N*hout*wout as reduce dim
287 curr_out_grad_pointer = (
288 out_grad_pointer
289 + (
290 output_n_offset * output_n_stride
291 + h * output_height_stride
292 + w * output_width_stride
293 )[:, None]
294 )
295 out_grad_mask = (output_n_offset < in_n)[:, None] & (
296 output_c_offset < out_c
297 )[None, :]
299 curr_out_grad = tl.load(curr_out_grad_pointer, mask=out_grad_mask)
301 input_height_offset = (
302 weight_height_point_value * dilation_height
303 - padding_height
304 + stride_height * h
305 )
307 input_width_offset = (
308 weight_width_point_value * dilation_width
309 - padding_width
310 + stride_width * w
311 )
313 curr_input_pointer = (
314 input_pointer
315 + (input_n_stride * output_n_offset)[None, :]
316 + (input_height_stride * input_height_offset)[:, None]
317 + (input_width_stride * input_width_offset)[:, None]
318 )
319 input_mask = (
320 (output_n_offset < in_n)[None, :]
321 & (ci_point_value < input_c)[:, None]
322 & (0 <= input_height_offset)[:, None]
323 & (input_height_offset < input_height)[:, None]
324 & (0 <= input_width_offset)[:, None]
325 & (input_width_offset < input_width)[:, None]
326 )
328 curr_input = tl.load(curr_input_pointer, mask=input_mask)
329 accum += tl.dot(curr_input, curr_out_grad, allow_tf32=False)
331 weight_mask = (
332 (ci_point_value < input_c)[:, None]
333 & (output_c_offset < out_c)[None, :]
334 & (weight_height_point_value < weight_height)[:, None]
335 & (weight_width_point_value < weight_width)[:, None]
336 )
337 tl.store(weight_pointer, accum, weight_mask)
340class Conv2d(torch.autograd.Function):
341 @staticmethod
342 def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
343 logger.debug("GEMS_MTHREADS CONV2D")
344 assert weight.ndim == 4, "Weights must be 4D, received shape {weight.shape}"
345 assert (
346 bias is None or bias.ndim == 1
347 ), "Bias must be 1D, received shape {bias.shape}"
349 assert (
350 input.shape[1] == groups * weight.shape[1]
351 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups"
352 assert (
353 bias is None or weight.shape[0] == bias.shape[0]
354 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape"
356 if isinstance(stride, (list, tuple)):
357 stride_height, stride_width = stride
358 else:
359 stride_height = stride_width = stride
361 if isinstance(padding, (list, tuple)):
362 padding_height, padding_width = padding
363 else:
364 padding_height = padding_width = padding
366 if isinstance(dilation, (list, tuple)):
367 dilation_height, dilation_width = dilation
368 else:
369 dilation_height = dilation_width = dilation
371 in_n, _, input_height, input_width = input.shape
372 out_c, weight_c, weight_height, weight_width = weight.shape
373 out_height = conv2d_output_size(
374 input_height, weight_height, stride_height, padding_height, dilation_height
375 )
376 out_width = conv2d_output_size(
377 input_width, weight_width, stride_width, padding_width, dilation_width
378 )
380 output_dtype = input.dtype
381 output = torch.empty(
382 (in_n, out_c, out_height, out_width),
383 device=input.device,
384 dtype=output_dtype,
385 )
387 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions,
388 # BLOCK_CO along the out_c,
389 # one group per cat
390 grid = lambda META: (
391 triton.cdiv(in_n * out_height * out_width, META["BLOCK_NI_HO_WO"]),
392 triton.cdiv(int(out_c // groups), META["BLOCK_CO"]),
393 groups,
394 )
396 if bias is None:
397 bias_pointer = torch.zeros(out_c, device=input.device, dtype=output_dtype)
398 else:
399 bias_pointer = bias
400 conv2d_forward_kernel[grid](
401 input,
402 weight,
403 output,
404 bias_pointer,
405 in_n,
406 input_height,
407 input_width,
408 out_c,
409 out_height,
410 out_width,
411 *input.stride(),
412 *weight.stride(),
413 *output.stride(),
414 weight_c,
415 weight_height,
416 weight_width,
417 stride_height,
418 stride_width,
419 padding_height,
420 padding_width,
421 dilation_height,
422 dilation_width,
423 groups=groups,
424 )
426 ctx.save_for_backward(weight, input, bias)
428 ctx.stride = (stride_height, stride_width)
429 ctx.padding = (padding_height, padding_width)
430 ctx.dilation = (dilation_height, dilation_width)
432 ctx.weight_info = (int(out_c / groups), weight_c, weight_height, weight_width)
433 ctx.input_info = (in_n, input_height, input_width)
434 ctx.out_info = (out_height, out_width)
436 ctx.device = input.device
437 ctx.groups = groups
439 return output
441 @staticmethod
442 def backward(ctx, out_grad):
443 logger.debug("GEMS_MTHREADS CONV2D VJP")
444 (weight, input, bias) = ctx.saved_tensors
445 # (out_c equals origin cout divide groups)
446 out_c, weight_c, weight_height, weight_width = ctx.weight_info
447 in_n, input_height, input_width = ctx.input_info
448 out_height, out_width = ctx.out_info
450 device = ctx.device
451 groups = ctx.groups
453 stride_height, stride_width = ctx.stride
454 dilation_height, dilation_width = ctx.dilation
455 padding_height, padding_width = ctx.padding
457 revert_padding_height = dilation_height * (weight_height - 1) - padding_height
458 revert_padding_width = dilation_width * (weight_width - 1) - padding_width
459 revert_weight = weight.clone()
460 revert_weight = torch.flip(revert_weight, dims=[2, 3]).contiguous()
462 if groups != 1:
463 revert_weight = revert_weight.reshape(
464 groups, out_c, weight_c, weight_height, weight_width
465 )
466 revert_weight = revert_weight.transpose(1, 2)
467 revert_weight = revert_weight.reshape(
468 groups * weight_c, out_c, weight_height, weight_width
469 ).contiguous()
470 else:
471 revert_weight = revert_weight.transpose(0, 1).contiguous()
473 new_out_height = out_grad.shape[2] + (stride_height - 1) * (
474 out_grad.shape[2] - 1
475 )
476 new_out_width = out_grad.shape[3] + (stride_width - 1) * (out_grad.shape[3] - 1)
478 new_out = torch.zeros(
479 out_grad.shape[0],
480 out_grad.shape[1],
481 new_out_height,
482 new_out_width,
483 device=device,
484 dtype=out_grad.dtype,
485 )
487 # copy out_grad to new_out
488 if stride_height > 1 or stride_width > 1:
489 for i in range(out_grad.shape[2]):
490 for j in range(out_grad.shape[3]):
491 new_out[:, :, i * (stride_height), j * (stride_width)] = out_grad[
492 :, :, i, j
493 ]
494 else:
495 new_out = out_grad
497 input_back = torch.zeros(
498 in_n,
499 weight_c * groups,
500 input_height,
501 input_width,
502 dtype=torch.float32,
503 device=device,
504 )
506 grid = lambda META: (
507 triton.cdiv(
508 out_grad.shape[0] * input_height * input_width, META["BLOCK_NI_HO_WO"]
509 ),
510 triton.cdiv(int(weight_c), META["BLOCK_CO"]),
511 groups,
512 )
513 bias_zero = torch.zeros(groups * weight_c, device=device, dtype=out_grad.dtype)
514 conv2d_forward_kernel[grid](
515 new_out,
516 revert_weight,
517 input_back,
518 bias_zero,
519 out_grad.shape[0],
520 new_out_height,
521 new_out_width,
522 groups * weight_c,
523 input_height,
524 input_width,
525 *new_out.stride(),
526 *revert_weight.stride(),
527 *input_back.stride(),
528 out_c,
529 weight_height,
530 weight_width,
531 1,
532 1,
533 revert_padding_height,
534 revert_padding_width,
535 dilation_height,
536 dilation_width,
537 groups=groups,
538 )
540 weight_back = torch.zeros(
541 out_c * groups,
542 weight_c,
543 weight_height,
544 weight_width,
545 dtype=weight.dtype,
546 device=device,
547 )
549 grid_weight = lambda meta: (
550 triton.cdiv(
551 weight_c * weight_height * weight_width, meta["BLOCK_CI_HK_WK"]
552 ),
553 groups,
554 triton.cdiv(out_c, meta["BLOCK_CO"]),
555 )
556 conv2d_backward_kernel_weight[grid_weight](
557 input,
558 out_grad,
559 weight_back,
560 *input.stride(),
561 *weight.stride(),
562 *out_grad.stride(),
563 input_height,
564 input_width,
565 weight_height,
566 weight_width,
567 weight_c,
568 in_n,
569 stride_height,
570 stride_width,
571 out_height,
572 out_width,
573 out_c,
574 padding_height,
575 padding_width,
576 dilation_height,
577 dilation_width,
578 )
579 if bias is not None:
580 bias_grad = out_grad.to(torch.float64).sum(dim=(0, 2, 3))
581 else:
582 bias_grad = None
583 return (
584 input_back,
585 weight_back,
586 bias_grad,
587 None,
588 None,
589 None,
590 None,
591 )
594# todo test SymInt[2] of stride or padding
595def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
596 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)