Coverage for src/flag_gems/ops/avg_pool2d.py: 43%

150 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12def pool2d_output_size( 

13 in_size: int, 

14 kernel_size: int, 

15 stride: int, 

16 padding: int, 

17 dilation: int, 

18 ceil_mode: bool = False, 

19) -> int: 

20 effective_kernel_size = (kernel_size - 1) * dilation + 1 

21 numerator = in_size + 2 * padding - effective_kernel_size 

22 if ceil_mode: 

23 output_size = (numerator + stride - 1) // stride + 1 

24 if (output_size - 1) * stride >= in_size + padding: 

25 output_size -= 1 

26 else: 

27 output_size = numerator // stride + 1 

28 

29 return output_size 

30 

31 

32@libentry() 

33@triton.autotune( 

34 configs=[ 

35 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

36 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

37 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

38 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

39 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

40 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), 

41 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

42 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), 

43 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

44 ], 

45 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

46) 

47@triton.jit 

48def avg_pool2d_forward_kernel( 

49 input_ptr, 

50 output_ptr, 

51 # Input tensor strides 

52 in_stride_n, 

53 in_stride_c, 

54 in_stride_h, 

55 in_stride_w, 

56 # Input/Output shapes 

57 in_c, 

58 in_h, 

59 in_w, 

60 out_h, 

61 out_w, 

62 # Pooling parameters 

63 kernel_h: tl.constexpr, 

64 kernel_w: tl.constexpr, 

65 stride_h: tl.constexpr, 

66 stride_w: tl.constexpr, 

67 padding_h: tl.constexpr, 

68 padding_w: tl.constexpr, 

69 dilation_h: tl.constexpr, 

70 dilation_w: tl.constexpr, 

71 # AvgPool specific parameters 

72 COUNT_INCLUDE_PAD: tl.constexpr, 

73 divisor_override, 

74 # Tiling meta-parameters 

75 BLOCK_H: tl.constexpr, 

76 BLOCK_W: tl.constexpr, 

77): 

78 pid_nc = tl.program_id(0) 

79 pid_hw = tl.program_id(1) 

80 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

81 h_block_idx = pid_hw // num_w_blocks 

82 w_block_idx = pid_hw % num_w_blocks 

83 n_idx = pid_nc // in_c 

84 c_idx = pid_nc % in_c 

85 

86 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

87 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

88 

89 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) 

90 count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) 

91 

92 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

93 

94 for kh in range(0, kernel_h): 

95 for kw in range(0, kernel_w): 

96 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h 

97 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w 

98 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) 

99 

100 input_offset = h_in * in_stride_h + w_in * in_stride_w 

101 current_val = tl.load( 

102 input_base_ptr + input_offset, mask=in_mask, other=0.0 

103 ) 

104 

105 sum_acc += tl.where(in_mask, current_val, 0.0) 

106 count_acc += in_mask.to(tl.int32) 

107 

108 if divisor_override != 0: 

109 divisor = tl.full((BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32) 

110 elif COUNT_INCLUDE_PAD: 

111 divisor = tl.full((BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32) 

112 else: 

113 divisor = count_acc.to(tl.float32) 

114 

115 output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0) 

116 

117 out_base_ptr = output_ptr + pid_nc * out_h * out_w 

118 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

119 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

120 output_block_ptr = ( 

121 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

122 ) 

123 

124 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w) 

125 tl.store( 

126 output_block_ptr, output_vals.to(output_ptr.type.element_ty), mask=out_mask 

127 ) 

128 

129 

130@libentry() 

131@triton.autotune( 

132 configs=[ 

133 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

134 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

135 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

136 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

137 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

138 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

139 ], 

140 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

141) 

142@triton.jit 

143def avg_pool2d_backward_kernel( 

144 grad_output_ptr, 

145 grad_input_ptr, 

146 # Input/Output shapes 

147 in_c, 

148 in_h, 

149 in_w, 

150 out_h, 

151 out_w, 

152 # Strides 

153 in_stride_n, 

154 in_stride_c, 

155 in_stride_h, 

156 in_stride_w, 

157 out_stride_n, 

158 out_stride_c, 

159 out_stride_h, 

160 out_stride_w, 

161 # Pooling parameters 

162 kernel_h: tl.constexpr, 

163 kernel_w: tl.constexpr, 

164 stride_h: tl.constexpr, 

165 stride_w: tl.constexpr, 

166 padding_h: tl.constexpr, 

167 padding_w: tl.constexpr, 

168 dilation_h: tl.constexpr, 

169 dilation_w: tl.constexpr, 

170 # AvgPool specific parameters 

171 COUNT_INCLUDE_PAD: tl.constexpr, 

172 divisor_override, 

173 # Tiling meta-parameters 

174 BLOCK_H: tl.constexpr, 

175 BLOCK_W: tl.constexpr, 

176): 

