Coverage for src/flag_gems/runtime/backend/_cambricon/ops/avg_pool2d.py: 0%

177 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8 

9from ..utils import MAX_GRID_SIZE_X, MAX_GRID_SIZE_Y 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def pool2d_output_size( 

15 in_size: int, 

16 kernel_size: int, 

17 stride: int, 

18 padding: int, 

19 dilation: int, 

20 ceil_mode: bool = False, 

21) -> int: 

22 effective_kernel_size = (kernel_size - 1) * dilation + 1 

23 numerator = in_size + 2 * padding - effective_kernel_size 

24 if ceil_mode: 

25 output_size = (numerator + stride - 1) // stride + 1 

26 if (output_size - 1) * stride >= in_size + padding: 

27 output_size -= 1 

28 else: 

29 output_size = numerator // stride + 1 

30 

31 return output_size 

32 

33 

34def limit_grid(grid_0, grid_1): 

35 grid_0_ub = MAX_GRID_SIZE_X // 4 

36 grid_1_ub = MAX_GRID_SIZE_Y 

37 return min(grid_0, grid_0_ub), min(grid_1, grid_1_ub) 

38 

39 

40@libentry() 

41@triton.autotune( 

42 configs=[ 

43 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

44 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

46 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

47 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

48 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), 

49 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

50 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), 

51 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

52 ], 

53 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

54) 

55@triton.jit 

56def avg_pool2d_forward_kernel( 

57 input_ptr, 

58 output_ptr, 

59 # Input tensor strides 

60 in_stride_n, 

61 in_stride_c, 

62 in_stride_h, 

63 in_stride_w, 

64 # Input/Output shapes 

65 in_c, 

66 in_h, 

67 in_w, 

68 out_h, 

69 out_w, 

70 # Total number of tasks on axis 0 

71 task_num_0, 

72 # Pooling parameters 

73 kernel_h: tl.constexpr, 

74 kernel_w: tl.constexpr, 

75 stride_h: tl.constexpr, 

76 stride_w: tl.constexpr, 

77 padding_h: tl.constexpr, 

78 padding_w: tl.constexpr, 

79 dilation_h: tl.constexpr, 

80 dilation_w: tl.constexpr, 

81 # AvgPool specific parameters 

82 COUNT_INCLUDE_PAD: tl.constexpr, 

83 divisor_override, 

84 # Tiling meta-parameters 

85 BLOCK_H: tl.constexpr, 

86 BLOCK_W: tl.constexpr, 

87): 

88 task_num_1 = tl.cdiv(out_h, BLOCK_H) * tl.cdiv(out_w, BLOCK_W) 

89 grid_0 = tl.num_programs(0) 

90 grid_1 = tl.num_programs(1) 

91 pid_nc = tl.program_id(0) 

92 while pid_nc < task_num_0: 

93 pid_hw = tl.program_id(1) 

94 while pid_hw < task_num_1: 

95 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

96 h_block_idx = pid_hw // num_w_blocks 

97 w_block_idx = pid_hw % num_w_blocks 

98 n_idx = pid_nc // in_c 

99 c_idx = pid_nc % in_c 

100 

101 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

102 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

103 

104 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) 

105 count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) 

106 

107 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

108 

109 for kh in range(0, kernel_h): 

110 for kw in range(0, kernel_w): 

111 h_in = ( 

112 h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h 

113 ) 

114 w_in = ( 

115 w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w 

116 ) 

117 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) 

118 

119 input_offset = h_in * in_stride_h + w_in * in_stride_w 

120 current_val = tl.load( 

121 input_base_ptr + input_offset, mask=in_mask, other=0.0 

122 ) 

123 

124 sum_acc += tl.where(in_mask, current_val, 0.0) 

125 count_acc += in_mask.to(tl.int32) 

126 

127 if divisor_override != 0: 

