Coverage for src/flag_gems/ops/conv3d.py: 57%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.ops.conv2d import conv2d_output_size 

10from flag_gems.utils import libentry 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15def conv3d_output_size( 

16 in_size: int, 

17 kernel_size: int, 

18 stride: int, 

19 padding: int, 

20 dilation: int, 

21) -> int: 

22 """ 

23 Determines the output size of a 3D convolution operation. 

24 

25 Args: 

26 in_size: Input size. 

27 kernel_size: Kernel size. 

28 stride: Stride. 

29 padding: Padding. 

30 dilation: Dilation. 

31 

32 Returns: 

33 Output size of 3D convolution. 

34 """ 

35 return conv2d_output_size(in_size, kernel_size, stride, padding, dilation) 

36 

37 

38@libentry() 

39@triton.autotune( 

40 configs=runtime.get_tuned_config("conv3d_forward"), 

41 key=[ 

42 "in_n", 

43 "weight_c", 

44 "input_depth", 

45 "input_height", 

46 "input_width", 

47 "out_c", 

48 "out_depth", 

49 "out_height", 

50 "out_width", 

51 "weight_depth", 

52 "weight_height", 

53 "weight_width", 

54 "stride_depth", 

55 "stride_height", 

56 "stride_width", 

57 "padding_depth", 

58 "padding_height", 

59 "padding_width", 

60 "groups", 

61 ], 

62) 

63@triton.jit 

64def conv3d_forward_kernel( 

65 input_pointer, 

66 weight_pointer, 

67 output_pointer, 

68 bias_pointer, 

69 in_n, 

70 input_depth, 

71 input_height, 

72 input_width, 

73 out_c, 

74 out_depth, 

75 out_height, 

76 out_width, 

77 input_n_stride, 

78 input_c_stride, 

79 input_depth_stride, 

80 input_height_stride, 

81 input_width_stride, 

82 weight_n_stride, 

83 weight_c_stride, 

84 weight_depth_stride, 

85 weight_height_stride, 

86 weight_width_stride, 

87 output_n_stride, 

88 output_c_stride, 

89 output_depth_stride, 

90 output_height_stride, 

91 output_width_stride, 

92 weight_c: tl.constexpr, 

93 weight_depth: tl.constexpr, 

94 weight_height: tl.constexpr, 

95 weight_width: tl.constexpr, 

96 stride_depth: tl.constexpr, 

97 stride_height: tl.constexpr, 

98 stride_width: tl.constexpr, 

99 padding_depth: tl.constexpr, 

100 padding_height: tl.constexpr, 

101 padding_width: tl.constexpr, 

102 dilation_depth: tl.constexpr, 

103 dilation_height: tl.constexpr, 

104 dilation_width: tl.constexpr, 

105 groups: tl.constexpr, 

106 BLOCK_NI_DO_HO_WO: tl.constexpr, 

107 BLOCK_CI: tl.constexpr, 

108 BLOCK_CO: tl.constexpr, 

109): 

110 pid_ni_do_ho_wo = tl.program_id(0) 

111 pid_co = tl.program_id(1) 

112 pid_group = tl.program_id(2) 

113 

114 # caculate in_n out_depth out_height out_weight value in kernel 

115 ni_do_ho_wo_offset = pid_ni_do_ho_wo * BLOCK_NI_DO_HO_WO + tl.arange( 

116 0, BLOCK_NI_DO_HO_WO 

117 ) 

118 ni_do_ho_offset = ni_do_ho_wo_offset // out_width 

119 ni_do_offset = ni_do_ho_offset // out_height 

120 in_n_point_value = ni_do_offset // out_depth 

121 output_depth_point_value = ni_do_offset % out_depth 

122 output_height_point_value = ni_do_ho_offset % out_height 

123 output_width_point_value = ni_do_ho_wo_offset % out_width 

124 

125 # Load the input and weight pointers. input and weight are of shape 

126 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width] 

127 out_per_group_c = out_c // groups 

128 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

129 input_pointer += ( 

130 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c 

131 )[:, None] 

132 weight_pointer += ( 

133 weight_n_stride * output_c_offset 

134 + weight_n_stride * pid_group * out_per_group_c 

135 )[None, :] 

136 

137 accum = tl.zeros((BLOCK_NI_DO_HO_WO, BLOCK_CO), dtype=tl.float32) 

138 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI 

139 for dhwc in range(weight_depth * weight_height * weight_width * BLOCK_CI_COUNT): 

