Coverage for src/flag_gems/ops/conv2d.py: 42%

178 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def conv2d_output_size( 

15 in_size: int, 

16 kernel_size: int, 

17 stride: int, 

18 padding: int, 

19 dilation: int, 

20) -> int: 

21 """ 

22 Determines the output size of a 2D convolution operation. 

23 

24 Args: 

25 in_size: Input size. 

26 kernel_size: Kernel size. 

27 stride: Stride. 

28 padding: Padding. 

29 dilation: Dilation. 

30 

31 Returns: 

32 Output size of 2D convolution. 

33 """ 

34 return (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 

35 

36 

37@libentry() 

38@triton.autotune( 

39 configs=runtime.get_tuned_config("conv2d_forward"), 

40 key=[ 

41 "in_n", 

42 "weight_c", 

43 "input_height", 

44 "input_width", 

45 "out_c", 

46 "out_height", 

47 "out_width", 

48 "weight_height", 

49 "weight_width", 

50 "stride_height", 

51 "stride_width", 

52 "padding_height", 

53 "padding_width", 

54 "groups", 

55 ], 

56) 

57@triton.jit 

58def conv2d_forward_kernel( 

59 input_pointer, 

60 weight_pointer, 

61 output_pointer, 

62 bias_pointer, 

63 in_n, 

64 input_height, 

65 input_width, 

66 out_c, 

67 out_height, 

68 out_width, 

69 input_n_stride, 

70 input_c_stride, 

71 input_height_stride, 

72 input_width_stride, 

73 weight_n_stride, 

74 weight_c_stride, 

75 weight_height_stride, 

76 weight_width_stride, 

77 output_n_stride, 

78 output_c_stride, 

79 output_height_stride, 

80 output_width_stride, 

81 weight_c: tl.constexpr, 

82 weight_height: tl.constexpr, 

83 weight_width: tl.constexpr, 

84 stride_height: tl.constexpr, 

85 stride_width: tl.constexpr, 

86 padding_height: tl.constexpr, 

87 padding_width: tl.constexpr, 

88 dilation_height: tl.constexpr, 

89 dilation_width: tl.constexpr, 

90 groups: tl.constexpr, 

91 BLOCK_NI_HO_WO: tl.constexpr, 

92 BLOCK_CI: tl.constexpr, 

93 BLOCK_CO: tl.constexpr, 

94): 

95 pid_ni_ho_wo = tl.program_id(0) 

96 pid_co = tl.program_id(1) 

97 pid_group = tl.program_id(2) 

98 

99 # caculate in_n out_height out_weight value in kernel 

100 ni_ho_wo_offset = pid_ni_ho_wo * BLOCK_NI_HO_WO + tl.arange(0, BLOCK_NI_HO_WO) 

101 ni_ho_offset = ni_ho_wo_offset // out_width 

102 in_n_point_value = ni_ho_offset // out_height 

103 output_height_point_value = ni_ho_offset % out_height 

104 output_width_point_value = ni_ho_wo_offset % out_width 

105 

106 # Load the input and weight pointers. input and weight are of shape 

107 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width] 

108 out_per_group_c = out_c // groups 

109 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

110 input_pointer += ( 

111 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c 

112 )[:, None] 

113 weight_pointer += ( 

114 weight_n_stride * output_c_offset 

115 + weight_n_stride * pid_group * out_per_group_c 

116 )[None, :] 

117 

118 accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32) 

119 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI 

120 for hwc in range(weight_height * weight_width * BLOCK_CI_COUNT): 

121 c = (hwc % BLOCK_CI_COUNT) * BLOCK_CI 

122 hw = hwc // BLOCK_CI_COUNT 

123 h = hw // weight_width 

124 w = hw % weight_width 

125 

126 input_c_offset = c + tl.arange(0, BLOCK_CI) 

127 input_height_offset = ( 

128 h * dilation_height 

129 - padding_height 

130 + stride_height * output_height_point_value 

131 ) 

132 input_width_offset = ( 

133 w * dilation_width - padding_width + stride_width * output_width_point_value 

134 ) 

135 

136 curr_input_pointer = ( 

137 input_pointer 

138 + (input_c_stride * input_c_offset)[None, :] 

139 + (input_height_stride * input_height_offset)[:, None] 

140 + (input_width_stride * input_width_offset)[:, None] 

141 ) 