177 pid_nc = tl.program_id(0) 

178 pid_hw = tl.program_id(1) 

179 

180 num_w_blocks = tl.cdiv(in_w, BLOCK_W) 

181 

182 h_block_idx = pid_hw // num_w_blocks 

183 w_block_idx = pid_hw % num_w_blocks 

184 n_idx = pid_nc // in_c 

185 c_idx = pid_nc % in_c 

186 

187 grad_input_block_ptr = grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

188 grad_output_base_ptr = grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c 

189 

190 h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

191 w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

192 

193 grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) 

194 

195 for kh_loop in range(kernel_h): 

196 for kw_loop in range(kernel_w): 

197 h_out_num = h_in_offsets[:, None] + padding_h - kh_loop * dilation_h 

198 w_out_num = w_in_offsets[None, :] + padding_w - kw_loop * dilation_w 

199 

200 h_valid_map = (h_out_num >= 0) & ((h_out_num % stride_h) == 0) 

201 w_valid_map = (w_out_num >= 0) & ((w_out_num % stride_w) == 0) 

202 

203 h_out = h_out_num // stride_h 

204 w_out = w_out_num // stride_w 

205 

206 h_out_mask = h_valid_map & (h_out < out_h) 

207 w_out_mask = w_valid_map & (w_out < out_w) 

208 out_mask = h_out_mask & w_out_mask 

209 

210 if divisor_override != 0: 

211 divisor = tl.full( 

212 (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32 

213 ) 

214 elif COUNT_INCLUDE_PAD: 

215 divisor = tl.full( 

216 (BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32 

217 ) 

218 else: 

219 h_start = h_out * stride_h - padding_h 

220 w_start = w_out * stride_w - padding_w 

221 count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) 

222 for kh_count in range(0, kernel_h): 

223 for kw_count in range(0, kernel_w): 

224 h_in_for_count = h_start + kh_count * dilation_h 

225 w_in_for_count = w_start + kw_count * dilation_w 

226 is_valid = ( 

227 (h_in_for_count >= 0) 

228 & (h_in_for_count < in_h) 

229 & (w_in_for_count >= 0) 

230 & (w_in_for_count < in_w) 

231 ) 

232 count += is_valid.to(tl.int32) 

233 divisor = count.to(tl.float32) 

234 

235 divisor = tl.where(divisor == 0, 1.0, divisor) 

236 

237 grad_out_ptr = ( 

238 grad_output_base_ptr + h_out * out_stride_h + w_out * out_stride_w 

239 ) 

240 grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0) 

241 grad_acc += tl.where(out_mask, grad_out_val / divisor, 0.0) 

242 # grad_to_add = grad_out_val.to(tl.float32) / divisor.to(tl.float32) 

243 # grad_acc += tl.where(out_mask, grad_to_add, 0.0) 

244 

245 grad_input_store_ptr = ( 

246 grad_input_block_ptr 

247 + h_in_offsets[:, None] * in_stride_h 

248 + w_in_offsets[None, :] * in_stride_w 

249 ) 

250 in_write_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) 

251 tl.store( 

252 grad_input_store_ptr, 

253 grad_acc.to(grad_input_ptr.type.element_ty), 

254 mask=in_write_mask, 

255 ) 

256 

257 

258def _parse_pool_params(kernel_size, stride, padding): 

259 if isinstance(kernel_size, int): 

260 kernel_h = kernel_w = kernel_size 

261 else: 

262 kernel_h, kernel_w = kernel_size 

263 

264 if stride is None or (isinstance(stride, (list, tuple)) and not stride): 

265 stride_h, stride_w = kernel_h, kernel_w 

266 elif isinstance(stride, int): 

267 stride_h = stride_w = stride 

268 else: 

269 stride_h, stride_w = stride 

270 

271 if isinstance(padding, int): 