128 divisor = tl.full( 

129 (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32 

130 ) 

131 elif COUNT_INCLUDE_PAD: 

132 divisor = tl.full( 

133 (BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32 

134 ) 

135 else: 

136 divisor = count_acc.to(tl.float32) 

137 

138 output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0) 

139 

140 out_base_ptr = output_ptr + pid_nc * out_h * out_w 

141 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

142 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

143 output_block_ptr = ( 

144 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

145 ) 

146 

147 out_mask = (out_h_offsets[:, None] < out_h) & ( 

148 out_w_offsets[None, :] < out_w 

149 ) 

150 tl.store( 

151 output_block_ptr, 

152 output_vals.to(output_ptr.type.element_ty), 

153 mask=out_mask, 

154 ) 

155 pid_hw += grid_1 

156 pid_nc += grid_0 

157 

158 

159@libentry() 

160@triton.autotune( 

161 configs=[ 

162 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

163 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

164 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

165 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

166 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

167 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

168 ], 

169 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

170) 

171@triton.jit 

172def avg_pool2d_backward_kernel( 

173 grad_output_ptr, 

174 grad_input_ptr, 

175 # Input/Output shapes 

176 in_c, 

177 in_h, 

178 in_w, 

179 out_h, 

180 out_w, 

181 task_num_0, 

182 # Strides 

183 in_stride_n, 

184 in_stride_c, 

185 in_stride_h, 

186 in_stride_w, 

187 out_stride_n, 

188 out_stride_c, 

189 out_stride_h, 

190 out_stride_w, 

191 # Pooling parameters 

192 kernel_h: tl.constexpr, 

193 kernel_w: tl.constexpr, 

194 stride_h: tl.constexpr, 

195 stride_w: tl.constexpr, 

196 padding_h: tl.constexpr, 

197 padding_w: tl.constexpr, 

198 dilation_h: tl.constexpr, 

199 dilation_w: tl.constexpr, 

200 # AvgPool specific parameters 

201 COUNT_INCLUDE_PAD: tl.constexpr, 

202 divisor_override, 

203 # Tiling meta-parameters 

204 BLOCK_H: tl.constexpr, 

205 BLOCK_W: tl.constexpr, 

206): 

207 task_num_1 = tl.cdiv(in_h, BLOCK_H) * tl.cdiv(in_w, BLOCK_W) 

208 grid_0 = tl.num_programs(0) 

209 grid_1 = tl.num_programs(1) 

210 pid_nc = tl.program_id(0) 

211 while pid_nc < task_num_0: 

212 pid_hw = tl.program_id(1) 

213 while pid_hw < task_num_1: 

214 num_w_blocks = tl.cdiv(in_w, BLOCK_W) 

215 h_block_idx = pid_hw // num_w_blocks 

216 w_block_idx = pid_hw % num_w_blocks 

217 n_idx = pid_nc // in_c 

218 c_idx = pid_nc % in_c 

219 

220 grad_input_block_ptr = ( 

221 grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

222 ) 

223 grad_output_base_ptr = ( 

224 grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c 

225 ) 

226 

227 h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

228 w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

229 

230 grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) 

231 

232 for kh_loop in range(kernel_h): 

233 for kw_loop in range(kernel_w): 

234 h_out_num = h_in_offsets[:, None] + padding_h - kh_loop * dilation_h 

235 w_out_num = w_in_offsets[None, :] + padding_w - kw_loop * dilation_w 

236 

237 h_valid_map = (h_out_num >= 0) & ((h_out_num % stride_h) == 0) 

238 w_valid_map = (w_out_num >= 0) & ((w_out_num % stride_w) == 0) 

239 

240 h_out = h_out_num // stride_h 

241 w_out = w_out_num // stride_w 

242 

243 h_out_mask = h_valid_map & (h_out < out_h) 

244 w_out_mask = w_valid_map & (w_out < out_w) 

245 out_mask = h_out_mask & w_out_mask 

246 

247 if divisor_override != 0: 

248 divisor = tl.full( 

249 (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32 

250 ) 

251 elif COUNT_INCLUDE_PAD: 

252 divisor = tl.full( 

253 (BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32 

254 ) 

255 else: 

256 h_start = h_out * stride_h - padding_h 

257 w_start = w_out * stride_w - padding_w 

258 count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) 

259 for kh_count in range(0, kernel_h): 

260 for kw_count in range(0, kernel_w): 

261 h_in_for_count = h_start + kh_count * dilation_h 

262 w_in_for_count = w_start + kw_count * dilation_w 

263 is_valid = ( 

264 (h_in_for_count >= 0) 

265 & (h_in_for_count < in_h) 

266 & (w_in_for_count >= 0) 

267 & (w_in_for_count < in_w) 

268 ) 

269 count += is_valid.to(tl.int32) 

270 divisor = count.to(tl.float32) 

271 

272 divisor = tl.where(divisor == 0, 1.0, divisor) 

273 

274 grad_out_ptr = ( 

275 grad_output_base_ptr 

276 + h_out * out_stride_h 

277 + w_out * out_stride_w 

278 ) 

279 grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0) 

280 grad_acc += tl.where(out_mask, grad_out_val / divisor, 0.0) 

281 # grad_to_add = grad_out_val.to(tl.float32) / divisor.to(tl.float32) 

282 # grad_acc += tl.where(out_mask, grad_to_add, 0.0) 

283 

284 grad_input_store_ptr = ( 

285 grad_input_block_ptr 

286 + h_in_offsets[:, None] * in_stride_h 

287 + w_in_offsets[None, :] * in_stride_w 

288 ) 

289 in_write_mask = (h_in_offsets[:, None] < in_h) & ( 

290 w_in_offsets[None, :] < in_w 

291 ) 

292 tl.store( 

293 grad_input_store_ptr, 

294 grad_acc.to(grad_input_ptr.type.element_ty), 

295 mask=in_write_mask, 

296 ) 

297 pid_hw += grid_1 

298 pid_nc += grid_0 

299 

300 

301def _parse_pool_params(kernel_size, stride, padding): 

302 if isinstance(kernel_size, int): 

303 kernel_h = kernel_w = kernel_size 

304 else: 

305 kernel_h, kernel_w = kernel_size 

