Coverage for src/flag_gems/runtime/backend/_cambricon/ops/max_pool2d_with_indices.py: 0%

168 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8from flag_gems.utils.limits import get_dtype_min 

9 

10from ..utils import MAX_GRID_SIZE_X, MAX_GRID_SIZE_Y 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15def max_pool2d_output_size( 

16 in_size: int, 

17 kernel_size: int, 

18 stride: int, 

19 padding: int, 

20 dilation: int, 

21 ceil_mode: bool = False, 

22) -> int: 

23 effective_kernel_size = (kernel_size - 1) * dilation + 1 

24 numerator = in_size + 2 * padding - effective_kernel_size 

25 if ceil_mode: 

26 output_size = (numerator + stride - 1) // stride + 1 

27 # PyTorch-compatible adjustment for ceil_mode 

28 if (output_size - 1) * stride >= in_size + padding: 

29 output_size -= 1 

30 else: 

31 output_size = numerator // stride + 1 

32 

33 return output_size 

34 

35 

36def limit_grid(grid_0, grid_1): 

37 grid_0_ub = MAX_GRID_SIZE_X // 4 

38 grid_1_ub = MAX_GRID_SIZE_Y 

39 return min(grid_0, grid_0_ub), min(grid_1, grid_1_ub) 

40 

41 

42@libentry() 

43@triton.autotune( 

44 configs=[ 

45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

46 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

47 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

48 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

49 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

50 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

51 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), 

52 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), 

53 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

54 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=3, num_warps=8), 

55 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=3, num_warps=8), 

56 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

57 ], 

58 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

59) 

60@triton.jit 

61def max_pool2d_forward_kernel( 

62 input_ptr, 

63 output_ptr, 

64 indices_ptr, 

65 # Input tensor strides 

66 in_stride_n, 

67 in_stride_c, 

68 in_stride_h, 

69 in_stride_w, 

70 # Input/Output shapes 

71 in_c, 

72 in_h, 

73 in_w, 

74 out_h, 

75 out_w, 

76 # Total number of tasks on axis 0 

77 task_num_0, 

78 # Pooling parameters 

79 kernel_h: tl.constexpr, 

80 kernel_w: tl.constexpr, 

81 stride_h: tl.constexpr, 

82 stride_w: tl.constexpr, 

83 padding_h: tl.constexpr, 

84 padding_w: tl.constexpr, 

85 dilation_h: tl.constexpr, 

86 dilation_w: tl.constexpr, 

87 # Meta-parameters for tiling 

88 BLOCK_H: tl.constexpr, 

89 BLOCK_W: tl.constexpr, 

90): 

91 task_num_1 = tl.cdiv(out_h, BLOCK_H) * tl.cdiv(out_w, BLOCK_W) 

92 grid_0 = tl.num_programs(0) 

93 grid_1 = tl.num_programs(1) 

94 pid_nc = tl.program_id(0) 

95 while pid_nc < task_num_0: 

96 pid_hw = tl.program_id(1) 

97 while pid_hw < task_num_1: 

98 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

99 h_block_idx = pid_hw // num_w_blocks 

100 w_block_idx = pid_hw % num_w_blocks 

101 n_idx = pid_nc // in_c 

102 c_idx = pid_nc % in_c 

103 

104 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

105 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

106 

107 dtype = input_ptr.type.element_ty 

108 min_val = get_dtype_min(dtype) 

109 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype) 

110 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int64) 

111 

112 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

113 

114 for kh in tl.static_range(0, kernel_h): 

115 for kw in tl.static_range(0, kernel_w): 

116 h_in = ( 

117 h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h 

118 ) 

119 w_in = ( 

120 w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w 

121 ) 

122 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) 

123 input_offset = h_in * in_stride_h + w_in * in_stride_w 

124 current_val = tl.load( 

125 input_base_ptr + input_offset, mask=in_mask, other=min_val 

126 ) 

127 current_idx = h_in * in_w + w_in 

128 

129 is_new_max = current_val > max_val_acc 

130 max_val_acc = tl.where(is_new_max, current_val, max_val_acc) 

131 max_idx_acc = tl.where( 

132 is_new_max & in_mask, current_idx, max_idx_acc 

133 ) 

134 

135 out_base_ptr = output_ptr + pid_nc * out_h * out_w 

136 indices_base_ptr = indices_ptr + pid_nc * out_h * out_w 

137 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

138 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

139 output_block_ptr = ( 

140 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

141 ) 

142 indices_block_ptr = ( 

143 indices_base_ptr 

144 + out_h_offsets[:, None] * out_w 

145 + out_w_offsets[None, :] 

146 ) 

147 

148 out_mask = (out_h_offsets[:, None] < out_h) & ( 

149 out_w_offsets[None, :] < out_w 

150 ) 