272 padding_h = padding_w = padding 

273 else: 

274 padding_h, padding_w = padding 

275 

276 if stride_h <= 0 or stride_w <= 0: 

277 raise ValueError("stride must be greater than zero") 

278 

279 if padding_h < 0 or padding_w < 0: 

280 raise ValueError("padding must be non-negative") 

281 

282 if padding_h > kernel_h // 2 or padding_w > kernel_w // 2: 

283 raise ValueError("pad should be smaller than or equal to half of kernel size") 

284 

285 return kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w 

286 

287 

288def avg_pool2d( 

289 input: torch.Tensor, 

290 kernel_size, 

291 stride=None, 

292 padding=0, 

293 ceil_mode=False, 

294 count_include_pad=True, 

295 divisor_override=None, 

296): 

297 logger.debug("GEMS AVG_POOL2D FORWARD") 

298 

299 if divisor_override is not None and divisor_override == 0: 

300 raise ValueError("divisor_override cannot be zero") 

301 

302 input = input.contiguous() 

303 

304 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params( 

305 kernel_size, stride, padding 

306 ) 

307 dilation_h, dilation_w = 1, 1 

308 

309 in_n, in_c, in_h, in_w = input.shape 

310 

311 out_h = pool2d_output_size( 

312 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode 

313 ) 

314 out_w = pool2d_output_size( 

315 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode 

316 ) 

317 

318 output = torch.empty( 

319 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype 

320 ) 

321 

322 if output.numel() == 0: 

323 return output 

324 

325 grid = lambda meta: ( 

326 in_n * in_c, 

327 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]), 

328 ) 

329 

330 avg_pool2d_forward_kernel[grid]( 

331 input, 

332 output, 

333 input.stride(0), 

334 input.stride(1), 

335 input.stride(2), 

336 input.stride(3), 

337 in_c, 

338 in_h, 

339 in_w, 

340 out_h, 

341 out_w, 

342 kernel_h, 

343 kernel_w, 

344 stride_h, 

345 stride_w, 

346 padding_h, 

347 padding_w, 

348 dilation_h, 

349 dilation_w, 

350 COUNT_INCLUDE_PAD=count_include_pad, 

351 divisor_override=divisor_override if divisor_override is not None else 0.0, 

352 ) 

353 

354 return output 

355 

356 

357def avg_pool2d_backward( 

358 grad_output: torch.Tensor, 

359 input: torch.Tensor, 

360 kernel_size, 

361 stride, 

362 padding, 

363 ceil_mode, 

364 count_include_pad, 

365 divisor_override, 

366): 

367 logger.debug("GEMS AVG_POOL2D BACKWARD") 

368 

369 if divisor_override is not None and divisor_override == 0: 

370 raise ValueError("divisor_override cannot be zero") 

371 

372 grad_output = grad_output.contiguous() 

373 

374 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params( 

375 kernel_size, stride, padding 

376 ) 

377 dilation_h, dilation_w = 1, 1 

378 

379 in_n, in_c, in_h, in_w = input.shape 

380 out_h, out_w = grad_output.shape[2], grad_output.shape[3] 

381 

382 grad_input = torch.zeros_like(input, dtype=torch.float32) 

383 

384 if grad_output.numel() == 0: 

385 return grad_input.to(grad_output.dtype) 

386 

387 grid = lambda meta: ( 

388 in_n * in_c, 

389 triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"]), 

390 ) 

391 

392 avg_pool2d_backward_kernel[grid]( 

393 grad_output, 

394 grad_input, 

395 in_c, 

396 in_h, 

397 in_w, 

398 out_h, 

399 out_w, 

400 grad_input.stride(0), 

401 grad_input.stride(1), 

402 grad_input.stride(2), 

403 grad_input.stride(3), 

404 grad_output.stride(0), 

405 grad_output.stride(1), 

406 grad_output.stride(2), 

407 grad_output.stride(3), 

408 kernel_h, 

409 kernel_w, 

410 stride_h, 

411 stride_w, 

412 padding_h, 

413 padding_w, 

414 dilation_h, 

415 dilation_w, 

416 COUNT_INCLUDE_PAD=count_include_pad, 

417 divisor_override=divisor_override if divisor_override is not None else 0.0, 

418 ) 

419 

420 return grad_input.to(grad_output.dtype)