Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/avg_pool2d.py: 0%

150 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12def pool2d_output_size( 

13 in_size: int, 

14 kernel_size: int, 

15 stride: int, 

16 padding: int, 

17 dilation: int, 

18 ceil_mode: bool = False, 

19) -> int: 

20 effective_kernel_size = (kernel_size - 1) * dilation + 1 

21 numerator = in_size + 2 * padding - effective_kernel_size 

22 if ceil_mode: 

23 output_size = (numerator + stride - 1) // stride + 1 

24 if (output_size - 1) * stride >= in_size + padding: 

25 output_size -= 1 

26 else: 

27 output_size = numerator // stride + 1 

28 

29 return output_size 

30 

31 

32@libentry() 

33@triton.autotune( 

34 configs=[ 

35 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

36 ], 

37 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

38) 

39@triton.jit 

40def avg_pool2d_forward_kernel( 

41 input_ptr, 

42 output_ptr, 

43 # Input tensor strides 

44 in_stride_n, 

45 in_stride_c, 

46 in_stride_h, 

47 in_stride_w, 

48 # Input/Output shapes 

49 in_c, 

50 in_h, 

51 in_w, 

52 out_h, 

53 out_w, 

54 # Pooling parameters 

55 kernel_h: tl.constexpr, 

56 kernel_w: tl.constexpr, 

57 stride_h: tl.constexpr, 

58 stride_w: tl.constexpr, 

59 padding_h: tl.constexpr, 

60 padding_w: tl.constexpr, 

61 dilation_h: tl.constexpr, 

62 dilation_w: tl.constexpr, 

63 # AvgPool specific parameters 

64 COUNT_INCLUDE_PAD: tl.constexpr, 

65 divisor_override, 

66 # Tiling meta-parameters 

67 BLOCK_H: tl.constexpr, 

68 BLOCK_W: tl.constexpr, 

69): 

70 pid_nc = tl.program_id(0) 

71 pid_hw = tl.program_id(1) 

72 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

73 h_block_idx = pid_hw // num_w_blocks 

74 w_block_idx = pid_hw % num_w_blocks 

75 n_idx = pid_nc // in_c 

76 c_idx = pid_nc % in_c 

77 

78 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

79 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

80 

81 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) 

82 count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) 

83 

84 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

85 

86 for kh in tl.static_range(0, kernel_h): 

87 for kw in tl.static_range(0, kernel_w): 

88 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h 

89 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w 

90 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) 

91 

92 input_offset = h_in * in_stride_h + w_in * in_stride_w 

93 current_val = tl.load( 

94 input_base_ptr + input_offset, mask=in_mask, other=0.0 

95 ) 

96 

97 sum_acc += tl.where(in_mask, current_val, 0.0) 

98 count_acc += in_mask.to(tl.int32) 

99 

100 count_divisor = count_acc.to(tl.float32) 

101 

102 if COUNT_INCLUDE_PAD: 

103 default_divisor = tl.where( 

104 count_divisor >= 0, float(kernel_h * kernel_w), count_divisor 

105 ) 

106 else: 

107 default_divisor = count_divisor 

108 

109 divisor = tl.where( 

110 divisor_override != 0, divisor_override + default_divisor * 0, default_divisor 

111 ) 

112 

113 output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0) 

114 

115 out_base_ptr = output_ptr + pid_nc * out_h * out_w 

116 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

117 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

118 output_block_ptr = ( 

119 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

120 ) 

121 

122 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w) 

123 tl.store( 

124 output_block_ptr, output_vals.to(output_ptr.type.element_ty), mask=out_mask 

125 ) 

126 

127 

128@libentry() 

129@triton.autotune( 

130 configs=[ 

131 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_warps=8), 

132 ], 

133 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

134) 

135@triton.jit 

136def avg_pool2d_backward_kernel( 

137 grad_output_ptr, 

138 grad_input_ptr, 

139 # Input/Output shapes 

140 in_c, 

141 in_h, 

142 in_w, 

143 out_h, 

144 out_w, 

145 # Strides 

146 in_stride_n, 

147 in_stride_c, 

148 in_stride_h, 

149 in_stride_w, 

150 out_stride_n, 

151 out_stride_c, 

152 out_stride_h, 

153 out_stride_w, 

154 # Pooling parameters 

155 kernel_h: tl.constexpr, 

156 kernel_w: tl.constexpr, 

157 stride_h: tl.constexpr, 

158 stride_w: tl.constexpr, 

159 padding_h: tl.constexpr, 

160 padding_w: tl.constexpr, 

161 dilation_h: tl.constexpr, 

162 dilation_w: tl.constexpr, 

163 # AvgPool specific parameters 

164 COUNT_INCLUDE_PAD: tl.constexpr, 

165 divisor_override, 

166 # Tiling meta-parameters 

167 BLOCK_H: tl.constexpr, 

168 BLOCK_W: tl.constexpr, 

169): 