140 c = (dhwc % BLOCK_CI_COUNT) * BLOCK_CI 

141 dhw = dhwc // BLOCK_CI_COUNT 

142 dh = dhw // weight_width 

143 d = dh // weight_height 

144 h = dh % weight_height 

145 w = dhw % weight_width 

146 

147 input_c_offset = c + tl.arange(0, BLOCK_CI) 

148 input_depth_offset = ( 

149 d * dilation_depth - padding_depth + stride_depth * output_depth_point_value 

150 ) 

151 input_height_offset = ( 

152 h * dilation_height 

153 - padding_height 

154 + stride_height * output_height_point_value 

155 ) 

156 input_width_offset = ( 

157 w * dilation_width - padding_width + stride_width * output_width_point_value 

158 ) 

159 

160 curr_input_pointer = ( 

161 input_pointer 

162 + (input_c_stride * input_c_offset)[None, :] 

163 + (input_depth_stride * input_depth_offset)[:, None] 

164 + (input_height_stride * input_height_offset)[:, None] 

165 + (input_width_stride * input_width_offset)[:, None] 

166 ) 

167 curr_weight_pointer = ( 

168 weight_pointer 

169 + (weight_c_stride * input_c_offset)[:, None] 

170 + (weight_depth_stride * d) 

171 + (weight_height_stride * h) 

172 + (weight_width_stride * w) 

173 ) 

174 

175 input_mask = ( 

176 (in_n_point_value < in_n)[:, None] 

177 & (input_c_offset < weight_c)[None, :] 

178 & (0 <= input_depth_offset)[:, None] 

179 & (input_depth_offset < input_depth)[:, None] 

180 & (0 <= input_height_offset)[:, None] 

181 & (input_height_offset < input_height)[:, None] 

182 & (0 <= input_width_offset)[:, None] 

183 & (input_width_offset < input_width)[:, None] 

184 ) 

185 weight_mask = (input_c_offset < weight_c)[:, None] & ( 

186 output_c_offset < out_per_group_c 

187 )[None, :] 

188 

189 input_block = tl.load(curr_input_pointer, mask=input_mask) 

190 weight_block = tl.load(curr_weight_pointer, mask=weight_mask) 

191 

192 accum += tl.dot(input_block, weight_block, allow_tf32=False) 

193 bias_pointer += (pid_group[None] * out_per_group_c)[None, :] + output_c_offset[ 

194 None, : 

195 ] 

196 mask_bias = (output_c_offset < out_per_group_c)[None, :] 

197 bias = tl.load(bias_pointer, mask_bias).to(tl.float32) 

198 accum += bias 

199 output_pointer += ( 

200 (output_n_stride * in_n_point_value)[:, None] 

201 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :] 

202 + (output_depth_stride * output_depth_point_value)[:, None] 

203 + (output_height_stride * output_height_point_value)[:, None] 

204 + (output_width_stride * output_width_point_value)[:, None] 

205 ) 

206 output_mask = ( 

207 (in_n_point_value < in_n)[:, None] 

208 & (output_c_offset < out_per_group_c)[None, :] 

209 & (output_depth_point_value < out_depth)[:, None] 

210 & (output_height_point_value < out_height)[:, None] 

211 & (output_width_point_value < out_width)[:, None] 

212 ) 

213 

214 tl.store(output_pointer, accum, mask=output_mask) 

215 

216 

217# class Conv3d(torch.autograd.Function): 

218# @staticmethod 

219# def forward(ctx, input, weight, bias, stride, padding, dilation, groups): 

220# pass 

221 

222 

223def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 

224 logger.debug("GEMS CONV3D") 

225 assert weight.ndim == 5, "Weights must be 5D, received shape {weight.shape}" 

226 assert ( 

227 bias is None or bias.ndim == 1 

228 ), "Bias must be 1D, received shape {bias.shape}" 

229 