142 curr_weight_pointer = ( 

143 weight_pointer 

144 + (weight_c_stride * input_c_offset)[:, None] 

145 + (weight_height_stride * h) 

146 + (weight_width_stride * w) 

147 ) 

148 

149 input_mask = ( 

150 (in_n_point_value < in_n)[:, None] 

151 & (input_c_offset < weight_c)[None, :] 

152 & (0 <= input_height_offset)[:, None] 

153 & (input_height_offset < input_height)[:, None] 

154 & (0 <= input_width_offset)[:, None] 

155 & (input_width_offset < input_width)[:, None] 

156 ) 

157 weight_mask = (input_c_offset < weight_c)[:, None] & ( 

158 output_c_offset < out_per_group_c 

159 )[None, :] 

160 

161 input_block = tl.load(curr_input_pointer, mask=input_mask) 

162 weight_block = tl.load(curr_weight_pointer, mask=weight_mask) 

163 

164 accum += tl.dot(input_block, weight_block, allow_tf32=False) 

165 bias_pointer += (pid_group[None] * out_per_group_c)[None, :] + output_c_offset[ 

166 None, : 

167 ] 

168 mask_bias = (output_c_offset < out_per_group_c)[None, :] 

169 bias = tl.load(bias_pointer, mask_bias).to(tl.float32) 

170 accum += bias 

171 output_pointer += ( 

172 (output_n_stride * in_n_point_value)[:, None] 

173 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :] 

174 + (output_height_stride * output_height_point_value)[:, None] 

175 + (output_width_stride * output_width_point_value)[:, None] 

176 ) 

177 output_mask = ( 

178 (in_n_point_value < in_n)[:, None] 

179 & (output_c_offset < out_per_group_c)[None, :] 

180 & (output_height_point_value < out_height)[:, None] 

181 & (output_width_point_value < out_width)[:, None] 

182 ) 

183 

184 tl.store(output_pointer, accum, mask=output_mask) 

185 

186 

187@libentry() 

188@triton.autotune( 

189 configs=runtime.get_tuned_config("conv2d_backward_weight"), 

190 key=[ 

191 "in_n", 

192 "input_height", 

193 "input_width", 

194 "weight_height", 

195 "weight_width", 

196 "input_c", 

197 "stride_height", 

198 "stride_width", 

199 "out_height", 

200 "out_width", 

201 "out_c", 

202 "padding_height", 

203 "padding_width", 

204 ], 

205) 

206@triton.jit 

207def conv2d_backward_kernel_weight( 

208 input_pointer, 

209 out_grad_pointer, 

210 weight_pointer, 

211 input_n_stride, 

212 input_c_stride, 

213 input_height_stride, 

214 input_width_stride, 

215 weight_n_stride, 

216 weight_c_stride, 

217 weight_height_stride, 

218 weight_width_stride, 

219 output_n_stride, 

220 output_c_stride, 

221 output_height_stride, 

222 output_width_stride, 

223 input_height, 

224 input_width, 

225 weight_height, 

226 weight_width, 

227 input_c, 

228 in_n, 

229 stride_height, 

230 stride_width, 

231 out_height, 

232 out_width, 

233 out_c, 

234 padding_height, 

235 padding_width, 

236 dilation_height, 

237 dilation_width, 

238 BLOCK_NO: tl.constexpr, 

239 BLOCK_CI_HK_WK: tl.constexpr, 

240 BLOCK_CO: tl.constexpr, 

241): 

242 # load out_grad n (groups out_c) ho wo 

243 # load weight (groups out_c) ci h w 

244 # load input n (groups ci) hi wi 

245 

246 # init pid and offset 0 for ci*hk*wk, 1 for groups, 2 for co. 

247 pid_ci_hk_wk = tl.program_id(0) 

248 pid_groups = tl.program_id(1) 

249 pid_co = tl.program_id(2) 

250 

251 # caculate ci weight_height weight_weight value in kernel 

252 ci_hk_wk_offset = pid_ci_hk_wk * BLOCK_CI_HK_WK + tl.arange(0, BLOCK_CI_HK_WK) 

253 ci_hk_offset = ci_hk_wk_offset // weight_width 

254 ci_point_value = ci_hk_offset // weight_height 

255 weight_height_point_value = ci_hk_offset % weight_height 

256 weight_width_point_value = ci_hk_wk_offset % weight_width 

257 

258 # caculate init pointer info of tensors 

259 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

