Coverage for src/flag_gems/runtime/backend/_mthreads/ops/conv2d.py: 0%

161 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger( 

11 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

12) 

13 

14 

15def conv2d_output_size( 

16 in_size: int, 

17 kernel_size: int, 

18 stride: int, 

19 padding: int, 

20 dilation: int, 

21) -> int: 

22 """ 

23 Determines the output size of a 2D convolution operation. 

24 

25 Args: 

26 in_size: Input size. 

27 kernel_size: Kernel size. 

28 stride: Stride. 

29 padding: Padding. 

30 dilation: Dilation. 

31 

32 Returns: 

33 Output size of 2D convolution. 

34 """ 

35 return (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 

36 

37 

38@libentry() 

39@triton.autotune( 

40 configs=runtime.get_tuned_config("conv2d_forward"), 

41 key=[ 

42 "in_n", 

43 "weight_c", 

44 "input_height", 

45 "input_width", 

46 "out_c", 

47 "out_height", 

48 "out_width", 

49 "weight_height", 

50 "weight_width", 

51 "stride_height", 

52 "stride_width", 

53 "padding_height", 

54 "padding_width", 

55 "groups", 

56 ], 

57) 

58@triton.jit 

59def conv2d_forward_kernel( 

60 input_pointer, 

61 weight_pointer, 

62 output_pointer, 

63 bias_pointer, 

64 in_n, 

65 input_height, 

66 input_width, 

67 out_c, 

68 out_height, 

69 out_width, 

70 input_n_stride, 

71 input_c_stride, 

72 input_height_stride, 

73 input_width_stride, 

74 weight_n_stride, 

75 weight_c_stride, 

76 weight_height_stride, 

77 weight_width_stride, 

78 output_n_stride, 

79 output_c_stride, 

80 output_height_stride, 

81 output_width_stride, 

82 weight_c: tl.constexpr, 

83 weight_height: tl.constexpr, 

84 weight_width: tl.constexpr, 

85 stride_height: tl.constexpr, 

86 stride_width: tl.constexpr, 

87 padding_height: tl.constexpr, 

88 padding_width: tl.constexpr, 

89 dilation_height: tl.constexpr, 

90 dilation_width: tl.constexpr, 

91 groups: tl.constexpr, 

92 BLOCK_NI_HO_WO: tl.constexpr, 

93 BLOCK_CI: tl.constexpr, 

94 BLOCK_CO: tl.constexpr, 

95): 

96 pid_ni_ho_wo = tl.program_id(0) 

97 pid_co = tl.program_id(1) 

98 pid_group = tl.program_id(2) 

99 

100 # caculate in_n out_height out_weight value in kernel 

101 ni_ho_wo_offset = pid_ni_ho_wo * BLOCK_NI_HO_WO + tl.arange(0, BLOCK_NI_HO_WO) 

102 ni_ho_offset = ni_ho_wo_offset // out_width 

103 in_n_point_value = ni_ho_offset // out_height 

104 output_height_point_value = ni_ho_offset % out_height 

105 output_width_point_value = ni_ho_wo_offset % out_width 

106 

107 # Load the input and weight pointers. input and weight are of shape 

108 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width] 

109 out_per_group_c = out_c // groups 

110 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

111 input_pointer += ( 

112 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c 

113 )[:, None] 

114 weight_pointer += ( 

115 weight_n_stride * output_c_offset 

116 + weight_n_stride * pid_group * out_per_group_c 

117 )[None, :] 

118 

119 accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32) 

120 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI 

121 for hwc in range(weight_height * weight_width * BLOCK_CI_COUNT): 

122 c = (hwc % BLOCK_CI_COUNT) * BLOCK_CI 

123 hw = hwc // BLOCK_CI_COUNT 

124 h = hw // weight_width 

125 w = hw % weight_width 

126 

127 input_c_offset = c + tl.arange(0, BLOCK_CI) 

128 input_height_offset = ( 

129 h * dilation_height 

130 - padding_height 

131 + stride_height * output_height_point_value 

132 ) 

133 input_width_offset = ( 

134 w * dilation_width - padding_width + stride_width * output_width_point_value 

135 ) 