170 pid_nc = tl.program_id(0) 

171 pid_hw = tl.program_id(1) 

172 

173 num_w_blocks = tl.cdiv(in_w, BLOCK_W) 

174 

175 h_block_idx = pid_hw // num_w_blocks 

176 w_block_idx = pid_hw % num_w_blocks 

177 n_idx = pid_nc // in_c 

178 c_idx = pid_nc % in_c 

179 

180 grad_input_block_ptr = grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

181 grad_output_base_ptr = grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c 

182 

183 h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

184 w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

185 

186 grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) 

187 

188 for kh_loop in tl.static_range(0, kernel_h): 

189 for kw_loop in tl.static_range(0, kernel_w): 

190 h_out_num = h_in_offsets[:, None] + padding_h - kh_loop * dilation_h 

191 w_out_num = w_in_offsets[None, :] + padding_w - kw_loop * dilation_w 

192 

193 h_valid_map = (h_out_num >= 0) & ((h_out_num % stride_h) == 0) 

194 w_valid_map = (w_out_num >= 0) & ((w_out_num % stride_w) == 0) 

195 

196 h_out = h_out_num // stride_h 

197 w_out = w_out_num // stride_w 

198 

199 h_out_mask = h_valid_map & (h_out < out_h) 

200 w_out_mask = w_valid_map & (w_out < out_w) 

201 out_mask = h_out_mask & w_out_mask 

202 

203 # Compute count for this output position (for count_include_pad=False) 

204 h_start = h_out * stride_h - padding_h 

205 w_start = w_out * stride_w - padding_w 

206 count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) 

207 for kh_count in tl.static_range(0, kernel_h): 

208 for kw_count in tl.static_range(0, kernel_w): 

209 h_in_for_count = h_start + kh_count * dilation_h 

210 w_in_for_count = w_start + kw_count * dilation_w 

211 is_valid = ( 

212 (h_in_for_count >= 0) 

213 & (h_in_for_count < in_h) 

214 & (w_in_for_count >= 0) 

215 & (w_in_for_count < in_w) 

216 ) 

217 count += is_valid.to(tl.int32) 

218 

219 count_divisor = count.to(tl.float32) 

220 

221 if COUNT_INCLUDE_PAD: 

222 default_divisor = tl.where( 

223 count_divisor >= 0, float(kernel_h * kernel_w), count_divisor 

224 ) 

225 else: 

226 default_divisor = count_divisor 

227 

228 divisor = tl.where( 

229 divisor_override != 0, 

230 divisor_override + default_divisor * 0, 

231 default_divisor, 

232 ) 

233 divisor = tl.where(divisor == 0, 1.0, divisor) 

234 

235 grad_out_ptr = ( 

236 grad_output_base_ptr + h_out * out_stride_h + w_out * out_stride_w 

237 ) 

238 grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0) 

239 grad_acc += tl.where(out_mask, grad_out_val / divisor, 0.0) 

240 

241 grad_input_store_ptr = ( 

242 grad_input_block_ptr 

243 + h_in_offsets[:, None] * in_stride_h 

244 + w_in_offsets[None, :] * in_stride_w 

245 ) 

246 in_write_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) 

247 tl.store( 

248 grad_input_store_ptr, 

249 grad_acc.to(grad_input_ptr.type.element_ty), 

250 mask=in_write_mask, 

251 ) 

252 

253 

254def _parse_pool_params(kernel_size, stride, padding): 

255 if isinstance(kernel_size, int): 

256 kernel_h = kernel_w = kernel_size 

257 else: 

258 kernel_h, kernel_w = kernel_size 

259 

260 if stride is None or (isinstance(stride, (list, tuple)) and not stride): 

261 stride_h, stride_w = kernel_h, kernel_w 

262 elif isinstance(stride, int): 

263 stride_h = stride_w = stride 

264 else: 