151 tl.store(output_block_ptr, max_val_acc, mask=out_mask) 

152 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask) 

153 pid_hw += grid_1 

154 pid_nc += grid_0 

155 

156 

157@libentry() 

158@triton.autotune( 

159 configs=[ 

160 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 16}, num_warps=4), 

161 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 8}, num_warps=4), 

162 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 32}, num_warps=4), 

163 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=8), 

164 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 64}, num_warps=8), 

165 triton.Config({"BLOCK_IN_H": 64, "BLOCK_IN_W": 16}, num_warps=8), 

166 ], 

167 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

168) 

169@triton.jit 

170def max_pool2d_backward_kernel( 

171 grad_output_ptr, 

172 indices_ptr, 

173 grad_input_ptr, 

174 # Shape info 

175 in_h, 

176 in_w, 

177 out_h, 

178 out_w, 

179 # Strides for grad_output/indices 

180 out_stride_nc, 

181 out_stride_h, 

182 out_stride_w, 

183 # Total number of tasks on axis 0 

184 task_num_0, 

185 # Pooling parameters 

186 kernel_h: tl.constexpr, 

187 kernel_w: tl.constexpr, 

188 stride_h: tl.constexpr, 

189 stride_w: tl.constexpr, 

190 padding_h: tl.constexpr, 

191 padding_w: tl.constexpr, 

192 dilation_h: tl.constexpr, 

193 dilation_w: tl.constexpr, 

194 # Tiling parameters 

195 BLOCK_IN_H: tl.constexpr, 

196 BLOCK_IN_W: tl.constexpr, 

197): 

198 task_num_1 = tl.cdiv(in_h, BLOCK_IN_H) * tl.cdiv(in_w, BLOCK_IN_W) 

199 grid_0 = tl.num_programs(0) 

200 grid_1 = tl.num_programs(1) 

201 nc_idx = tl.program_id(0) 

202 while nc_idx < task_num_0: 

203 pid_hw = tl.program_id(1) 

204 while pid_hw < task_num_1: 

205 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W) 

206 h_block_idx = pid_hw // num_w_blocks 

207 w_block_idx = pid_hw % num_w_blocks 

208 

209 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H) 

210 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W) 

211 

212 current_input_flat_idx = ( 

213 h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

214 ) 

215 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32) 

216 

217 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc 

218 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc 

219 

220 for kh in tl.static_range(0, kernel_h): 

221 for kw in tl.static_range(0, kernel_w): 

222 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h 

223 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w 

224 

225 valid_map_mask = (numerator_h % stride_h == 0) & ( 

226 numerator_w % stride_w == 0 

227 ) 

228 h_out = numerator_h // stride_h 

229 w_out = numerator_w // stride_w 

230 out_bounds_mask = ( 

231 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w) 

232 ) 

233 load_mask = valid_map_mask & out_bounds_mask 

234 

235 safe_h_out = tl.where(load_mask, h_out, 0) 

236 safe_w_out = tl.where(load_mask, w_out, 0) 

237 out_offsets = safe_h_out * out_stride_h + safe_w_out 

238 

239 indices_block = tl.load( 

240 indices_base_ptr + out_offsets, mask=load_mask, other=-1 

241 ) 

242 match_mask = indices_block == current_input_flat_idx 

243 

244 grad_block = tl.load( 

245 grad_output_base_ptr + out_offsets, mask=match_mask, other=0.0 

246 ) 

247 grad_acc += grad_block 

248 

249 grad_input_base_ptr = grad_input_ptr + nc_idx * in_h * in_w 

250 grad_input_offsets = h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

251 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) 

252 tl.store( 

253 grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask 

254 ) 

255 pid_hw += grid_1 

256 nc_idx += grid_0 

257 

258 

259def _parse_pool_params(kernel_size, stride, padding, dilation): 

260 def _parse_param(param, name, default=None): 

261 if param is None: 

262 return default 

263 if isinstance(param, int): 

264 return param, param 

265 if isinstance(param, (list, tuple)) and len(param) == 2: 

266 return param 

267 raise ValueError(f"Invalid {name}: {param}") 

268 

269 kernel_h, kernel_w = _parse_param(kernel_size, "kernel_size") 

270 stride_h, stride_w = _parse_param(stride, "stride", default=(kernel_h, kernel_w)) 

271 padding_h, padding_w = _parse_param(padding, "padding", default=(0, 0)) 

272 dilation_h, dilation_w = _parse_param(dilation, "dilation", default=(1, 1)) 

273 

274 if stride_h <= 0 or stride_w <= 0: 

275 raise ValueError( 

276 f"stride must be positive, but got stride=({stride_h}, {stride_w})" 

277 ) 

278 if padding_h < 0 or padding_w < 0: 