136 

137 curr_input_pointer = ( 

138 input_pointer 

139 + (input_c_stride * input_c_offset)[None, :] 

140 + (input_height_stride * input_height_offset)[:, None] 

141 + (input_width_stride * input_width_offset)[:, None] 

142 ) 

143 curr_weight_pointer = ( 

144 weight_pointer 

145 + (weight_c_stride * input_c_offset)[:, None] 

146 + (weight_height_stride * h) 

147 + (weight_width_stride * w) 

148 ) 

149 

150 input_mask = ( 

151 (in_n_point_value < in_n)[:, None] 

152 & (input_c_offset < weight_c)[None, :] 

153 & (0 <= input_height_offset)[:, None] 

154 & (input_height_offset < input_height)[:, None] 

155 & (0 <= input_width_offset)[:, None] 

156 & (input_width_offset < input_width)[:, None] 

157 ) 

158 weight_mask = (input_c_offset < weight_c)[:, None] & ( 

159 output_c_offset < out_per_group_c 

160 )[None, :] 

161 

162 input_block = tl.load(curr_input_pointer, mask=input_mask) 

163 weight_block = tl.load(curr_weight_pointer, mask=weight_mask) 

164 

165 accum += tl.dot(input_block, weight_block, allow_tf32=False) 

166 bias_pointer += (pid_group[None] * out_per_group_c)[None, :] + output_c_offset[ 

167 None, : 

168 ] 

169 mask_bias = (output_c_offset < out_per_group_c)[None, :] 

170 bias = tl.load(bias_pointer, mask_bias).to(tl.float32) 

171 accum += bias 

172 output_pointer += ( 

173 (output_n_stride * in_n_point_value)[:, None] 

174 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :] 

175 + (output_height_stride * output_height_point_value)[:, None] 

176 + (output_width_stride * output_width_point_value)[:, None] 

177 ) 

178 output_mask = ( 

179 (in_n_point_value < in_n)[:, None] 

180 & (output_c_offset < out_per_group_c)[None, :] 

181 & (output_height_point_value < out_height)[:, None] 

182 & (output_width_point_value < out_width)[:, None] 

183 ) 

184 

185 tl.store(output_pointer, accum, mask=output_mask) 

186 

187 

188@libentry() 

189@triton.autotune( 

190 configs=runtime.get_tuned_config("conv2d_backward_weight"), 

191 key=[ 

192 "in_n", 

193 "input_height", 

194 "input_width", 

195 "weight_height", 

196 "weight_width", 

197 "input_c", 

198 "stride_height", 

199 "stride_width", 

200 "out_height", 

201 "out_width", 

202 "out_c", 

203 "padding_height", 

204 "padding_width", 

205 ], 

206) 

207@triton.jit 

208def conv2d_backward_kernel_weight( 

209 input_pointer, 

210 out_grad_pointer, 

211 weight_pointer, 

212 input_n_stride, 

213 input_c_stride, 

214 input_height_stride, 

215 input_width_stride, 

216 weight_n_stride, 

217 weight_c_stride, 

218 weight_height_stride, 

219 weight_width_stride, 

220 output_n_stride, 

221 output_c_stride, 

222 output_height_stride, 

223 output_width_stride, 

224 input_height, 

225 input_width, 

226 weight_height, 

227 weight_width, 

228 input_c, 

229 in_n, 

230 stride_height, 

231 stride_width, 

232 out_height, 

233 out_width, 

234 out_c, 

235 padding_height, 

236 padding_width, 

237 dilation_height, 

238 dilation_width, 

239 BLOCK_NO: tl.constexpr, 

240 BLOCK_CI_HK_WK: tl.constexpr, 

241 BLOCK_CO: tl.constexpr, 

242): 

243 # load out_grad n (groups out_c) ho wo 

244 # load weight (groups out_c) ci h w 

245 # load input n (groups ci) hi wi 

246 

247 # init pid and offset 0 for ci*hk*wk, 1 for groups, 2 for co. 

248 pid_ci_hk_wk = tl.program_id(0) 

249 pid_groups = tl.program_id(1) 