265 stride_h, stride_w = stride 

266 

267 if isinstance(padding, int): 

268 padding_h = padding_w = padding 

269 else: 

270 padding_h, padding_w = padding 

271 

272 if stride_h <= 0 or stride_w <= 0: 

273 raise ValueError("stride must be greater than zero") 

274 

275 if padding_h < 0 or padding_w < 0: 

276 raise ValueError("padding must be non-negative") 

277 

278 if padding_h > kernel_h // 2 or padding_w > kernel_w // 2: 

279 raise ValueError("pad should be smaller than or equal to half of kernel size") 

280 

281 return kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w 

282 

283 

284def avg_pool2d( 

285 input: torch.Tensor, 

286 kernel_size, 

287 stride=None, 

288 padding=0, 

289 ceil_mode=False, 

290 count_include_pad=True, 

291 divisor_override=None, 

292): 

293 logger.debug("GEMS AVG_POOL2D FORWARD") 

294 

295 if divisor_override is not None and divisor_override == 0: 

296 raise ValueError("divisor_override cannot be zero") 

297 

298 input = input.contiguous() 

299 

300 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params( 

301 kernel_size, stride, padding 

302 ) 

303 dilation_h, dilation_w = 1, 1 

304 

305 in_n, in_c, in_h, in_w = input.shape 

306 

307 out_h = pool2d_output_size( 

308 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode 

309 ) 

310 out_w = pool2d_output_size( 

311 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode 

312 ) 

313 

314 output = torch.empty( 

315 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype 

316 ) 

317 

318 if output.numel() == 0: 

319 return output 

320 

321 grid = lambda meta: ( 

322 in_n * in_c, 

323 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]), 

324 ) 

325 

326 avg_pool2d_forward_kernel[grid]( 

327 input, 

328 output, 

329 input.stride(0), 

330 input.stride(1), 

331 input.stride(2), 

332 input.stride(3), 

333 in_c, 

334 in_h, 

335 in_w, 

336 out_h, 

337 out_w, 

338 kernel_h, 

339 kernel_w, 

340 stride_h, 

341 stride_w, 

342 padding_h, 

343 padding_w, 

344 dilation_h, 

345 dilation_w, 

346 COUNT_INCLUDE_PAD=count_include_pad, 

347 divisor_override=divisor_override if divisor_override is not None else 0.0, 

348 ) 

349 

350 return output 

351 

352 

353def avg_pool2d_backward( 

354 grad_output: torch.Tensor, 

355 input: torch.Tensor, 

356 kernel_size, 

357 stride, 

358 padding, 

359 ceil_mode, 

360 count_include_pad, 

361 divisor_override, 

362): 

363 logger.debug("GEMS AVG_POOL2D BACKWARD") 

364 

365 if divisor_override is not None and divisor_override == 0: 

366 raise ValueError("divisor_override cannot be zero") 

367 

368 grad_output = grad_output.contiguous() 

369 

370 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params( 

371 kernel_size, stride, padding 

372 ) 

373 dilation_h, dilation_w = 1, 1 

374 

375 in_n, in_c, in_h, in_w = input.shape 

376 out_h, out_w = grad_output.shape[2], grad_output.shape[3] 

377 

378 grad_input = torch.zeros_like(input, dtype=torch.float32) 

379 

380 if grad_output.numel() == 0: 

381 return grad_input.to(grad_output.dtype) 

382 

383 grid = lambda meta: ( 

384 in_n * in_c, 

385 triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"]), 

386 ) 

387 

388 avg_pool2d_backward_kernel[grid]( 

389 grad_output, 

390 grad_input, 

391 in_c, 

392 in_h, 

393 in_w, 

394 out_h, 

395 out_w, 

396 grad_input.stride(0), 

397 grad_input.stride(1), 

398 grad_input.stride(2), 

399 grad_input.stride(3), 

400 grad_output.stride(0), 

401 grad_output.stride(1), 

402 grad_output.stride(2), 

403 grad_output.stride(3), 

404 kernel_h, 

405 kernel_w, 

406 stride_h, 

407 stride_w, 

408 padding_h, 

409 padding_w, 

410 dilation_h, 

411 dilation_w, 

412 COUNT_INCLUDE_PAD=count_include_pad, 

413 divisor_override=divisor_override if divisor_override is not None else 0.0, 

414 ) 

415 

416 return grad_input.to(grad_output.dtype)