Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/conv2d.py: 0%

216 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12 

13def conv2d_output_size( 

14 in_size: int, 

15 kernel_size: int, 

16 stride: int, 

17 padding: int, 

18 dilation: int, 

19) -> int: 

20 """ 

21 Determines the output size of a 2D convolution operation. 

22 

23 Args: 

24 in_size: Input size. 

25 kernel_size: Kernel size. 

26 stride: Stride. 

27 padding: Padding. 

28 dilation: Dilation. 

29 

30 Returns: 

31 Output size of 2D convolution. 

32 """ 

33 return (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 

34 

35 

36@libentry() 

37# @triton.autotune( 

38# configs=runtime.get_tuned_config("conv2d_forward"), 

39# key=[ 

40# "in_n", 

41# "weight_c", 

42# "input_height", 

43# "input_width", 

44# "out_c", 

45# "out_height", 

46# "out_width", 

47# "weight_height", 

48# "weight_width", 

49# "stride_height", 

50# "stride_width", 

51# "padding_height", 

52# "padding_width", 

53# "groups", 

54# ], 

55# ) 

56@triton.jit 

57def conv2d_forward_kernel( 

58 input_pointer, 

59 weight_pointer, 

60 output_pointer, 

61 bias_pointer, 

62 in_n, 

63 input_height, 

64 input_width, 

65 out_c, 

66 out_height, 

67 out_width, 

68 input_n_stride, 

69 input_c_stride, 

70 input_height_stride, 

71 input_width_stride, 

72 weight_n_stride, 

73 weight_c_stride, 

74 weight_height_stride, 

75 weight_width_stride, 

76 output_n_stride, 

77 output_c_stride, 

78 output_height_stride, 

79 output_width_stride, 

80 weight_c: tl.constexpr, 

81 weight_height: tl.constexpr, 

82 weight_width: tl.constexpr, 

83 stride_height: tl.constexpr, 

84 stride_width: tl.constexpr, 

85 padding_height: tl.constexpr, 

86 padding_width: tl.constexpr, 

87 dilation_height: tl.constexpr, 

88 dilation_width: tl.constexpr, 

89 groups: tl.constexpr, 

90 BLOCK_NI_HO_WO: tl.constexpr, 

91 BLOCK_CI: tl.constexpr, 

92 BLOCK_CO: tl.constexpr, 

93 USE_MIXED_PRECISION: tl.constexpr, 

94): 

95 """ 

96 Mixed-precision forward kernel. 

97 When USE_MIXED_PRECISION=True: FP16/BF16 I/O + FP32 accumulator 

98 """ 

99 pid_ni_ho_wo = tl.program_id(0) 

100 pid_co = tl.program_id(1) 

101 pid_group = tl.program_id(2) 

102 

103 # caculate in_n out_height out_weight value in kernel 

104 ni_ho_wo_offset = pid_ni_ho_wo * BLOCK_NI_HO_WO + tl.arange(0, BLOCK_NI_HO_WO) 

105 ni_ho_offset = ni_ho_wo_offset // out_width 

106 in_n_point_value = ni_ho_offset // out_height 

107 output_height_point_value = ni_ho_offset % out_height 

108 output_width_point_value = ni_ho_wo_offset % out_width 

109 

110 # Load the input and weight pointers. input and weight are of shape 

111 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width] 

112 out_per_group_c = out_c // groups 

113 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

114 input_pointer += ( 

115 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c 

116 )[:, None] 

117 weight_pointer += ( 

118 weight_n_stride * output_c_offset 

119 + weight_n_stride * pid_group * out_per_group_c 

120 )[None, :] 

121 

122 accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32) 

123 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI 

124 for hwc in range(weight_height * weight_width * BLOCK_CI_COUNT): 

125 c = (hwc % BLOCK_CI_COUNT) * BLOCK_CI 

126 hw = hwc // BLOCK_CI_COUNT 