250 pid_co = tl.program_id(2) 

251 

252 # caculate ci weight_height weight_weight value in kernel 

253 ci_hk_wk_offset = pid_ci_hk_wk * BLOCK_CI_HK_WK + tl.arange(0, BLOCK_CI_HK_WK) 

254 ci_hk_offset = ci_hk_wk_offset // weight_width 

255 ci_point_value = ci_hk_offset // weight_height 

256 weight_height_point_value = ci_hk_offset % weight_height 

257 weight_width_point_value = ci_hk_wk_offset % weight_width 

258 

259 # caculate init pointer info of tensors 

260 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

261 out_grad_pointer += (output_c_offset * output_c_stride)[None, :] + ( 

262 pid_groups[None] * output_c_stride * out_c 

263 )[:, None] 

264 

265 weight_pointer += ( 

266 pid_groups * weight_n_stride * out_c + output_c_offset * weight_n_stride 

267 )[None, :] + ( 

268 ci_point_value * weight_c_stride 

269 + weight_height_point_value * weight_height_stride 

270 + weight_width_point_value * weight_width_stride 

271 )[ 

272 :, None 

273 ] 

274 

275 input_pointer += (ci_point_value * input_c_stride[None])[:, None] + ( 

276 pid_groups[None] * input_c_stride * input_c 

277 )[None, :] 

278 

279 # calculate the values of the input based on the width and height of the output by looping 

280 accum = tl.zeros((BLOCK_CI_HK_WK, BLOCK_CO), dtype=tl.float32) 

281 for h in range(0, out_height): 

282 for w in range(0, out_width): 

283 for n in range(0, in_n, BLOCK_NO): 

284 output_n_offset = n + tl.arange(0, BLOCK_NO) 

285 

286 # caculate input pointer to [cin*kh*kw, *] out_grad pointer to [*, out_c], N*hout*wout as reduce dim 

287 curr_out_grad_pointer = ( 

288 out_grad_pointer 

289 + ( 

290 output_n_offset * output_n_stride 

291 + h * output_height_stride 

292 + w * output_width_stride 

293 )[:, None] 

294 ) 

295 out_grad_mask = (output_n_offset < in_n)[:, None] & ( 

296 output_c_offset < out_c 

297 )[None, :] 

298 

299 curr_out_grad = tl.load(curr_out_grad_pointer, mask=out_grad_mask) 

300 

301 input_height_offset = ( 

302 weight_height_point_value * dilation_height 

303 - padding_height 

304 + stride_height * h 

305 ) 

306 

307 input_width_offset = ( 

308 weight_width_point_value * dilation_width 

309 - padding_width 

310 + stride_width * w 

311 ) 

312 

313 curr_input_pointer = ( 

314 input_pointer 

315 + (input_n_stride * output_n_offset)[None, :] 

316 + (input_height_stride * input_height_offset)[:, None] 

317 + (input_width_stride * input_width_offset)[:, None] 

318 ) 

319 input_mask = ( 

320 (output_n_offset < in_n)[None, :] 

321 & (ci_point_value < input_c)[:, None] 

322 & (0 <= input_height_offset)[:, None] 

323 & (input_height_offset < input_height)[:, None] 

324 & (0 <= input_width_offset)[:, None] 

325 & (input_width_offset < input_width)[:, None] 

326 ) 

327 

328 curr_input = tl.load(curr_input_pointer, mask=input_mask) 

329 accum += tl.dot(curr_input, curr_out_grad, allow_tf32=False) 

330 

331 weight_mask = ( 

332 (ci_point_value < input_c)[:, None] 

333 & (output_c_offset < out_c)[None, :] 

334 & (weight_height_point_value < weight_height)[:, None] 

335 & (weight_width_point_value < weight_width)[:, None] 

336 ) 

337 tl.store(weight_pointer, accum, weight_mask) 

338 

339 

340class Conv2d(torch.autograd.Function): 

341 @staticmethod 

342 def forward(ctx, input, weight, bias, stride, padding, dilation, groups): 

343 logger.debug("GEMS_MTHREADS CONV2D") 

344 assert weight.ndim == 4, "Weights must be 4D, received shape {weight.shape}" 