230 assert ( 

231 input.shape[1] == groups * weight.shape[1] 

232 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups" 

233 assert ( 

234 bias is None or weight.shape[0] == bias.shape[0] 

235 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape" 

236 

237 if isinstance(stride, (list, tuple)): 

238 stride_depth, stride_height, stride_width = stride 

239 else: 

240 stride_depth = stride_height = stride_width = stride 

241 

242 if isinstance(dilation, (list, tuple)): 

243 dilation_depth, dilation_height, dilation_width = dilation 

244 else: 

245 dilation_depth = dilation_height = dilation_width = dilation 

246 

247 if isinstance(padding, str): 

248 if padding == "same": 

249 assert ( 

250 stride_depth == 1 and stride_height == 1 and stride_width == 1 

251 ), \ 

252 "Doesn't support any stride values other than 1 in padding = 'same' mode, \ 

253 received stride value {stride}" 

254 id = input.shape[-3] 

255 ih = input.shape[-2] 

256 iw = input.shape[-1] 

257 kernel_size_d = weight.shape[-3] 

258 kernel_size_h = weight.shape[-2] 

259 kernel_size_w = weight.shape[-1] 

260 padding_depth = math.ceil( 

261 ( 

262 stride_depth * (id - 1) 

263 + 1 

264 + dilation_depth * (kernel_size_d - 1) 

265 - id 

266 ) 

267 / 2 

268 ) 

269 padding_height = math.ceil( 

270 ( 

271 stride_height * (ih - 1) 

272 + 1 

273 + dilation_height * (kernel_size_h - 1) 

274 - ih 

275 ) 

276 / 2 

277 ) 

278 padding_width = math.ceil( 

279 ( 

280 stride_width * (iw - 1) 

281 + 1 

282 + dilation_width * (kernel_size_w - 1) 

283 - iw 

284 ) 

285 / 2 

286 ) 

287 od = int( 

288 (id + 2 * padding_depth - dilation_depth * (kernel_size_d - 1) - 1) 

289 / stride_depth 

290 + 1 

291 ) 

292 oh = int( 

293 (ih + 2 * padding_height - dilation_height * (kernel_size_h - 1) - 1) 

294 / stride_height 

295 + 1 

296 ) 

297 ow = int( 

298 (iw + 2 * padding_width - dilation_width * (kernel_size_w - 1) - 1) 

299 / stride_width 

300 + 1 

301 ) 

302 elif padding == "valid": 

303 padding_depth = padding_height = padding_width = 0 

304 else: 

305 raise ValueError( 

306 f"Unsupported padding string: {padding}, only'valild'/'same' are allowed." 

307 ) 

308 elif isinstance(padding, (list, tuple)): 

309 padding_depth, padding_height, padding_width = padding 

310 else: 

311 padding_depth = padding_height = padding_width = padding 

312 

313 in_n, _, input_depth, input_height, input_width = input.shape 

314 out_c, weight_c, weight_depth, weight_height, weight_width = weight.shape 

315 out_depth = conv3d_output_size( 

316 input_depth, weight_depth, stride_depth, padding_depth, dilation_depth 

317 ) 

318 

319 out_height = conv3d_output_size( 

320 input_height, weight_height, stride_height, padding_height, dilation_height 

321 ) 

322 out_width = conv3d_output_size( 

323 input_width, weight_width, stride_width, padding_width, dilation_width 

324 ) 

325 

326 output_dtype = input.dtype 

327 output = torch.empty( 

328 (in_n, out_c, out_depth, out_height, out_width), 

329 device=input.device, 

330 dtype=output_dtype, 

331 ) 

332 

333 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions, 

334 # BLOCK_CO along the out_c, 

335 # one group per cat 

336 grid = lambda META: ( 

337 triton.cdiv( 

338 in_n * out_depth * out_height * out_width, META["BLOCK_NI_DO_HO_WO"] 

339 ), 

340 triton.cdiv(out_c // groups, META["BLOCK_CO"]), 

341 groups, 

342 ) 

343 

344 if bias is None: 

345 bias_pointer = torch.zeros(out_c, device=input.device, dtype=output_dtype) 

346 else: 

347 bias_pointer = bias 

348 

349 conv3d_forward_kernel[grid]( 

350 input, 

351 weight, 

352 output, 

353 bias_pointer, 

354 in_n, 

355 input_depth, 

356 input_height, 

357 input_width, 

358 out_c, 

359 out_depth, 

360 out_height, 

361 out_width, 

362 *input.stride(), 

363 *weight.stride(), 

364 *output.stride(), 

365 weight_c, 

366 weight_depth, 

367 weight_height, 

368 weight_width, 

369 stride_depth, 

370 stride_height, 

371 stride_width, 

372 padding_depth, 

373 padding_height, 

374 padding_width, 

375 dilation_depth, 

376 dilation_height, 

377 dilation_width, 

378 groups=groups, 

379 ) 

380 

381 if padding == "same": 

382 output = output[..., (od - id) :, (oh - ih) :, (ow - iw) :] 

383 

384 return output