127 h = hw // weight_width 

128 w = hw % weight_width 

129 

130 input_c_offset = c + tl.arange(0, BLOCK_CI) 

131 input_height_offset = ( 

132 h * dilation_height 

133 - padding_height 

134 + stride_height * output_height_point_value 

135 ) 

136 input_width_offset = ( 

137 w * dilation_width - padding_width + stride_width * output_width_point_value 

138 ) 

139 

140 curr_input_pointer = ( 

141 input_pointer 

142 + (input_c_stride * input_c_offset)[None, :] 

143 + (input_height_stride * input_height_offset)[:, None] 

144 + (input_width_stride * input_width_offset)[:, None] 

145 ) 

146 curr_weight_pointer = ( 

147 weight_pointer 

148 + (weight_c_stride * input_c_offset)[:, None] 

149 + (weight_height_stride * h) 

150 + (weight_width_stride * w) 

151 ) 

152 

153 input_mask = ( 

154 (in_n_point_value < in_n)[:, None] 

155 & (input_c_offset < weight_c)[None, :] 

156 & (0 <= input_height_offset)[:, None] 

157 & (input_height_offset < input_height)[:, None] 

158 & (0 <= input_width_offset)[:, None] 

159 & (input_width_offset < input_width)[:, None] 

160 ) 

161 weight_mask = (input_c_offset < weight_c)[:, None] & ( 

162 output_c_offset < out_per_group_c 

163 )[None, :] 

164 

165 input_block = tl.load(curr_input_pointer, mask=input_mask) 

166 weight_block = tl.load(curr_weight_pointer, mask=weight_mask) 

167 

168 # Mixed precision: convert to FP32 for computation 

169 if USE_MIXED_PRECISION: 

170 input_block = input_block.to(tl.float32) 

171 weight_block = weight_block.to(tl.float32) 

172 

173 accum += tl.dot(input_block, weight_block, allow_tf32=False) 

174 bias_pointer += pid_group * out_per_group_c[None, :] + output_c_offset[None, :] 

175 mask_bias = (output_c_offset < out_per_group_c)[None, :] 

176 bias = tl.load(bias_pointer, mask_bias).to(tl.float32) 

177 accum += bias 

178 output_pointer += ( 

179 (output_n_stride * in_n_point_value)[:, None] 

180 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :] 

181 + (output_height_stride * output_height_point_value)[:, None] 

182 + (output_width_stride * output_width_point_value)[:, None] 

183 ) 

184 output_mask = ( 

185 (in_n_point_value < in_n)[:, None] 

186 & (output_c_offset < out_per_group_c)[None, :] 

187 & (output_height_point_value < out_height)[:, None] 

188 & (output_width_point_value < out_width)[:, None] 

189 ) 

190 

191 tl.store(output_pointer, accum, mask=output_mask) 

192 

193 

194@libentry() 

195# @triton.autotune( 

196# configs=runtime.get_tuned_config("conv2d_backward_weight"), 

197# key=[ 

198# "in_n", 

199# "input_height", 

200# "input_width", 

201# "weight_height", 

202# "weight_width", 

203# "input_c", 

204# "stride_height", 

205# "stride_width", 

206# "out_height", 

207# "out_width", 

208# "out_c", 

209# "padding_height", 

210# "padding_width", 

211# ], 

212# ) 

213@triton.jit 

214def conv2d_backward_kernel_weight( 

215 input_pointer, 

216 out_grad_pointer, 

217 weight_pointer, 

218 input_n_stride, 

219 input_c_stride, 

220 input_height_stride, 

221 input_width_stride, 

222 weight_n_stride, 

223 weight_c_stride, 

224 weight_height_stride, 

225 weight_width_stride, 

226 output_n_stride, 

227 output_c_stride, 

228 output_height_stride, 

229 output_width_stride, 

230 input_height, 

231 input_width, 

232 weight_height, 

233 weight_width, 

234 input_c, 

235 in_n, 

236 stride_height, 

237 stride_width, 

238 out_height, 

239 out_width, 

240 out_c, 

241 padding_height, 

242 padding_width, 

243 dilation_height, 

244 dilation_width, 

245 groups: tl.constexpr, 

246 BLOCK_NO: tl.constexpr, 

247 BLOCK_CI_HK_WK: tl.constexpr, 

248 BLOCK_CO: tl.constexpr, 

249): 