260 out_grad_pointer += (output_c_offset * output_c_stride)[None, :] + ( 

261 pid_groups[None] * output_c_stride * out_c 

262 )[:, None] 

263 

264 weight_pointer += ( 

265 pid_groups * weight_n_stride * out_c + output_c_offset * weight_n_stride 

266 )[None, :] + ( 

267 ci_point_value * weight_c_stride 

268 + weight_height_point_value * weight_height_stride 

269 + weight_width_point_value * weight_width_stride 

270 )[ 

271 :, None 

272 ] 

273 

274 input_pointer += (ci_point_value * input_c_stride[None])[:, None] + ( 

275 pid_groups[None] * input_c_stride * input_c 

276 )[None, :] 

277 

278 # calculate the values of the input based on the width and height of the output by looping 

279 accum = tl.zeros((BLOCK_CI_HK_WK, BLOCK_CO), dtype=tl.float32) 

280 for h in range(0, out_height): 

281 for w in range(0, out_width): 

282 for n in range(0, in_n, BLOCK_NO): 

283 output_n_offset = n + tl.arange(0, BLOCK_NO) 

284 

285 # caculate input pointer to [cin*kh*kw, *] out_grad pointer to [*, out_c], N*hout*wout as reduce dim 

286 curr_out_grad_pointer = ( 

287 out_grad_pointer 

288 + ( 

289 output_n_offset * output_n_stride 

290 + h * output_height_stride 

291 + w * output_width_stride 

292 )[:, None] 

293 ) 

294 out_grad_mask = (output_n_offset < in_n)[:, None] & ( 

295 output_c_offset < out_c 

296 )[None, :] 

297 

298 curr_out_grad = tl.load(curr_out_grad_pointer, mask=out_grad_mask) 

299 

300 input_height_offset = ( 

301 weight_height_point_value * dilation_height 

302 - padding_height 

303 + stride_height * h 

304 ) 

305 

306 input_width_offset = ( 

307 weight_width_point_value * dilation_width 

308 - padding_width 

309 + stride_width * w 

310 ) 

311 

312 curr_input_pointer = ( 

313 input_pointer 

314 + (input_n_stride * output_n_offset)[None, :] 

315 + (input_height_stride * input_height_offset)[:, None] 

316 + (input_width_stride * input_width_offset)[:, None] 

317 ) 

318 input_mask = ( 

319 (output_n_offset < in_n)[None, :] 

320 & (ci_point_value < input_c)[:, None] 

321 & (0 <= input_height_offset)[:, None] 

322 & (input_height_offset < input_height)[:, None] 

323 & (0 <= input_width_offset)[:, None] 

324 & (input_width_offset < input_width)[:, None] 

325 ) 

326 

327 curr_input = tl.load(curr_input_pointer, mask=input_mask) 

328 accum += tl.dot(curr_input, curr_out_grad, allow_tf32=False) 

329 

330 weight_mask = ( 

331 (ci_point_value < input_c)[:, None] 

332 & (output_c_offset < out_c)[None, :] 

333 & (weight_height_point_value < weight_height)[:, None] 

334 & (weight_width_point_value < weight_width)[:, None] 

335 ) 

336 tl.store(weight_pointer, accum, weight_mask) 

337 

338 

339class Conv2d(torch.autograd.Function): 

340 @staticmethod 

341 def forward(ctx, input, weight, bias, stride, padding, dilation, groups): 

342 logger.debug("GEMS CONV2D") 

343 assert weight.ndim == 4, "Weights must be 4D, received shape {weight.shape}" 

344 assert ( 

345 bias is None or bias.ndim == 1 

346 ), "Bias must be 1D, received shape {bias.shape}" 

347 

