Coverage for src/flag_gems/ops/conv2d.py: 13%
178 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.utils import libentry
11logger = logging.getLogger(__name__)
14def conv2d_output_size(
15 in_size: int,
16 kernel_size: int,
17 stride: int,
18 padding: int,
19 dilation: int,
20) -> int:
21 """
22 Determines the output size of a 2D convolution operation.
24 Args:
25 in_size: Input size.
26 kernel_size: Kernel size.
27 stride: Stride.
28 padding: Padding.
29 dilation: Dilation.
31 Returns:
32 Output size of 2D convolution.
33 """
34 return (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
37@libentry()
38@triton.autotune(
39 configs=runtime.get_tuned_config("conv2d_forward"),
40 key=[
41 "in_n",
42 "weight_c",
43 "input_height",
44 "input_width",
45 "out_c",
46 "out_height",
47 "out_width",
48 "weight_height",
49 "weight_width",
50 "stride_height",
51 "stride_width",
52 "padding_height",
53 "padding_width",
54 "groups",
55 ],
56)
57@triton.jit
58def conv2d_forward_kernel(
59 input_pointer,
60 weight_pointer,
61 output_pointer,
62 bias_pointer,
63 in_n,
64 input_height,
65 input_width,
66 out_c,
67 out_height,
68 out_width,
69 input_n_stride,
70 input_c_stride,
71 input_height_stride,
72 input_width_stride,
73 weight_n_stride,
74 weight_c_stride,
75 weight_height_stride,
76 weight_width_stride,
77 output_n_stride,
78 output_c_stride,
79 output_height_stride,
80 output_width_stride,
81 weight_c: tl.constexpr,
82 weight_height: tl.constexpr,
83 weight_width: tl.constexpr,
84 stride_height: tl.constexpr,
85 stride_width: tl.constexpr,
86 padding_height: tl.constexpr,
87 padding_width: tl.constexpr,
88 dilation_height: tl.constexpr,
89 dilation_width: tl.constexpr,
90 groups: tl.constexpr,
91 BLOCK_NI_HO_WO: tl.constexpr,
92 BLOCK_CI: tl.constexpr,
93 BLOCK_CO: tl.constexpr,
94):
95 pid_ni_ho_wo = tl.program_id(0)
96 pid_co = tl.program_id(1)
97 pid_group = tl.program_id(2)
99 # caculate in_n out_height out_weight value in kernel
100 ni_ho_wo_offset = pid_ni_ho_wo * BLOCK_NI_HO_WO + tl.arange(0, BLOCK_NI_HO_WO)
101 ni_ho_offset = ni_ho_wo_offset // out_width
102 in_n_point_value = ni_ho_offset // out_height
103 output_height_point_value = ni_ho_offset % out_height
104 output_width_point_value = ni_ho_wo_offset % out_width
106 # Load the input and weight pointers. input and weight are of shape
107 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width]
108 out_per_group_c = out_c // groups
109 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
110 input_pointer += (
111 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c
112 )[:, None]
113 weight_pointer += (
114 weight_n_stride * output_c_offset
115 + weight_n_stride * pid_group * out_per_group_c
116 )[None, :]
118 accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32)
119 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI
120 for hwc in range(weight_height * weight_width * BLOCK_CI_COUNT):
121 c = (hwc % BLOCK_CI_COUNT) * BLOCK_CI
122 hw = hwc // BLOCK_CI_COUNT
123 h = hw // weight_width
124 w = hw % weight_width
126 input_c_offset = c + tl.arange(0, BLOCK_CI)
127 input_height_offset = (
128 h * dilation_height
129 - padding_height
130 + stride_height * output_height_point_value
131 )
132 input_width_offset = (
133 w * dilation_width - padding_width + stride_width * output_width_point_value
134 )
136 curr_input_pointer = (
137 input_pointer
138 + (input_c_stride * input_c_offset)[None, :]
139 + (input_height_stride * input_height_offset)[:, None]
140 + (input_width_stride * input_width_offset)[:, None]
141 )
142 curr_weight_pointer = (
143 weight_pointer
144 + (weight_c_stride * input_c_offset)[:, None]
145 + (weight_height_stride * h)
146 + (weight_width_stride * w)
147 )
149 input_mask = (
150 (in_n_point_value < in_n)[:, None]
151 & (input_c_offset < weight_c)[None, :]
152 & (0 <= input_height_offset)[:, None]
153 & (input_height_offset < input_height)[:, None]
154 & (0 <= input_width_offset)[:, None]
155 & (input_width_offset < input_width)[:, None]
156 )
157 weight_mask = (input_c_offset < weight_c)[:, None] & (
158 output_c_offset < out_per_group_c
159 )[None, :]
161 input_block = tl.load(curr_input_pointer, mask=input_mask)
162 weight_block = tl.load(curr_weight_pointer, mask=weight_mask)
164 accum += tl.dot(input_block, weight_block, allow_tf32=False)
165 bias_pointer += (pid_group[None] * out_per_group_c)[None, :] + output_c_offset[
166 None, :
167 ]
168 mask_bias = (output_c_offset < out_per_group_c)[None, :]
169 bias = tl.load(bias_pointer, mask_bias).to(tl.float32)
170 accum += bias
171 output_pointer += (
172 (output_n_stride * in_n_point_value)[:, None]
173 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :]
174 + (output_height_stride * output_height_point_value)[:, None]
175 + (output_width_stride * output_width_point_value)[:, None]
176 )
177 output_mask = (
178 (in_n_point_value < in_n)[:, None]
179 & (output_c_offset < out_per_group_c)[None, :]
180 & (output_height_point_value < out_height)[:, None]
181 & (output_width_point_value < out_width)[:, None]
182 )
184 tl.store(output_pointer, accum, mask=output_mask)
187@libentry()
188@triton.autotune(
189 configs=runtime.get_tuned_config("conv2d_backward_weight"),
190 key=[
191 "in_n",
192 "input_height",
193 "input_width",
194 "weight_height",
195 "weight_width",
196 "input_c",
197 "stride_height",
198 "stride_width",
199 "out_height",
200 "out_width",
201 "out_c",
202 "padding_height",
203 "padding_width",
204 ],
205)
206@triton.jit
207def conv2d_backward_kernel_weight(
208 input_pointer,
209 out_grad_pointer,
210 weight_pointer,
211 input_n_stride,
212 input_c_stride,
213 input_height_stride,
214 input_width_stride,
215 weight_n_stride,
216 weight_c_stride,
217 weight_height_stride,
218 weight_width_stride,
219 output_n_stride,
220 output_c_stride,
221 output_height_stride,
222 output_width_stride,
223 input_height,
224 input_width,
225 weight_height,
226 weight_width,
227 input_c,
228 in_n,
229 stride_height,
230 stride_width,
231 out_height,
232 out_width,
233 out_c,
234 padding_height,
235 padding_width,
236 dilation_height,
237 dilation_width,
238 BLOCK_NO: tl.constexpr,
239 BLOCK_CI_HK_WK: tl.constexpr,
240 BLOCK_CO: tl.constexpr,
241):
242 # load out_grad n (groups out_c) ho wo
243 # load weight (groups out_c) ci h w
244 # load input n (groups ci) hi wi
246 # init pid and offset 0 for ci*hk*wk, 1 for groups, 2 for co.
247 pid_ci_hk_wk = tl.program_id(0)
248 pid_groups = tl.program_id(1)
249 pid_co = tl.program_id(2)
251 # caculate ci weight_height weight_weight value in kernel
252 ci_hk_wk_offset = pid_ci_hk_wk * BLOCK_CI_HK_WK + tl.arange(0, BLOCK_CI_HK_WK)
253 ci_hk_offset = ci_hk_wk_offset // weight_width
254 ci_point_value = ci_hk_offset // weight_height
255 weight_height_point_value = ci_hk_offset % weight_height
256 weight_width_point_value = ci_hk_wk_offset % weight_width
258 # caculate init pointer info of tensors
259 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
260 out_grad_pointer += (output_c_offset * output_c_stride)[None, :] + (
261 pid_groups[None] * output_c_stride * out_c
262 )[:, None]
264 weight_pointer += (
265 pid_groups * weight_n_stride * out_c + output_c_offset * weight_n_stride
266 )[None, :] + (
267 ci_point_value * weight_c_stride
268 + weight_height_point_value * weight_height_stride
269 + weight_width_point_value * weight_width_stride
270 )[
271 :, None
272 ]
274 input_pointer += (ci_point_value * input_c_stride[None])[:, None] + (
275 pid_groups[None] * input_c_stride * input_c
276 )[None, :]
278 # calculate the values of the input based on the width and height of the output by looping
279 accum = tl.zeros((BLOCK_CI_HK_WK, BLOCK_CO), dtype=tl.float32)
280 for h in range(0, out_height):
281 for w in range(0, out_width):
282 for n in range(0, in_n, BLOCK_NO):
283 output_n_offset = n + tl.arange(0, BLOCK_NO)
285 # caculate input pointer to [cin*kh*kw, *] out_grad pointer to [*, out_c], N*hout*wout as reduce dim
286 curr_out_grad_pointer = (
287 out_grad_pointer
288 + (
289 output_n_offset * output_n_stride
290 + h * output_height_stride
291 + w * output_width_stride
292 )[:, None]
293 )
294 out_grad_mask = (output_n_offset < in_n)[:, None] & (
295 output_c_offset < out_c
296 )[None, :]
298 curr_out_grad = tl.load(curr_out_grad_pointer, mask=out_grad_mask)
300 input_height_offset = (
301 weight_height_point_value * dilation_height
302 - padding_height
303 + stride_height * h
304 )
306 input_width_offset = (
307 weight_width_point_value * dilation_width
308 - padding_width
309 + stride_width * w
310 )
312 curr_input_pointer = (
313 input_pointer
314 + (input_n_stride * output_n_offset)[None, :]
315 + (input_height_stride * input_height_offset)[:, None]
316 + (input_width_stride * input_width_offset)[:, None]
317 )
318 input_mask = (
319 (output_n_offset < in_n)[None, :]
320 & (ci_point_value < input_c)[:, None]
321 & (0 <= input_height_offset)[:, None]
322 & (input_height_offset < input_height)[:, None]
323 & (0 <= input_width_offset)[:, None]
324 & (input_width_offset < input_width)[:, None]
325 )
327 curr_input = tl.load(curr_input_pointer, mask=input_mask)
328 accum += tl.dot(curr_input, curr_out_grad, allow_tf32=False)
330 weight_mask = (
331 (ci_point_value < input_c)[:, None]
332 & (output_c_offset < out_c)[None, :]
333 & (weight_height_point_value < weight_height)[:, None]
334 & (weight_width_point_value < weight_width)[:, None]
335 )
336 tl.store(weight_pointer, accum, weight_mask)
339class Conv2d(torch.autograd.Function):
340 @staticmethod
341 def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
342 logger.debug("GEMS CONV2D")
343 assert weight.ndim == 4, "Weights must be 4D, received shape {weight.shape}"
344 assert (
345 bias is None or bias.ndim == 1
346 ), "Bias must be 1D, received shape {bias.shape}"
348 assert (
349 input.shape[1] == groups * weight.shape[1]
350 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups"
351 assert (
352 bias is None or weight.shape[0] == bias.shape[0]
353 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape"
355 if isinstance(stride, (list, tuple)):
356 stride_height, stride_width = stride
357 else:
358 stride_height = stride_width = stride
360 if isinstance(padding, (list, tuple)):
361 padding_height, padding_width = padding
362 else:
363 padding_height = padding_width = padding
365 if isinstance(dilation, (list, tuple)):
366 dilation_height, dilation_width = dilation
367 else:
368 dilation_height = dilation_width = dilation
370 in_n, _, input_height, input_width = input.shape
371 out_c, weight_c, weight_height, weight_width = weight.shape
372 out_height = conv2d_output_size(
373 input_height, weight_height, stride_height, padding_height, dilation_height
374 )
375 out_width = conv2d_output_size(
376 input_width, weight_width, stride_width, padding_width, dilation_width
377 )
379 output_dtype = input.dtype
380 output = torch.empty(
381 (in_n, out_c, out_height, out_width),
382 device=input.device,
383 dtype=output_dtype,
384 )
386 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions,
387 # BLOCK_CO along the out_c,
388 # one group per cat
389 grid = lambda META: (
390 triton.cdiv(in_n * out_height * out_width, META["BLOCK_NI_HO_WO"]),
391 triton.cdiv(int(out_c // groups), META["BLOCK_CO"]),
392 groups,
393 )
395 if bias is None:
396 bias_pointer = torch.zeros(out_c, device=input.device, dtype=output_dtype)
397 else:
398 bias_pointer = bias
399 conv2d_forward_kernel[grid](
400 input,
401 weight,
402 output,
403 bias_pointer,
404 in_n,
405 input_height,
406 input_width,
407 out_c,
408 out_height,
409 out_width,
410 *input.stride(),
411 *weight.stride(),
412 *output.stride(),
413 weight_c,
414 weight_height,
415 weight_width,
416 stride_height,
417 stride_width,
418 padding_height,
419 padding_width,
420 dilation_height,
421 dilation_width,
422 groups=groups,
423 )
425 ctx.save_for_backward(weight, input, bias)
427 ctx.stride = (stride_height, stride_width)
428 ctx.padding = (padding_height, padding_width)
429 ctx.dilation = (dilation_height, dilation_width)
431 ctx.weight_info = (int(out_c / groups), weight_c, weight_height, weight_width)
432 ctx.input_info = (in_n, input_height, input_width)
433 ctx.out_info = (out_height, out_width)
435 ctx.device = input.device
436 ctx.groups = groups
438 return output
440 @staticmethod
441 def backward(ctx, out_grad):
442 logger.debug("GEMS CONV2D VJP")
443 (weight, input, bias) = ctx.saved_tensors
444 # (out_c equals origin cout divide groups)
445 out_c, weight_c, weight_height, weight_width = ctx.weight_info
446 in_n, input_height, input_width = ctx.input_info
447 out_height, out_width = ctx.out_info
449 device = ctx.device
450 groups = ctx.groups
452 stride_height, stride_width = ctx.stride
453 dilation_height, dilation_width = ctx.dilation
454 padding_height, padding_width = ctx.padding
456 revert_padding_height = dilation_height * (weight_height - 1) - padding_height
457 revert_padding_width = dilation_width * (weight_width - 1) - padding_width
458 revert_weight = weight.clone()
459 revert_weight = torch.flip(revert_weight, dims=[2, 3]).contiguous()
461 if groups != 1:
462 revert_weight = revert_weight.reshape(
463 groups, out_c, weight_c, weight_height, weight_width
464 )
465 revert_weight = revert_weight.transpose(1, 2)
466 revert_weight = revert_weight.reshape(
467 groups * weight_c, out_c, weight_height, weight_width
468 ).contiguous()
469 else:
470 revert_weight = revert_weight.transpose(0, 1).contiguous()
472 new_out_height = out_grad.shape[2] + (stride_height - 1) * (
473 out_grad.shape[2] - 1
474 )
475 new_out_width = out_grad.shape[3] + (stride_width - 1) * (out_grad.shape[3] - 1)
477 new_out = torch.zeros(
478 out_grad.shape[0],
479 out_grad.shape[1],
480 new_out_height,
481 new_out_width,
482 device=device,
483 dtype=out_grad.dtype,
484 )
486 # copy out_grad to new_out
487 if stride_height > 1 or stride_width > 1:
488 for i in range(out_grad.shape[2]):
489 for j in range(out_grad.shape[3]):
490 new_out[:, :, i * (stride_height), j * (stride_width)] = out_grad[
491 :, :, i, j
492 ]
493 else:
494 new_out = out_grad
496 input_back = torch.zeros(
497 in_n,
498 weight_c * groups,
499 input_height,
500 input_width,
501 dtype=torch.float32,
502 device=device,
503 )
505 grid = lambda META: (
506 triton.cdiv(
507 out_grad.shape[0] * input_height * input_width, META["BLOCK_NI_HO_WO"]
508 ),
509 triton.cdiv(int(weight_c), META["BLOCK_CO"]),
510 groups,
511 )
512 bias_zero = torch.zeros(groups * weight_c, device=device, dtype=out_grad.dtype)
513 conv2d_forward_kernel[grid](
514 new_out,
515 revert_weight,
516 input_back,
517 bias_zero,
518 out_grad.shape[0],
519 new_out_height,
520 new_out_width,
521 groups * weight_c,
522 input_height,
523 input_width,
524 *new_out.stride(),
525 *revert_weight.stride(),
526 *input_back.stride(),
527 out_c,
528 weight_height,
529 weight_width,
530 1,
531 1,
532 revert_padding_height,
533 revert_padding_width,
534 dilation_height,
535 dilation_width,
536 groups=groups,
537 )
539 weight_back = torch.zeros(
540 out_c * groups,
541 weight_c,
542 weight_height,
543 weight_width,
544 dtype=weight.dtype,
545 device=device,
546 )
548 grid_weight = lambda meta: (
549 triton.cdiv(
550 weight_c * weight_height * weight_width, meta["BLOCK_CI_HK_WK"]
551 ),
552 groups,
553 triton.cdiv(out_c, meta["BLOCK_CO"]),
554 )
555 conv2d_backward_kernel_weight[grid_weight](
556 input,
557 out_grad,
558 weight_back,
559 *input.stride(),
560 *weight.stride(),
561 *out_grad.stride(),
562 input_height,
563 input_width,
564 weight_height,
565 weight_width,
566 weight_c,
567 in_n,
568 stride_height,
569 stride_width,
570 out_height,
571 out_width,
572 out_c,
573 padding_height,
574 padding_width,
575 dilation_height,
576 dilation_width,
577 )
578 if bias is not None:
579 bias_grad = out_grad.sum(dim=(0, 2, 3))
580 else:
581 bias_grad = None
582 return (
583 input_back,
584 weight_back,
585 bias_grad,
586 None,
587 None,
588 None,
589 None,
590 )
593# todo test SymInt[2] of stride or padding
594def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
595 if isinstance(padding, str):
596 if padding == "same":
597 assert (
598 stride == 1
599 ), "Doesn't support any stride values other than 1 \
600 in padding = 'same' mode, received stride value {stride}"
601 ih = input.shape[-2]
602 iw = input.shape[-1]
603 kernel_size_h = weight.shape[-2]
604 kernel_size_w = weight.shape[-1]
605 padding_h = int(
606 math.ceil(
607 (stride * (ih - 1) + 1 + dilation * (kernel_size_h - 1) - ih) / 2
608 )
609 )
610 padding_w = int(
611 math.ceil(
612 (stride * (iw - 1) + 1 + dilation * (kernel_size_w - 1) - iw) / 2
613 )
614 )
615 oh = int(
616 (ih + 2 * padding_h - dilation * (kernel_size_h - 1) - 1) / stride + 1
617 )
618 ow = int(
619 (iw + 2 * padding_w - dilation * (kernel_size_w - 1) - 1) / stride + 1
620 )
621 padding = max(padding_h, padding_w)
622 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)[
623 ..., (oh - ih) :, (ow - iw) :
624 ]
625 elif padding == "valid":
626 return Conv2d.apply(input, weight, bias, stride, 0, dilation, groups)
627 else:
628 raise ValueError(
629 f"Unsupported padding string: {padding}, only'valild'/'same' are allowed."
630 )
631 else:
632 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)