250 # load out_grad n (groups out_c) ho wo 

251 # load weight (groups out_c) ci h w 

252 # load input n (groups ci) hi wi 

253 

254 # init pid and offset 0 for ci*hk*wk, 1 for groups, 2 for co. 

255 pid_ci_hk_wk = tl.program_id(0) 

256 pid_groups = tl.program_id(1) 

257 pid_co = tl.program_id(2) 

258 

259 # caculate ci weight_height weight_weight value in kernel 

260 ci_hk_wk_offset = pid_ci_hk_wk * BLOCK_CI_HK_WK + tl.arange(0, BLOCK_CI_HK_WK) 

261 ci_hk_offset = ci_hk_wk_offset // weight_width 

262 ci_point_value = ci_hk_offset // weight_height 

263 weight_height_point_value = ci_hk_offset % weight_height 

264 weight_width_point_value = ci_hk_wk_offset % weight_width 

265 

266 # caculate init pointer info of tensors 

267 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

268 out_grad_pointer += (output_c_offset * output_c_stride)[None, :] + ( 

269 pid_groups * output_c_stride * out_c 

270 )[:, None] 

271 

272 weight_pointer += ( 

273 pid_groups * weight_n_stride * out_c + output_c_offset * weight_n_stride 

274 )[None, :] + ( 

275 ci_point_value * weight_c_stride 

276 + weight_height_point_value * weight_height_stride 

277 + weight_width_point_value * weight_width_stride 

278 )[ 

279 :, None 

280 ] 

281 

282 input_pointer += (ci_point_value * input_c_stride)[:, None] + ( 

283 pid_groups * input_c_stride * input_c 

284 )[None, :] 

285 

286 # calculate the values of the input based on the width and height of the output by looping 

287 accum = tl.zeros((BLOCK_CI_HK_WK, BLOCK_CO), dtype=tl.float32) 

288 for h in range(0, out_height): 

289 for w in range(0, out_width): 

290 for n in range(0, in_n, BLOCK_NO): 

291 output_n_offset = n + tl.arange(0, BLOCK_NO) 

292 

293 # caculate input pointer to [cin*kh*kw, *] out_grad pointer to [*, out_c], N*hout*wout as reduce dim 

294 curr_out_grad_pointer = ( 

295 out_grad_pointer 

296 + ( 

297 output_n_offset * output_n_stride 

298 + h * output_height_stride 

299 + w * output_width_stride 

300 )[:, None] 

301 ) 

302 out_grad_mask = (output_n_offset < in_n)[:, None] & ( 

303 output_c_offset < out_c 

304 )[None, :] 

305 

306 curr_out_grad = tl.load(curr_out_grad_pointer, mask=out_grad_mask) 

307 

308 input_height_offset = ( 

309 weight_height_point_value * dilation_height 

310 - padding_height 

311 + stride_height * h 

312 ) 

313 

314 input_width_offset = ( 

315 weight_width_point_value * dilation_width 

316 - padding_width 

317 + stride_width * w 

318 ) 

319 

320 curr_input_pointer = ( 

321 input_pointer 

322 + (input_n_stride * output_n_offset)[None, :] 

323 + (input_height_stride * input_height_offset)[:, None] 

324 + (input_width_stride * input_width_offset)[:, None] 

325 ) 

