Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/conv2d.py: 0%
216 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.utils import libentry
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13def conv2d_output_size(
14 in_size: int,
15 kernel_size: int,
16 stride: int,
17 padding: int,
18 dilation: int,
19) -> int:
20 """
21 Determines the output size of a 2D convolution operation.
23 Args:
24 in_size: Input size.
25 kernel_size: Kernel size.
26 stride: Stride.
27 padding: Padding.
28 dilation: Dilation.
30 Returns:
31 Output size of 2D convolution.
32 """
33 return (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
36@libentry()
37# @triton.autotune(
38# configs=runtime.get_tuned_config("conv2d_forward"),
39# key=[
40# "in_n",
41# "weight_c",
42# "input_height",
43# "input_width",
44# "out_c",
45# "out_height",
46# "out_width",
47# "weight_height",
48# "weight_width",
49# "stride_height",
50# "stride_width",
51# "padding_height",
52# "padding_width",
53# "groups",
54# ],
55# )
56@triton.jit
57def conv2d_forward_kernel(
58 input_pointer,
59 weight_pointer,
60 output_pointer,
61 bias_pointer,
62 in_n,
63 input_height,
64 input_width,
65 out_c,
66 out_height,
67 out_width,
68 input_n_stride,
69 input_c_stride,
70 input_height_stride,
71 input_width_stride,
72 weight_n_stride,
73 weight_c_stride,
74 weight_height_stride,
75 weight_width_stride,
76 output_n_stride,
77 output_c_stride,
78 output_height_stride,
79 output_width_stride,
80 weight_c: tl.constexpr,
81 weight_height: tl.constexpr,
82 weight_width: tl.constexpr,
83 stride_height: tl.constexpr,
84 stride_width: tl.constexpr,
85 padding_height: tl.constexpr,
86 padding_width: tl.constexpr,
87 dilation_height: tl.constexpr,
88 dilation_width: tl.constexpr,
89 groups: tl.constexpr,
90 BLOCK_NI_HO_WO: tl.constexpr,
91 BLOCK_CI: tl.constexpr,
92 BLOCK_CO: tl.constexpr,
93 USE_MIXED_PRECISION: tl.constexpr,
94):
95 """
96 Mixed-precision forward kernel.
97 When USE_MIXED_PRECISION=True: FP16/BF16 I/O + FP32 accumulator
98 """
99 pid_ni_ho_wo = tl.program_id(0)
100 pid_co = tl.program_id(1)
101 pid_group = tl.program_id(2)
103 # caculate in_n out_height out_weight value in kernel
104 ni_ho_wo_offset = pid_ni_ho_wo * BLOCK_NI_HO_WO + tl.arange(0, BLOCK_NI_HO_WO)
105 ni_ho_offset = ni_ho_wo_offset // out_width
106 in_n_point_value = ni_ho_offset // out_height
107 output_height_point_value = ni_ho_offset % out_height
108 output_width_point_value = ni_ho_wo_offset % out_width
110 # Load the input and weight pointers. input and weight are of shape
111 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width]
112 out_per_group_c = out_c // groups
113 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
114 input_pointer += (
115 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c
116 )[:, None]
117 weight_pointer += (
118 weight_n_stride * output_c_offset
119 + weight_n_stride * pid_group * out_per_group_c
120 )[None, :]
122 accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32)
123 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI
124 for hwc in range(weight_height * weight_width * BLOCK_CI_COUNT):
125 c = (hwc % BLOCK_CI_COUNT) * BLOCK_CI
126 hw = hwc // BLOCK_CI_COUNT
127 h = hw // weight_width
128 w = hw % weight_width
130 input_c_offset = c + tl.arange(0, BLOCK_CI)
131 input_height_offset = (
132 h * dilation_height
133 - padding_height
134 + stride_height * output_height_point_value
135 )
136 input_width_offset = (
137 w * dilation_width - padding_width + stride_width * output_width_point_value
138 )
140 curr_input_pointer = (
141 input_pointer
142 + (input_c_stride * input_c_offset)[None, :]
143 + (input_height_stride * input_height_offset)[:, None]
144 + (input_width_stride * input_width_offset)[:, None]
145 )
146 curr_weight_pointer = (
147 weight_pointer
148 + (weight_c_stride * input_c_offset)[:, None]
149 + (weight_height_stride * h)
150 + (weight_width_stride * w)
151 )
153 input_mask = (
154 (in_n_point_value < in_n)[:, None]
155 & (input_c_offset < weight_c)[None, :]
156 & (0 <= input_height_offset)[:, None]
157 & (input_height_offset < input_height)[:, None]
158 & (0 <= input_width_offset)[:, None]
159 & (input_width_offset < input_width)[:, None]
160 )
161 weight_mask = (input_c_offset < weight_c)[:, None] & (
162 output_c_offset < out_per_group_c
163 )[None, :]
165 input_block = tl.load(curr_input_pointer, mask=input_mask)
166 weight_block = tl.load(curr_weight_pointer, mask=weight_mask)
168 # Mixed precision: convert to FP32 for computation
169 if USE_MIXED_PRECISION:
170 input_block = input_block.to(tl.float32)
171 weight_block = weight_block.to(tl.float32)
173 accum += tl.dot(input_block, weight_block, allow_tf32=False)
174 bias_pointer += pid_group * out_per_group_c[None, :] + output_c_offset[None, :]
175 mask_bias = (output_c_offset < out_per_group_c)[None, :]
176 bias = tl.load(bias_pointer, mask_bias).to(tl.float32)
177 accum += bias
178 output_pointer += (
179 (output_n_stride * in_n_point_value)[:, None]
180 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :]
181 + (output_height_stride * output_height_point_value)[:, None]
182 + (output_width_stride * output_width_point_value)[:, None]
183 )
184 output_mask = (
185 (in_n_point_value < in_n)[:, None]
186 & (output_c_offset < out_per_group_c)[None, :]
187 & (output_height_point_value < out_height)[:, None]
188 & (output_width_point_value < out_width)[:, None]
189 )
191 tl.store(output_pointer, accum, mask=output_mask)
194@libentry()
195# @triton.autotune(
196# configs=runtime.get_tuned_config("conv2d_backward_weight"),
197# key=[
198# "in_n",
199# "input_height",
200# "input_width",
201# "weight_height",
202# "weight_width",
203# "input_c",
204# "stride_height",
205# "stride_width",
206# "out_height",
207# "out_width",
208# "out_c",
209# "padding_height",
210# "padding_width",
211# ],
212# )
213@triton.jit
214def conv2d_backward_kernel_weight(
215 input_pointer,
216 out_grad_pointer,
217 weight_pointer,
218 input_n_stride,
219 input_c_stride,
220 input_height_stride,
221 input_width_stride,
222 weight_n_stride,
223 weight_c_stride,
224 weight_height_stride,
225 weight_width_stride,
226 output_n_stride,
227 output_c_stride,
228 output_height_stride,
229 output_width_stride,
230 input_height,
231 input_width,
232 weight_height,
233 weight_width,
234 input_c,
235 in_n,
236 stride_height,
237 stride_width,
238 out_height,
239 out_width,
240 out_c,
241 padding_height,
242 padding_width,
243 dilation_height,
244 dilation_width,
245 groups: tl.constexpr,
246 BLOCK_NO: tl.constexpr,
247 BLOCK_CI_HK_WK: tl.constexpr,
248 BLOCK_CO: tl.constexpr,
249):
250 # load out_grad n (groups out_c) ho wo
251 # load weight (groups out_c) ci h w
252 # load input n (groups ci) hi wi
254 # init pid and offset 0 for ci*hk*wk, 1 for groups, 2 for co.
255 pid_ci_hk_wk = tl.program_id(0)
256 pid_groups = tl.program_id(1)
257 pid_co = tl.program_id(2)
259 # caculate ci weight_height weight_weight value in kernel
260 ci_hk_wk_offset = pid_ci_hk_wk * BLOCK_CI_HK_WK + tl.arange(0, BLOCK_CI_HK_WK)
261 ci_hk_offset = ci_hk_wk_offset // weight_width
262 ci_point_value = ci_hk_offset // weight_height
263 weight_height_point_value = ci_hk_offset % weight_height
264 weight_width_point_value = ci_hk_wk_offset % weight_width
266 # caculate init pointer info of tensors
267 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
268 out_grad_pointer += (output_c_offset * output_c_stride)[None, :] + (
269 pid_groups * output_c_stride * out_c
270 )[:, None]
272 weight_pointer += (
273 pid_groups * weight_n_stride * out_c + output_c_offset * weight_n_stride
274 )[None, :] + (
275 ci_point_value * weight_c_stride
276 + weight_height_point_value * weight_height_stride
277 + weight_width_point_value * weight_width_stride
278 )[
279 :, None
280 ]
282 input_pointer += (ci_point_value * input_c_stride)[:, None] + (
283 pid_groups * input_c_stride * input_c
284 )[None, :]
286 # calculate the values of the input based on the width and height of the output by looping
287 accum = tl.zeros((BLOCK_CI_HK_WK, BLOCK_CO), dtype=tl.float32)
288 for h in range(0, out_height):
289 for w in range(0, out_width):
290 for n in range(0, in_n, BLOCK_NO):
291 output_n_offset = n + tl.arange(0, BLOCK_NO)
293 # caculate input pointer to [cin*kh*kw, *] out_grad pointer to [*, out_c], N*hout*wout as reduce dim
294 curr_out_grad_pointer = (
295 out_grad_pointer
296 + (
297 output_n_offset * output_n_stride
298 + h * output_height_stride
299 + w * output_width_stride
300 )[:, None]
301 )
302 out_grad_mask = (output_n_offset < in_n)[:, None] & (
303 output_c_offset < out_c
304 )[None, :]
306 curr_out_grad = tl.load(curr_out_grad_pointer, mask=out_grad_mask)
308 input_height_offset = (
309 weight_height_point_value * dilation_height
310 - padding_height
311 + stride_height * h
312 )
314 input_width_offset = (
315 weight_width_point_value * dilation_width
316 - padding_width
317 + stride_width * w
318 )
320 curr_input_pointer = (
321 input_pointer
322 + (input_n_stride * output_n_offset)[None, :]
323 + (input_height_stride * input_height_offset)[:, None]
324 + (input_width_stride * input_width_offset)[:, None]
325 )
326 input_mask = (
327 (output_n_offset < in_n)[None, :]
328 & (ci_point_value < input_c)[:, None]
329 & (0 <= input_height_offset)[:, None]
330 & (input_height_offset < input_height)[:, None]
331 & (0 <= input_width_offset)[:, None]
332 & (input_width_offset < input_width)[:, None]
333 )
335 curr_input = tl.load(curr_input_pointer, mask=input_mask)
337 # Mixed precision: always convert to FP32 for FP16/BF16 safety
338 # This is a simplified check - in practice, should pass USE_MIXED_PRECISION
339 # For now, we detect if it's FP16/BF16 and convert
340 if curr_input.dtype != tl.float32:
341 curr_input = curr_input.to(tl.float32)
342 if curr_out_grad.dtype != tl.float32:
343 curr_out_grad = curr_out_grad.to(tl.float32)
345 accum += tl.dot(curr_input, curr_out_grad, allow_tf32=False)
347 weight_mask = (
348 (ci_point_value < input_c)[:, None]
349 & (output_c_offset < out_c)[None, :]
350 & (weight_height_point_value < weight_height)[:, None]
351 & (weight_width_point_value < weight_width)[:, None]
352 )
353 tl.store(weight_pointer, accum, weight_mask)
356class Conv2d(torch.autograd.Function):
357 @staticmethod
358 def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
359 logger.debug("GEMS CONV2D")
360 assert weight.ndim == 4, "Weights must be 4D, received shape {weight.shape}"
361 assert (
362 bias is None or bias.ndim == 1
363 ), "Bias must be 1D, received shape {bias.shape}"
365 assert (
366 input.shape[1] == groups * weight.shape[1]
367 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups"
368 assert (
369 bias is None or weight.shape[0] == bias.shape[0]
370 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape"
372 if isinstance(stride, (list, tuple)):
373 stride_height, stride_width = stride
374 else:
375 stride_height = stride_width = stride
377 if isinstance(padding, (list, tuple)):
378 padding_height, padding_width = padding
379 else:
380 padding_height = padding_width = padding
382 if isinstance(dilation, (list, tuple)):
383 dilation_height, dilation_width = dilation
384 else:
385 dilation_height = dilation_width = dilation
387 in_n, _, input_height, input_width = input.shape
388 out_c, weight_c, weight_height, weight_width = weight.shape
389 out_height = conv2d_output_size(
390 input_height, weight_height, stride_height, padding_height, dilation_height
391 )
392 out_width = conv2d_output_size(
393 input_width, weight_width, stride_width, padding_width, dilation_width
394 )
396 output_dtype = input.dtype
398 # Hybrid strategy: Python-level FP32 conversion for small cases,
399 # kernel-level mixed precision for large cases
400 #
401 # Hardware constraints (XPU3):
402 # - FP16: Supports mixed precision (verified to work)
403 # - BF16: Limited support, "unsupported data type" errors in some cases
404 # → Always use Python FP32 conversion for safety
405 #
406 # Rationale:
407 # - Small FP16 cases: Python FP32 matches PyTorch reference exactly
408 # - Large FP16 cases: Mixed precision saves 50% bandwidth → 2x speedup
409 # - All BF16 cases: Python FP32 for hardware compatibility
410 #
411 # Threshold: spatial_size > 1024 triggers FP16 mixed precision
412 spatial_size = input_height * input_width
413 is_large_case = (spatial_size > 1024) and (in_n * out_c > 64)
415 # Only enable mixed precision for FP16 large cases
416 use_mixed_precision = (input.dtype == torch.float16) and is_large_case
417 use_python_fp32 = (
418 input.dtype in (torch.float16, torch.bfloat16)
419 ) and not use_mixed_precision
421 if use_python_fp32:
422 # Small cases: convert in Python layer for reference-matching behavior
423 input = input.to(torch.float32)
424 weight = weight.to(torch.float32)
425 if bias is not None:
426 bias = bias.to(torch.float32)
427 compute_dtype = torch.float32
428 else:
429 # Large cases or FP32: keep original precision
430 compute_dtype = output_dtype
432 output = torch.empty(
433 (in_n, out_c, out_height, out_width),
434 device=input.device,
435 dtype=compute_dtype,
436 )
438 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions,
439 # BLOCK_CO along the out_c,
440 # one group per cat
441 grid = lambda META: (
442 triton.cdiv(in_n * out_height * out_width, META["BLOCK_NI_HO_WO"]),
443 triton.cdiv(int(out_c // groups), META["BLOCK_CO"]),
444 groups,
445 )
447 if bias is None:
448 bias_pointer = torch.zeros(out_c, device=input.device, dtype=torch.float)
449 else:
450 bias_pointer = bias.to(torch.float)
451 flag = 0
452 if input.shape[2] != input.shape[3]:
453 flag = 999
454 else:
455 flag = 32
456 conv2d_forward_kernel[grid](
457 input,
458 weight,
459 output,
460 bias_pointer,
461 in_n,
462 input_height,
463 input_width,
464 out_c,
465 out_height,
466 out_width,
467 *input.stride(),
468 *weight.stride(),
469 *output.stride(),
470 weight_c,
471 weight_height,
472 weight_width,
473 stride_height,
474 stride_width,
475 padding_height,
476 padding_width,
477 dilation_height,
478 dilation_width,
479 groups=groups,
480 BLOCK_NI_HO_WO=flag,
481 BLOCK_CI=32,
482 BLOCK_CO=32,
483 USE_MIXED_PRECISION=use_mixed_precision,
484 )
486 ctx.save_for_backward(weight, input, bias)
488 ctx.stride = (stride_height, stride_width)
489 ctx.padding = (padding_height, padding_width)
490 ctx.dilation = (dilation_height, dilation_width)
492 ctx.weight_info = (int(out_c / groups), weight_c, weight_height, weight_width)
493 ctx.input_info = (in_n, input_height, input_width)
494 ctx.out_info = (out_height, out_width)
496 ctx.device = input.device
497 ctx.groups = groups
498 ctx.use_mixed_precision = use_mixed_precision
499 ctx.use_python_fp32 = use_python_fp32
500 ctx.output_dtype = output_dtype
502 # Convert output back if we used Python-level FP32 conversion
503 if use_python_fp32:
504 output = output.to(output_dtype)
506 return output
508 @staticmethod
509 def backward(ctx, out_grad):
510 logger.debug("GEMS CONV2D VJP")
511 (weight, input, bias) = ctx.saved_tensors
512 # (out_c equals origin cout divide groups)
513 out_c, weight_c, weight_height, weight_width = ctx.weight_info
514 in_n, input_height, input_width = ctx.input_info
515 out_height, out_width = ctx.out_info
517 device = ctx.device
518 groups = ctx.groups
519 use_mixed_precision = ctx.use_mixed_precision
520 use_python_fp32 = ctx.use_python_fp32
521 output_dtype = ctx.output_dtype
523 stride_height, stride_width = ctx.stride
524 dilation_height, dilation_width = ctx.dilation
525 padding_height, padding_width = ctx.padding
527 # If forward used Python-level FP32, convert out_grad to match
528 if use_python_fp32 and out_grad.dtype in (torch.float16, torch.bfloat16):
529 out_grad = out_grad.to(torch.float32)
531 revert_padding_height = dilation_height * (weight_height - 1) - padding_height
532 revert_padding_width = dilation_width * (weight_width - 1) - padding_width
533 revert_weight = weight.clone()
534 revert_weight = torch.flip(revert_weight, dims=[2, 3]).contiguous()
536 if groups != 1:
537 revert_weight = revert_weight.reshape(
538 groups, out_c, weight_c, weight_height, weight_width
539 )
540 revert_weight = revert_weight.transpose(1, 2)
541 revert_weight = revert_weight.reshape(
542 groups * weight_c, out_c, weight_height, weight_width
543 ).contiguous()
544 else:
545 revert_weight = revert_weight.transpose(0, 1).contiguous()
547 # Calculate new_out dimensions for transposed convolution
548 # Must account for output_padding when (input + 2*padding - dilation*(kernel-1) - 1) % stride != 0
549 new_out_height = (
550 input_height + 2 * padding_height - dilation_height * (weight_height - 1)
551 )
552 new_out_width = (
553 input_width + 2 * padding_width - dilation_width * (weight_width - 1)
554 )
556 new_out = torch.zeros(
557 out_grad.shape[0],
558 out_grad.shape[1],
559 new_out_height,
560 new_out_width,
561 device=device,
562 dtype=out_grad.dtype,
563 )
565 # copy out_grad to new_out
566 if stride_height > 1 or stride_width > 1:
567 for i in range(out_grad.shape[2]):
568 for j in range(out_grad.shape[3]):
569 new_out[:, :, i * (stride_height), j * (stride_width)] = out_grad[
570 :, :, i, j
571 ]
572 else:
573 new_out = out_grad
575 input_back = torch.zeros(
576 in_n,
577 weight_c * groups,
578 input_height,
579 input_width,
580 dtype=input.dtype, # Use original dtype for mixed precision
581 device=device,
582 )
584 grid = lambda META: (
585 triton.cdiv(
586 out_grad.shape[0] * input_height * input_width, META["BLOCK_NI_HO_WO"]
587 ),
588 triton.cdiv(int(weight_c), META["BLOCK_CO"]),
589 groups,
590 )
591 flag = 888
592 bias_zero = torch.zeros(groups * weight_c, device=device, dtype=out_grad.dtype)
593 conv2d_forward_kernel[grid](
594 new_out,
595 revert_weight,
596 input_back,
597 bias_zero,
598 out_grad.shape[0],
599 new_out_height,
600 new_out_width,
601 groups * weight_c,
602 input_height,
603 input_width,
604 *new_out.stride(),
605 *revert_weight.stride(),
606 *input_back.stride(),
607 out_c,
608 weight_height,
609 weight_width,
610 1,
611 1,
612 revert_padding_height,
613 revert_padding_width,
614 dilation_height,
615 dilation_width,
616 groups=groups,
617 BLOCK_NI_HO_WO=flag,
618 BLOCK_CI=32,
619 BLOCK_CO=32,
620 USE_MIXED_PRECISION=use_mixed_precision,
621 )
623 # For mixed precision: weight_back accumulator must be FP32 to prevent overflow
624 # We'll convert back to original dtype at the end
625 weight_back_dtype = torch.float32 if use_mixed_precision else weight.dtype
627 weight_back = torch.zeros(
628 out_c * groups,
629 weight_c,
630 weight_height,
631 weight_width,
632 dtype=weight_back_dtype,
633 device=device,
634 )
636 grid_weight = lambda meta: (
637 triton.cdiv(
638 weight_c * weight_height * weight_width, meta["BLOCK_CI_HK_WK"]
639 ),
640 groups,
641 triton.cdiv(out_c, meta["BLOCK_CO"]),
642 )
643 conv2d_backward_kernel_weight[grid_weight](
644 input,
645 out_grad,
646 weight_back,
647 *input.stride(),
648 *weight.stride(),
649 *out_grad.stride(),
650 input_height,
651 input_width,
652 weight_height,
653 weight_width,
654 weight_c,
655 in_n,
656 stride_height,
657 stride_width,
658 out_height,
659 out_width,
660 out_c,
661 padding_height,
662 padding_width,
663 dilation_height,
664 dilation_width,
665 groups,
666 BLOCK_NO=32,
667 BLOCK_CI_HK_WK=32,
668 BLOCK_CO=32,
669 )
670 if bias is not None:
671 bias_grad = out_grad.sum(dim=(0, 2, 3))
672 else:
673 bias_grad = None
675 # Convert gradients back to original dtype if needed
676 if use_python_fp32:
677 # Python FP32 path: convert everything back
678 input_back = (
679 input_back.to(output_dtype)
680 if input_back.dtype != output_dtype
681 else input_back
682 )
683 weight_back = (
684 weight_back.to(output_dtype)
685 if weight_back.dtype != output_dtype
686 else weight_back
687 )
688 if bias_grad is not None:
689 bias_grad = (
690 bias_grad.to(output_dtype)
691 if bias_grad.dtype != output_dtype
692 else bias_grad
693 )
694 elif use_mixed_precision and weight_back.dtype != weight.dtype:
695 # Mixed precision path: weight_back was FP32, convert back
696 weight_back = weight_back.to(weight.dtype)
698 return (
699 input_back,
700 weight_back,
701 bias_grad,
702 None,
703 None,
704 None,
705 None,
706 )
709# todo test SymInt[2] of stride or padding
710def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
711 if isinstance(padding, str):
712 if padding == "same":
713 assert stride == 1, (
714 f"Doesn't support any stride values other than 1 in padding = 'same' mode, "
715 f"received stride value {stride}"
716 )
717 ih = input.shape[-2]
718 iw = input.shape[-1]
719 kernel_size_h = weight.shape[-2]
720 kernel_size_w = weight.shape[-1]
721 import math
723 padding_h = int(
724 math.ceil(
725 (stride * (ih - 1) + 1 + dilation * (kernel_size_h - 1) - ih) / 2
726 )
727 )
728 padding_w = int(
729 math.ceil(
730 (stride * (iw - 1) + 1 + dilation * (kernel_size_w - 1) - iw) / 2
731 )
732 )
733 oh = int(
734 (ih + 2 * padding_h - dilation * (kernel_size_h - 1) - 1) / stride + 1
735 )
736 ow = int(
737 (iw + 2 * padding_w - dilation * (kernel_size_w - 1) - 1) / stride + 1
738 )
739 padding = max(padding_h, padding_w)
740 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)[
741 ..., (oh - ih) :, (ow - iw) :
742 ]
743 elif padding == "valid":
744 return Conv2d.apply(input, weight, bias, stride, 0, dilation, groups)
745 else:
746 raise ValueError(
747 f"Unsupported padding string: {padding}, only 'valid'/'same' are allowed."
748 )
749 else:
750 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)