348 assert ( 

349 input.shape[1] == groups * weight.shape[1] 

350 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups" 

351 assert ( 

352 bias is None or weight.shape[0] == bias.shape[0] 

353 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape" 

354 

355 if isinstance(stride, (list, tuple)): 

356 stride_height, stride_width = stride 

357 else: 

358 stride_height = stride_width = stride 

359 

360 if isinstance(padding, (list, tuple)): 

361 padding_height, padding_width = padding 

362 else: 

363 padding_height = padding_width = padding 

364 

365 if isinstance(dilation, (list, tuple)): 

366 dilation_height, dilation_width = dilation 

367 else: 

368 dilation_height = dilation_width = dilation 

369 

370 in_n, _, input_height, input_width = input.shape 

371 out_c, weight_c, weight_height, weight_width = weight.shape 

372 out_height = conv2d_output_size( 

373 input_height, weight_height, stride_height, padding_height, dilation_height 

374 ) 

375 out_width = conv2d_output_size( 

376 input_width, weight_width, stride_width, padding_width, dilation_width 

377 ) 

378 

379 output_dtype = input.dtype 

380 output = torch.empty( 

381 (in_n, out_c, out_height, out_width), 

382 device=input.device, 

383 dtype=output_dtype, 

384 ) 

385 

386 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions, 

387 # BLOCK_CO along the out_c, 

388 # one group per cat 

389 grid = lambda META: ( 

390 triton.cdiv(in_n * out_height * out_width, META["BLOCK_NI_HO_WO"]), 

391 triton.cdiv(int(out_c // groups), META["BLOCK_CO"]), 

392 groups, 

393 ) 

394 

395 if bias is None: 

396 bias_pointer = torch.zeros(out_c, device=input.device, dtype=output_dtype) 

397 else: 

398 bias_pointer = bias 

399 conv2d_forward_kernel[grid]( 

400 input, 

401 weight, 

402 output, 

403 bias_pointer, 

404 in_n, 

405 input_height, 

406 input_width, 

407 out_c, 

408 out_height, 

409 out_width, 

410 *input.stride(), 

411 *weight.stride(), 

412 *output.stride(), 

413 weight_c, 

414 weight_height, 

415 weight_width, 

416 stride_height, 

417 stride_width, 

418 padding_height, 

419 padding_width, 

420 dilation_height, 

421 dilation_width, 

422 groups=groups, 

423 ) 

424 

425 ctx.save_for_backward(weight, input, bias) 

426 

427 ctx.stride = (stride_height, stride_width) 

428 ctx.padding = (padding_height, padding_width) 

429 ctx.dilation = (dilation_height, dilation_width) 

430 

431 ctx.weight_info = (int(out_c / groups), weight_c, weight_height, weight_width) 

432 ctx.input_info = (in_n, input_height, input_width) 

433 ctx.out_info = (out_height, out_width) 

434 

435 ctx.device = input.device 

436 ctx.groups = groups 

437 

438 return output 

439 

440 @staticmethod 

441 def backward(ctx, out_grad): 

442 logger.debug("GEMS CONV2D VJP") 

443 (weight, input, bias) = ctx.saved_tensors 

444 # (out_c equals origin cout divide groups) 

445 out_c, weight_c, weight_height, weight_width = ctx.weight_info 

446 in_n, input_height, input_width = ctx.input_info 

447 out_height, out_width = ctx.out_info 

448 

449 device = ctx.device 

450 groups = ctx.groups 

451 

452 stride_height, stride_width = ctx.stride 

453 dilation_height, dilation_width = ctx.dilation 

454 padding_height, padding_width = ctx.padding 

455 

456 revert_padding_height = dilation_height * (weight_height - 1) - padding_height 

457 revert_padding_width = dilation_width * (weight_width - 1) - padding_width 

458 revert_weight = weight.clone() 

459 revert_weight = torch.flip(revert_weight, dims=[2, 3]).contiguous() 

460 

461 if groups != 1: 

462 revert_weight = revert_weight.reshape( 

463 groups, out_c, weight_c, weight_height, weight_width 

464 ) 

465 revert_weight = revert_weight.transpose(1, 2) 

466 revert_weight = revert_weight.reshape( 

467 groups * weight_c, out_c, weight_height, weight_width 

468 ).contiguous() 

469 else: 

470 revert_weight = revert_weight.transpose(0, 1).contiguous() 

471 

472 new_out_height = out_grad.shape[2] + (stride_height - 1) * ( 

473 out_grad.shape[2] - 1 

474 ) 

475 new_out_width = out_grad.shape[3] + (stride_width - 1) * (out_grad.shape[3] - 1) 

476 

477 new_out = torch.zeros( 

478 out_grad.shape[0], 

479 out_grad.shape[1], 

480 new_out_height, 

481 new_out_width, 

482 device=device, 

483 dtype=out_grad.dtype, 

484 ) 

485 

486 # copy out_grad to new_out 

487 if stride_height > 1 or stride_width > 1: 

488 for i in range(out_grad.shape[2]): 

489 for j in range(out_grad.shape[3]): 

490 new_out[:, :, i * (stride_height), j * (stride_width)] = out_grad[ 

491 :, :, i, j 

492 ] 

493 else: 

494 new_out = out_grad 

495 

496 input_back = torch.zeros( 

497 in_n, 

498 weight_c * groups, 

499 input_height, 

500 input_width, 

501 dtype=torch.float32, 

502 device=device, 

503 ) 

504 

505 grid = lambda META: ( 

506 triton.cdiv( 

507 out_grad.shape[0] * input_height * input_width, META["BLOCK_NI_HO_WO"] 

508 ), 

509 triton.cdiv(int(weight_c), META["BLOCK_CO"]), 

510 groups, 

511 ) 

512 bias_zero = torch.zeros(groups * weight_c, device=device, dtype=out_grad.dtype) 

513 conv2d_forward_kernel[grid]( 

514 new_out, 

515 revert_weight, 

516 input_back, 

517 bias_zero, 

518 out_grad.shape[0], 

519 new_out_height, 

520 new_out_width, 

521 groups * weight_c, 

522 input_height, 

523 input_width, 

524 *new_out.stride(), 

525 *revert_weight.stride(), 

526 *input_back.stride(), 

527 out_c, 

528 weight_height, 

529 weight_width, 

530 1, 

531 1, 

532 revert_padding_height, 

533 revert_padding_width, 

534 dilation_height, 

535 dilation_width, 

536 groups=groups, 

537 ) 

538 

539 weight_back = torch.zeros( 

540 out_c * groups, 

541 weight_c, 

542 weight_height, 

543 weight_width, 

544 dtype=weight.dtype, 

545 device=device, 

546 ) 

547 

548 grid_weight = lambda meta: ( 

549 triton.cdiv( 

550 weight_c * weight_height * weight_width, meta["BLOCK_CI_HK_WK"] 

551 ), 

552 groups, 

553 triton.cdiv(out_c, meta["BLOCK_CO"]), 

554 ) 

555 conv2d_backward_kernel_weight[grid_weight]( 

556 input, 

557 out_grad, 

558 weight_back, 

559 *input.stride(), 

560 *weight.stride(), 

561 *out_grad.stride(), 

562 input_height, 

563 input_width, 

564 weight_height, 

565 weight_width, 

566 weight_c, 

567 in_n, 

568 stride_height, 

569 stride_width, 

570 out_height, 

571 out_width, 

572 out_c, 

573 padding_height, 

574 padding_width, 

575 dilation_height, 

576 dilation_width, 

577 ) 

578 if bias is not None: 

579 bias_grad = out_grad.sum(dim=(0, 2, 3)) 

580 else: 

581 bias_grad = None 

582 return ( 

583 input_back, 

584 weight_back, 

585 bias_grad, 

586 None, 

587 None, 

588 None, 

589 None, 

590 ) 

591 

592 

593# todo test SymInt[2] of stride or padding 

594def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 

595 if isinstance(padding, str): 

596 if padding == "same": 

597 assert ( 

598 stride == 1 

599 ), "Doesn't support any stride values other than 1 \ 

600 in padding = 'same' mode, received stride value {stride}" 

601 ih = input.shape[-2] 

602 iw = input.shape[-1] 

603 kernel_size_h = weight.shape[-2] 

604 kernel_size_w = weight.shape[-1] 

605 padding_h = int( 

606 math.ceil( 

607 (stride * (ih - 1) + 1 + dilation * (kernel_size_h - 1) - ih) / 2 

608 ) 

609 ) 

610 padding_w = int( 

611 math.ceil( 

612 (stride * (iw - 1) + 1 + dilation * (kernel_size_w - 1) - iw) / 2 

613 ) 

614 ) 

615 oh = int( 

616 (ih + 2 * padding_h - dilation * (kernel_size_h - 1) - 1) / stride + 1 

617 ) 

618 ow = int( 

619 (iw + 2 * padding_w - dilation * (kernel_size_w - 1) - 1) / stride + 1 

620 ) 

621 padding = max(padding_h, padding_w) 

622 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)[ 

623 ..., (oh - ih) :, (ow - iw) : 

624 ] 

625 elif padding == "valid": 

626 return Conv2d.apply(input, weight, bias, stride, 0, dilation, groups) 

627 else: 

628 raise ValueError( 

629 f"Unsupported padding string: {padding}, only'valild'/'same' are allowed." 

630 ) 

631 else: 

632 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)