306 

307 if stride is None or (isinstance(stride, (list, tuple)) and not stride): 

308 stride_h, stride_w = kernel_h, kernel_w 

309 elif isinstance(stride, int): 

310 stride_h = stride_w = stride 

311 else: 

312 stride_h, stride_w = stride 

313 

314 if isinstance(padding, int): 

315 padding_h = padding_w = padding 

316 else: 

317 padding_h, padding_w = padding 

318 

319 if stride_h <= 0 or stride_w <= 0: 

320 raise ValueError("stride must be greater than zero") 

321 

322 if padding_h < 0 or padding_w < 0: 

323 raise ValueError("padding must be non-negative") 

324 

325 if padding_h > kernel_h // 2 or padding_w > kernel_w // 2: 

326 raise ValueError("pad should be smaller than or equal to half of kernel size") 

327 

328 return kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w 

329 

330 

331def avg_pool2d( 

332 input: torch.Tensor, 

333 kernel_size, 

334 stride=None, 

335 padding=0, 

336 ceil_mode=False, 

337 count_include_pad=True, 

338 divisor_override=None, 

339): 

340 logger.debug("GEMS_CAMBRICON AVG_POOL2D FORWARD") 

341 

342 if divisor_override is not None and divisor_override == 0: 

343 raise ValueError("divisor_override cannot be zero") 

344 

345 input = input.contiguous() 

346 

347 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params( 

348 kernel_size, stride, padding 

349 ) 

350 dilation_h, dilation_w = 1, 1 

351 

352 in_n, in_c, in_h, in_w = input.shape 

353 

354 out_h = pool2d_output_size( 

355 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode 

356 ) 

357 out_w = pool2d_output_size( 

358 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode 

359 ) 

360 

361 output = torch.empty( 

362 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype 

363 ) 

364 

365 if output.numel() == 0: 

366 return output 

367 

368 def grid(meta): 

369 grid_0 = in_n * in_c 

370 grid_1 = triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv( 

371 out_w, meta["BLOCK_W"] 

372 ) 

373 return limit_grid(grid_0, grid_1) 

374 

375 task_num_0 = in_n * in_c 

376 

377 avg_pool2d_forward_kernel[grid]( 

378 input, 

379 output, 

380 input.stride(0), 

381 input.stride(1), 

382 input.stride(2), 

383 input.stride(3), 

384 in_c, 

385 in_h, 

386 in_w, 

387 out_h, 

388 out_w, 

389 task_num_0, 

390 kernel_h, 

391 kernel_w, 

392 stride_h, 

393 stride_w, 

394 padding_h, 

395 padding_w, 

396 dilation_h, 

397 dilation_w, 

398 COUNT_INCLUDE_PAD=count_include_pad, 

399 divisor_override=divisor_override if divisor_override is not None else 0.0, 

400 ) 

401 

402 return output 

403 

404 

405def avg_pool2d_backward( 

406 grad_output: torch.Tensor, 

407 input: torch.Tensor, 

408 kernel_size, 

409 stride, 

410 padding, 

411 ceil_mode, 

412 count_include_pad, 

413 divisor_override, 

414): 

415 logger.debug("GEMS_CAMBRICON AVG_POOL2D BACKWARD") 

416 

417 if divisor_override is not None and divisor_override == 0: 

418 raise ValueError("divisor_override cannot be zero") 

419 

420 grad_output = grad_output.contiguous() 

421 

422 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params( 

423 kernel_size, stride, padding 

424 ) 

425 dilation_h, dilation_w = 1, 1 

426 

427 in_n, in_c, in_h, in_w = input.shape 

428 out_h, out_w = grad_output.shape[2], grad_output.shape[3] 

429 

430 grad_input = torch.zeros_like(input, dtype=torch.float32) 

431 

432 if grad_output.numel() == 0: 

433 return grad_input.to(grad_output.dtype) 

434 

435 def grid(meta): 

436 grid_0 = in_n * in_c 

437 grid_1 = triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"]) 

438 return limit_grid(grid_0, grid_1) 

439 

440 task_num_0 = in_n * in_c 

441 avg_pool2d_backward_kernel[grid]( 

442 grad_output, 

443 grad_input, 

444 in_c, 

445 in_h, 

446 in_w, 

447 out_h, 

448 out_w, 

449 task_num_0, 

450 grad_input.stride(0), 

451 grad_input.stride(1), 

452 grad_input.stride(2), 

453 grad_input.stride(3), 

454 grad_output.stride(0), 

455 grad_output.stride(1), 

456 grad_output.stride(2), 

457 grad_output.stride(3), 

458 kernel_h, 

459 kernel_w, 

460 stride_h, 

461 stride_w, 

462 padding_h, 

463 padding_w, 

464 dilation_h, 

465 dilation_w, 

466 COUNT_INCLUDE_PAD=count_include_pad, 

467 divisor_override=divisor_override if divisor_override is not None else 0.0, 

468 ) 

469 

470 return grad_input.to(grad_output.dtype)