279 raise ValueError( 

280 f"padding must be non-negative, but got padding=({padding_h}, {padding_w})" 

281 ) 

282 if dilation_h <= 0 or dilation_w <= 0: 

283 raise ValueError( 

284 f"dilation must be positive, but got dilation=({dilation_h}, {dilation_w})" 

285 ) 

286 

287 return ( 

288 kernel_h, 

289 kernel_w, 

290 stride_h, 

291 stride_w, 

292 padding_h, 

293 padding_w, 

294 dilation_h, 

295 dilation_w, 

296 ) 

297 

298 

299def max_pool2d_with_indices( 

300 input: torch.Tensor, 

301 kernel_size, 

302 stride=None, 

303 padding=0, 

304 dilation=1, 

305 ceil_mode=False, 

306): 

307 logger.debug("GEMS_CAMBRICON MAX_POOL2D_WITH_INDICES FORWARD") 

308 input = input.contiguous() 

309 

310 params = _parse_pool_params(kernel_size, stride, padding, dilation) 

311 ( 

312 kernel_h, 

313 kernel_w, 

314 stride_h, 

315 stride_w, 

316 padding_h, 

317 padding_w, 

318 dilation_h, 

319 dilation_w, 

320 ) = params 

321 

322 in_n, in_c, in_h, in_w = input.shape 

323 out_h = max_pool2d_output_size( 

324 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode 

325 ) 

326 out_w = max_pool2d_output_size( 

327 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode 

328 ) 

329 

330 output = torch.empty( 

331 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype 

332 ) 

333 indices = torch.empty( 

334 (in_n, in_c, out_h, out_w), device=input.device, dtype=torch.int64 

335 ) 

336 

337 if output.numel() == 0: 

338 return output, indices 

339 

340 def grid(meta): 

341 grid_0 = in_n * in_c 

342 grid_1 = triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv( 

343 out_w, meta["BLOCK_W"] 

344 ) 

345 return limit_grid(grid_0, grid_1) 

346 

347 task_num_0 = in_n * in_c 

348 max_pool2d_forward_kernel[grid]( 

349 input, 

350 output, 

351 indices, 

352 input.stride(0), 

353 input.stride(1), 

354 input.stride(2), 

355 input.stride(3), 

356 in_c, 

357 in_h, 

358 in_w, 

359 out_h, 

360 out_w, 

361 task_num_0, 

362 kernel_h, 

363 kernel_w, 

364 stride_h, 

365 stride_w, 

366 padding_h, 

367 padding_w, 

368 dilation_h, 

369 dilation_w, 

370 ) 

371 

372 return output, indices 

373 

374 

375def max_pool2d_backward( 

376 grad_output: torch.Tensor, 

377 input: torch.Tensor, 

378 indices: torch.Tensor, 

379 kernel_size, 

380 stride, 

381 padding, 

382 dilation, 

383 ceil_mode, 

384): 

385 logger.debug("GEMS_CAMBRICON MAX_POOL2D_WITH_INDICES BACKWARD") 

386 grad_output = grad_output.contiguous() 

387 indices = indices.contiguous() 

388 

389 params = _parse_pool_params(kernel_size, stride, padding, dilation) 

390 ( 

391 kernel_h, 

392 kernel_w, 

393 stride_h, 

394 stride_w, 

395 padding_h, 

396 padding_w, 

397 dilation_h, 

398 dilation_w, 

399 ) = params 

400 

401 in_n, in_c, in_h, in_w = input.shape 

402 out_h, out_w = grad_output.shape[2], grad_output.shape[3] 

403 

404 grad_input = torch.zeros_like(input, dtype=torch.float32) 

405 

406 if grad_input.numel() == 0: 

407 return grad_input.to(grad_output.dtype) 

408 

409 def grid(meta): 

410 grid_0 = in_n * in_c 

411 grid_1 = triton.cdiv(in_h, meta["BLOCK_IN_H"]) * triton.cdiv( 

412 in_w, meta["BLOCK_IN_W"] 

413 ) 

414 return limit_grid(grid_0, grid_1) 

415 

416 task_num_0 = in_n * in_c 

417 

418 out_stride_nc = out_h * out_w 

419 out_stride_h = out_w 

420 out_stride_w = 1 

421 

422 max_pool2d_backward_kernel[grid]( 

423 grad_output, 

424 indices, 

425 grad_input, 

426 in_h, 

427 in_w, 

428 out_h, 

429 out_w, 

430 out_stride_nc, 

431 out_stride_h, 

432 out_stride_w, 

433 task_num_0, 

434 kernel_h, 

435 kernel_w, 

436 stride_h, 

437 stride_w, 

438 padding_h, 

439 padding_w, 

440 dilation_h, 

441 dilation_w, 

442 ) 

443 

444 return grad_input.to(grad_output.dtype)