326 input_mask = ( 

327 (output_n_offset < in_n)[None, :] 

328 & (ci_point_value < input_c)[:, None] 

329 & (0 <= input_height_offset)[:, None] 

330 & (input_height_offset < input_height)[:, None] 

331 & (0 <= input_width_offset)[:, None] 

332 & (input_width_offset < input_width)[:, None] 

333 ) 

334 

335 curr_input = tl.load(curr_input_pointer, mask=input_mask) 

336 

337 # Mixed precision: always convert to FP32 for FP16/BF16 safety 

338 # This is a simplified check - in practice, should pass USE_MIXED_PRECISION 

339 # For now, we detect if it's FP16/BF16 and convert 

340 if curr_input.dtype != tl.float32: 

341 curr_input = curr_input.to(tl.float32) 

342 if curr_out_grad.dtype != tl.float32: 

343 curr_out_grad = curr_out_grad.to(tl.float32) 

344 

345 accum += tl.dot(curr_input, curr_out_grad, allow_tf32=False) 

346 

347 weight_mask = ( 

348 (ci_point_value < input_c)[:, None] 

349 & (output_c_offset < out_c)[None, :] 

350 & (weight_height_point_value < weight_height)[:, None] 

351 & (weight_width_point_value < weight_width)[:, None] 

352 ) 

353 tl.store(weight_pointer, accum, weight_mask) 

354 

355 

356class Conv2d(torch.autograd.Function): 

357 @staticmethod 

358 def forward(ctx, input, weight, bias, stride, padding, dilation, groups): 

359 logger.debug("GEMS CONV2D") 

360 assert weight.ndim == 4, "Weights must be 4D, received shape {weight.shape}" 

361 assert ( 

362 bias is None or bias.ndim == 1 

363 ), "Bias must be 1D, received shape {bias.shape}" 

364 