345 assert ( 

346 bias is None or bias.ndim == 1 

347 ), "Bias must be 1D, received shape {bias.shape}" 

348 

349 assert ( 

350 input.shape[1] == groups * weight.shape[1] 

351 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups" 

352 assert ( 

353 bias is None or weight.shape[0] == bias.shape[0] 

354 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape" 

355 

356 if isinstance(stride, (list, tuple)): 

357 stride_height, stride_width = stride 

358 else: 

359 stride_height = stride_width = stride 

360 

361 if isinstance(padding, (list, tuple)): 

362 padding_height, padding_width = padding 

363 else: 

364 padding_height = padding_width = padding 

365 

366 if isinstance(dilation, (list, tuple)): 

367 dilation_height, dilation_width = dilation 

368 else: 

369 dilation_height = dilation_width = dilation 

370 

371 in_n, _, input_height, input_width = input.shape 

372 out_c, weight_c, weight_height, weight_width = weight.shape 

373 out_height = conv2d_output_size( 

374 input_height, weight_height, stride_height, padding_height, dilation_height 

375 ) 

376 out_width = conv2d_output_size( 

377 input_width, weight_width, stride_width, padding_width, dilation_width 

378 ) 

379 

380 output_dtype = input.dtype 

381 output = torch.empty( 

382 (in_n, out_c, out_height, out_width), 

383 device=input.device, 

384 dtype=output_dtype, 

385 ) 

386 

387 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions, 

388 # BLOCK_CO along the out_c, 

389 # one group per cat 

390 grid = lambda META: ( 

391 triton.cdiv(in_n * out_height * out_width, META["BLOCK_NI_HO_WO"]), 

392 triton.cdiv(int(out_c // groups), META["BLOCK_CO"]), 

393 groups, 

394 ) 

395 

396 if bias is None: 

397 bias_pointer = torch.zeros(out_c, device=input.device, dtype=output_dtype) 

398 else: 

399 bias_pointer = bias 

400 conv2d_forward_kernel[grid]( 

401 input, 

402 weight, 

403 output, 

404 bias_pointer, 

405 in_n, 

406 input_height, 

407 input_width, 

408 out_c, 

409 out_height, 

410 out_width, 

411 *input.stride(), 

412 *weight.stride(), 

413 *output.stride(), 

414 weight_c, 

415 weight_height, 

416 weight_width, 

417 stride_height, 

418 stride_width, 

419 padding_height, 

420 padding_width, 

421 dilation_height, 

422 dilation_width, 

423 groups=groups, 

424 ) 

425 

426 ctx.save_for_backward(weight, input, bias) 

427 

428 ctx.stride = (stride_height, stride_width) 

429 ctx.padding = (padding_height, padding_width) 

430 ctx.dilation = (dilation_height, dilation_width) 

431 

432 ctx.weight_info = (int(out_c / groups), weight_c, weight_height, weight_width) 

433 ctx.input_info = (in_n, input_height, input_width) 

434 ctx.out_info = (out_height, out_width) 

435 

436 ctx.device = input.device 

437 ctx.groups = groups 

438 

439 return output 

440 

441 @staticmethod 

442 def backward(ctx, out_grad): 

443 logger.debug("GEMS_MTHREADS CONV2D VJP") 

444 (weight, input, bias) = ctx.saved_tensors 

445 # (out_c equals origin cout divide groups) 

446 out_c, weight_c, weight_height, weight_width = ctx.weight_info 

447 in_n, input_height, input_width = ctx.input_info 

448 out_height, out_width = ctx.out_info 

449 

450 device = ctx.device 

451 groups = ctx.groups 

452 

453 stride_height, stride_width = ctx.stride 

454 dilation_height, dilation_width = ctx.dilation 

455 padding_height, padding_width = ctx.padding 

456 

457 revert_padding_height = dilation_height * (weight_height - 1) - padding_height 

458 revert_padding_width = dilation_width * (weight_width - 1) - padding_width 

459 revert_weight = weight.clone() 

460 revert_weight = torch.flip(revert_weight, dims=[2, 3]).contiguous() 

461 

462 if groups != 1: 

463 revert_weight = revert_weight.reshape( 

464 groups, out_c, weight_c, weight_height, weight_width 

465 ) 

466 revert_weight = revert_weight.transpose(1, 2) 

467 revert_weight = revert_weight.reshape( 

468 groups * weight_c, out_c, weight_height, weight_width 

469 ).contiguous() 

470 else: 

471 revert_weight = revert_weight.transpose(0, 1).contiguous() 

472 

473 new_out_height = out_grad.shape[2] + (stride_height - 1) * ( 

474 out_grad.shape[2] - 1 

475 ) 

476 new_out_width = out_grad.shape[3] + (stride_width - 1) * (out_grad.shape[3] - 1) 

477 

478 new_out = torch.zeros( 

479 out_grad.shape[0], 

480 out_grad.shape[1], 

481 new_out_height, 

482 new_out_width, 

483 device=device, 

484 dtype=out_grad.dtype, 

485 ) 

486 

487 # copy out_grad to new_out 

488 if stride_height > 1 or stride_width > 1: 

489 for i in range(out_grad.shape[2]): 

490 for j in range(out_grad.shape[3]): 

491 new_out[:, :, i * (stride_height), j * (stride_width)] = out_grad[ 

492 :, :, i, j 

493 ] 

494 else: 

495 new_out = out_grad 

496 

497 input_back = torch.zeros( 

498 in_n, 

499 weight_c * groups, 

500 input_height, 

501 input_width, 

502 dtype=torch.float32, 

503 device=device, 

504 ) 

505 

506 grid = lambda META: ( 

507 triton.cdiv( 

508 out_grad.shape[0] * input_height * input_width, META["BLOCK_NI_HO_WO"] 

509 ), 

510 triton.cdiv(int(weight_c), META["BLOCK_CO"]), 

511 groups, 

512 ) 

513 bias_zero = torch.zeros(groups * weight_c, device=device, dtype=out_grad.dtype) 

514 conv2d_forward_kernel[grid]( 

515 new_out, 

516 revert_weight, 

517 input_back, 

518 bias_zero, 

519 out_grad.shape[0], 

520 new_out_height, 

521 new_out_width, 

522 groups * weight_c, 

523 input_height, 

524 input_width, 

525 *new_out.stride(), 

526 *revert_weight.stride(), 

527 *input_back.stride(), 

528 out_c, 

529 weight_height, 

530 weight_width, 

531 1, 

532 1, 

533 revert_padding_height, 

534 revert_padding_width, 

535 dilation_height, 

536 dilation_width, 

537 groups=groups, 

538 ) 

539 

540 weight_back = torch.zeros( 

541 out_c * groups, 

542 weight_c, 

543 weight_height, 

544 weight_width, 

545 dtype=weight.dtype, 

546 device=device, 

547 ) 

548 

549 grid_weight = lambda meta: ( 

550 triton.cdiv( 

551 weight_c * weight_height * weight_width, meta["BLOCK_CI_HK_WK"] 

552 ), 

553 groups, 

554 triton.cdiv(out_c, meta["BLOCK_CO"]), 

555 ) 

556 conv2d_backward_kernel_weight[grid_weight]( 

557 input, 

558 out_grad, 

559 weight_back, 

560 *input.stride(), 

561 *weight.stride(), 

562 *out_grad.stride(), 

563 input_height, 

564 input_width, 

565 weight_height, 

566 weight_width, 

567 weight_c, 

568 in_n, 

569 stride_height, 

570 stride_width, 

571 out_height, 

572 out_width, 

573 out_c, 

574 padding_height, 

575 padding_width, 

576 dilation_height, 

577 dilation_width, 

578 ) 

579 if bias is not None: 

580 bias_grad = out_grad.to(torch.float64).sum(dim=(0, 2, 3)) 

581 else: 

582 bias_grad = None 

583 return ( 

584 input_back, 

585 weight_back, 

586 bias_grad, 

587 None, 

588 None, 

589 None, 

590 None, 

591 ) 

592 

593 

594# todo test SymInt[2] of stride or padding 

595def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 

596 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)