365 assert ( 

366 input.shape[1] == groups * weight.shape[1] 

367 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups" 

368 assert ( 

369 bias is None or weight.shape[0] == bias.shape[0] 

370 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape" 

371 

372 if isinstance(stride, (list, tuple)): 

373 stride_height, stride_width = stride 

374 else: 

375 stride_height = stride_width = stride 

376 

377 if isinstance(padding, (list, tuple)): 

378 padding_height, padding_width = padding 

379 else: 

380 padding_height = padding_width = padding 

381 

382 if isinstance(dilation, (list, tuple)): 

383 dilation_height, dilation_width = dilation 

384 else: 

385 dilation_height = dilation_width = dilation 

386 

387 in_n, _, input_height, input_width = input.shape 

388 out_c, weight_c, weight_height, weight_width = weight.shape 

389 out_height = conv2d_output_size( 

390 input_height, weight_height, stride_height, padding_height, dilation_height 

391 ) 

392 out_width = conv2d_output_size( 

393 input_width, weight_width, stride_width, padding_width, dilation_width 

394 ) 

395 

396 output_dtype = input.dtype 

397 

398 # Hybrid strategy: Python-level FP32 conversion for small cases, 

399 # kernel-level mixed precision for large cases 

400 # 

401 # Hardware constraints (XPU3): 

402 # - FP16: Supports mixed precision (verified to work) 

403 # - BF16: Limited support, "unsupported data type" errors in some cases 

404 # → Always use Python FP32 conversion for safety 

405 # 

406 # Rationale: 

407 # - Small FP16 cases: Python FP32 matches PyTorch reference exactly 

408 # - Large FP16 cases: Mixed precision saves 50% bandwidth → 2x speedup 

409 # - All BF16 cases: Python FP32 for hardware compatibility 

410 # 

411 # Threshold: spatial_size > 1024 triggers FP16 mixed precision 

412 spatial_size = input_height * input_width 

413 is_large_case = (spatial_size > 1024) and (in_n * out_c > 64) 

414 

415 # Only enable mixed precision for FP16 large cases 

416 use_mixed_precision = (input.dtype == torch.float16) and is_large_case 

417 use_python_fp32 = ( 

418 input.dtype in (torch.float16, torch.bfloat16) 

419 ) and not use_mixed_precision 

420 

421 if use_python_fp32: 

422 # Small cases: convert in Python layer for reference-matching behavior 

423 input = input.to(torch.float32) 

424 weight = weight.to(torch.float32) 

425 if bias is not None: 

426 bias = bias.to(torch.float32) 

427 compute_dtype = torch.float32 

428 else: 

429 # Large cases or FP32: keep original precision 

430 compute_dtype = output_dtype 

431 

432 output = torch.empty( 

433 (in_n, out_c, out_height, out_width), 

434 device=input.device, 

435 dtype=compute_dtype, 

436 ) 

437 

438 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions, 

439 # BLOCK_CO along the out_c, 

440 # one group per cat 

441 grid = lambda META: ( 

442 triton.cdiv(in_n * out_height * out_width, META["BLOCK_NI_HO_WO"]), 

443 triton.cdiv(int(out_c // groups), META["BLOCK_CO"]), 

444 groups, 

445 ) 

446 

447 if bias is None: 

448 bias_pointer = torch.zeros(out_c, device=input.device, dtype=torch.float) 

449 else: 

450 bias_pointer = bias.to(torch.float) 

451 flag = 0 

452 if input.shape[2] != input.shape[3]: 

453 flag = 999 

454 else: 

455 flag = 32 

456 conv2d_forward_kernel[grid]( 

457 input, 

458 weight, 

459 output, 

460 bias_pointer, 

461 in_n, 

462 input_height, 

463 input_width, 

464 out_c, 

465 out_height, 

466 out_width, 

467 *input.stride(), 

468 *weight.stride(), 

469 *output.stride(), 

470 weight_c, 

471 weight_height, 

472 weight_width, 

473 stride_height, 

474 stride_width, 

475 padding_height, 

476 padding_width, 

477 dilation_height, 

478 dilation_width, 

479 groups=groups, 

480 BLOCK_NI_HO_WO=flag, 

481 BLOCK_CI=32, 

482 BLOCK_CO=32, 

483 USE_MIXED_PRECISION=use_mixed_precision, 

484 ) 

485 

486 ctx.save_for_backward(weight, input, bias) 

487 

488 ctx.stride = (stride_height, stride_width) 

489 ctx.padding = (padding_height, padding_width) 

490 ctx.dilation = (dilation_height, dilation_width) 

491 

492 ctx.weight_info = (int(out_c / groups), weight_c, weight_height, weight_width) 

493 ctx.input_info = (in_n, input_height, input_width) 

494 ctx.out_info = (out_height, out_width) 

495 

496 ctx.device = input.device 

497 ctx.groups = groups 

498 ctx.use_mixed_precision = use_mixed_precision 

499 ctx.use_python_fp32 = use_python_fp32 

500 ctx.output_dtype = output_dtype 

501 

502 # Convert output back if we used Python-level FP32 conversion 

503 if use_python_fp32: 

504 output = output.to(output_dtype) 

505 

506 return output 

507 

508 @staticmethod 

509 def backward(ctx, out_grad): 

510 logger.debug("GEMS CONV2D VJP") 

511 (weight, input, bias) = ctx.saved_tensors 

512 # (out_c equals origin cout divide groups) 

513 out_c, weight_c, weight_height, weight_width = ctx.weight_info 

514 in_n, input_height, input_width = ctx.input_info 

515 out_height, out_width = ctx.out_info 

516 

517 device = ctx.device 

518 groups = ctx.groups 

519 use_mixed_precision = ctx.use_mixed_precision 

520 use_python_fp32 = ctx.use_python_fp32 

521 output_dtype = ctx.output_dtype 

522 

523 stride_height, stride_width = ctx.stride 

524 dilation_height, dilation_width = ctx.dilation 

525 padding_height, padding_width = ctx.padding 

526 

527 # If forward used Python-level FP32, convert out_grad to match 

528 if use_python_fp32 and out_grad.dtype in (torch.float16, torch.bfloat16): 

529 out_grad = out_grad.to(torch.float32) 

530 

531 revert_padding_height = dilation_height * (weight_height - 1) - padding_height 

532 revert_padding_width = dilation_width * (weight_width - 1) - padding_width 

533 revert_weight = weight.clone() 

534 revert_weight = torch.flip(revert_weight, dims=[2, 3]).contiguous() 

535 

536 if groups != 1: 

537 revert_weight = revert_weight.reshape( 

538 groups, out_c, weight_c, weight_height, weight_width 

539 ) 

540 revert_weight = revert_weight.transpose(1, 2) 

541 revert_weight = revert_weight.reshape( 

542 groups * weight_c, out_c, weight_height, weight_width 

543 ).contiguous() 

544 else: 

545 revert_weight = revert_weight.transpose(0, 1).contiguous() 

546 

547 # Calculate new_out dimensions for transposed convolution 

548 # Must account for output_padding when (input + 2*padding - dilation*(kernel-1) - 1) % stride != 0 

549 new_out_height = ( 

550 input_height + 2 * padding_height - dilation_height * (weight_height - 1) 

551 ) 

552 new_out_width = ( 

553 input_width + 2 * padding_width - dilation_width * (weight_width - 1) 

554 ) 

555 

556 new_out = torch.zeros( 

557 out_grad.shape[0], 

558 out_grad.shape[1], 

559 new_out_height, 

560 new_out_width, 

561 device=device, 

562 dtype=out_grad.dtype, 

563 ) 

564 

565 # copy out_grad to new_out 

566 if stride_height > 1 or stride_width > 1: 

567 for i in range(out_grad.shape[2]): 

568 for j in range(out_grad.shape[3]): 

569 new_out[:, :, i * (stride_height), j * (stride_width)] = out_grad[ 

570 :, :, i, j 

571 ] 

572 else: 

573 new_out = out_grad 

574 

575 input_back = torch.zeros( 

576 in_n, 

577 weight_c * groups, 

578 input_height, 

579 input_width, 

580 dtype=input.dtype, # Use original dtype for mixed precision 

581 device=device, 

582 ) 

583 

584 grid = lambda META: ( 

585 triton.cdiv( 

586 out_grad.shape[0] * input_height * input_width, META["BLOCK_NI_HO_WO"] 

587 ), 

588 triton.cdiv(int(weight_c), META["BLOCK_CO"]), 

589 groups, 

590 ) 

591 flag = 888 

592 bias_zero = torch.zeros(groups * weight_c, device=device, dtype=out_grad.dtype) 

593 conv2d_forward_kernel[grid]( 

594 new_out, 

595 revert_weight, 

596 input_back, 

597 bias_zero, 

598 out_grad.shape[0], 

599 new_out_height, 

600 new_out_width, 

601 groups * weight_c, 

602 input_height, 

603 input_width, 

604 *new_out.stride(), 

605 *revert_weight.stride(), 

606 *input_back.stride(), 

607 out_c, 

608 weight_height, 

609 weight_width, 

610 1, 

611 1, 

612 revert_padding_height, 

613 revert_padding_width, 

614 dilation_height, 

615 dilation_width, 

616 groups=groups, 

617 BLOCK_NI_HO_WO=flag, 

618 BLOCK_CI=32, 

619 BLOCK_CO=32, 

620 USE_MIXED_PRECISION=use_mixed_precision, 

621 ) 

622 

623 # For mixed precision: weight_back accumulator must be FP32 to prevent overflow 

624 # We'll convert back to original dtype at the end 

625 weight_back_dtype = torch.float32 if use_mixed_precision else weight.dtype 

626 

627 weight_back = torch.zeros( 

628 out_c * groups, 

629 weight_c, 

630 weight_height, 

631 weight_width, 

632 dtype=weight_back_dtype, 

633 device=device, 

634 ) 

635 

636 grid_weight = lambda meta: ( 

637 triton.cdiv( 

638 weight_c * weight_height * weight_width, meta["BLOCK_CI_HK_WK"] 

639 ), 

640 groups, 

641 triton.cdiv(out_c, meta["BLOCK_CO"]), 

642 ) 

643 conv2d_backward_kernel_weight[grid_weight]( 

644 input, 

645 out_grad, 

646 weight_back, 

647 *input.stride(), 

648 *weight.stride(), 

649 *out_grad.stride(), 

650 input_height, 

651 input_width, 

652 weight_height, 

653 weight_width, 

654 weight_c, 

655 in_n, 

656 stride_height, 

657 stride_width, 

658 out_height, 

659 out_width, 

660 out_c, 

661 padding_height, 

662 padding_width, 

663 dilation_height, 

664 dilation_width, 

665 groups, 

666 BLOCK_NO=32, 

667 BLOCK_CI_HK_WK=32, 

668 BLOCK_CO=32, 

669 ) 

670 if bias is not None: 

671 bias_grad = out_grad.sum(dim=(0, 2, 3)) 

672 else: 

673 bias_grad = None 

674 

675 # Convert gradients back to original dtype if needed 

676 if use_python_fp32: 

677 # Python FP32 path: convert everything back 

678 input_back = ( 

679 input_back.to(output_dtype) 

680 if input_back.dtype != output_dtype 

681 else input_back 

682 ) 

683 weight_back = ( 

684 weight_back.to(output_dtype) 

685 if weight_back.dtype != output_dtype 

686 else weight_back 

687 ) 

688 if bias_grad is not None: 

689 bias_grad = ( 

690 bias_grad.to(output_dtype) 

691 if bias_grad.dtype != output_dtype 

692 else bias_grad 

693 ) 

694 elif use_mixed_precision and weight_back.dtype != weight.dtype: 

695 # Mixed precision path: weight_back was FP32, convert back 

696 weight_back = weight_back.to(weight.dtype) 

697 

698 return ( 

699 input_back, 

700 weight_back, 

701 bias_grad, 

702 None, 

703 None, 

704 None, 

705 None, 

706 ) 

707 

708 

709# todo test SymInt[2] of stride or padding 

710def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 

711 if isinstance(padding, str): 

712 if padding == "same": 

713 assert stride == 1, ( 

714 f"Doesn't support any stride values other than 1 in padding = 'same' mode, " 

715 f"received stride value {stride}" 

716 ) 

717 ih = input.shape[-2] 

718 iw = input.shape[-1] 

719 kernel_size_h = weight.shape[-2] 

720 kernel_size_w = weight.shape[-1] 

721 import math 

722 

723 padding_h = int( 

724 math.ceil( 

725 (stride * (ih - 1) + 1 + dilation * (kernel_size_h - 1) - ih) / 2 

726 ) 

727 ) 

728 padding_w = int( 

729 math.ceil( 

730 (stride * (iw - 1) + 1 + dilation * (kernel_size_w - 1) - iw) / 2 

731 ) 

732 ) 

733 oh = int( 

734 (ih + 2 * padding_h - dilation * (kernel_size_h - 1) - 1) / stride + 1 

735 ) 

736 ow = int( 

737 (iw + 2 * padding_w - dilation * (kernel_size_w - 1) - 1) / stride + 1 

738 ) 

739 padding = max(padding_h, padding_w) 

740 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)[ 

741 ..., (oh - ih) :, (ow - iw) : 

742 ] 

743 elif padding == "valid": 

744 return Conv2d.apply(input, weight, bias, stride, 0, dilation, groups) 

745 else: 

746 raise ValueError( 

747 f"Unsupported padding string: {padding}, only 'valid'/'same' are allowed." 

748 ) 

749 else: 

750 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)