Coverage for src/flag_gems/ops/max_pool2d_with_indices.py: 48%

141 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8from flag_gems.utils.limits import get_dtype_min 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def max_pool2d_output_size( 

14 in_size: int, 

15 kernel_size: int, 

16 stride: int, 

17 padding: int, 

18 dilation: int, 

19 ceil_mode: bool = False, 

20) -> int: 

21 effective_kernel_size = (kernel_size - 1) * dilation + 1 

22 numerator = in_size + 2 * padding - effective_kernel_size 

23 if ceil_mode: 

24 output_size = (numerator + stride - 1) // stride + 1 

25 # PyTorch-compatible adjustment for ceil_mode 

26 if (output_size - 1) * stride >= in_size + padding: 

27 output_size -= 1 

28 else: 

29 output_size = numerator // stride + 1 

30 

31 return output_size 

32 

33 

34@libentry() 

35@triton.autotune( 

36 configs=[ 

37 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

38 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

39 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

40 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

41 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

42 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

43 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), 

44 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), 

45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

46 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=3, num_warps=8), 

47 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=3, num_warps=8), 

48 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

49 ], 

50 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

51) 

52@triton.jit 

53def max_pool2d_forward_kernel( 

54 input_ptr, 

55 output_ptr, 

56 indices_ptr, 

57 # Input tensor strides 

58 in_stride_n, 

59 in_stride_c, 

60 in_stride_h, 

61 in_stride_w, 

62 # Input/Output shapes 

63 in_c, 

64 in_h, 

65 in_w, 

66 out_h, 

67 out_w, 

68 # Pooling parameters 

69 kernel_h: tl.constexpr, 

70 kernel_w: tl.constexpr, 

71 stride_h: tl.constexpr, 

72 stride_w: tl.constexpr, 

73 padding_h: tl.constexpr, 

74 padding_w: tl.constexpr, 

75 dilation_h: tl.constexpr, 

76 dilation_w: tl.constexpr, 

77 # Meta-parameters for tiling 

78 BLOCK_H: tl.constexpr, 

79 BLOCK_W: tl.constexpr, 

80): 

81 pid_nc = tl.program_id(0) 

82 pid_hw = tl.program_id(1) 

83 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

84 h_block_idx = pid_hw // num_w_blocks 

85 w_block_idx = pid_hw % num_w_blocks 

86 n_idx = pid_nc // in_c 

87 c_idx = pid_nc % in_c 

88 

89 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

90 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

91 

92 dtype = input_ptr.type.element_ty 

93 min_val = get_dtype_min(dtype) 

94 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype) 

95 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int64) 

96 

97 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

98 

99 for kh in tl.static_range(0, kernel_h): 

100 for kw in tl.static_range(0, kernel_w): 

101 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h 

102 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w 

103 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) 

104 input_offset = h_in * in_stride_h + w_in * in_stride_w 

105 current_val = tl.load( 

106 input_base_ptr + input_offset, mask=in_mask, other=min_val 

107 ) 

108 current_idx = h_in * in_w + w_in 

109 

110 is_new_max = current_val > max_val_acc 

111 max_val_acc = tl.where(is_new_max, current_val, max_val_acc) 

112 max_idx_acc = tl.where(is_new_max & in_mask, current_idx, max_idx_acc) 

113 

114 out_base_ptr = output_ptr + pid_nc * out_h * out_w 

115 indices_base_ptr = indices_ptr + pid_nc * out_h * out_w 

116 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

117 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

118 output_block_ptr = ( 

119 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

120 ) 

121 indices_block_ptr = ( 

122 indices_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

123 ) 

124 

125 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w) 

126 tl.store(output_block_ptr, max_val_acc, mask=out_mask) 

127 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask) 

128 

129 

130@libentry() 

131@triton.autotune( 

132 configs=[ 

133 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 16}, num_warps=4), 

134 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 8}, num_warps=4), 

135 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 32}, num_warps=4), 

136 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=8), 

137 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 64}, num_warps=8), 

138 triton.Config({"BLOCK_IN_H": 64, "BLOCK_IN_W": 16}, num_warps=8), 

139 ], 

140 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

141) 

142@triton.jit 

143def max_pool2d_backward_kernel( 

144 grad_output_ptr, 

145 indices_ptr, 

146 grad_input_ptr, 

147 # Shape info 

148 in_h, 

149 in_w, 

150 out_h, 

151 out_w, 

152 # Strides for grad_output/indices 

153 out_stride_nc, 

154 out_stride_h, 

155 out_stride_w, 

156 # Pooling parameters 

157 kernel_h: tl.constexpr, 

158 kernel_w: tl.constexpr, 

159 stride_h: tl.constexpr, 

160 stride_w: tl.constexpr, 

161 padding_h: tl.constexpr, 

162 padding_w: tl.constexpr, 

163 dilation_h: tl.constexpr, 

164 dilation_w: tl.constexpr, 

165 # Tiling parameters 

166 BLOCK_IN_H: tl.constexpr, 

167 BLOCK_IN_W: tl.constexpr, 

168): 

169 nc_idx = tl.program_id(0) 

170 pid_hw = tl.program_id(1) 

171 

172 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W) 

173 h_block_idx = pid_hw // num_w_blocks 

174 w_block_idx = pid_hw % num_w_blocks 

175 

176 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H) 

177 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W) 

178 

179 current_input_flat_idx = h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

180 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32) 

181 

182 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc 

183 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc 

184 

185 for kh in tl.static_range(0, kernel_h): 

186 for kw in tl.static_range(0, kernel_w): 

187 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h 

188 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w 

189 

190 valid_map_mask = (numerator_h % stride_h == 0) & ( 

191 numerator_w % stride_w == 0 

192 ) 

193 h_out = numerator_h // stride_h 

194 w_out = numerator_w // stride_w 

195 out_bounds_mask = ( 

196 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w) 

197 ) 

198 load_mask = valid_map_mask & out_bounds_mask 

199 

200 safe_h_out = tl.where(load_mask, h_out, 0) 

201 safe_w_out = tl.where(load_mask, w_out, 0) 

202 out_offsets = safe_h_out * out_stride_h + safe_w_out 

203 

204 indices_block = tl.load( 

205 indices_base_ptr + out_offsets, mask=load_mask, other=-1 

206 ) 

207 match_mask = indices_block == current_input_flat_idx 

208 

209 grad_block = tl.load( 

210 grad_output_base_ptr + out_offsets, mask=match_mask, other=0.0 

211 ) 

212 grad_acc += grad_block 

213 

214 grad_input_base_ptr = grad_input_ptr + nc_idx * in_h * in_w 

215 grad_input_offsets = h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

216 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) 

217 tl.store(grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask) 

218 

219 

220def _parse_pool_params(kernel_size, stride, padding, dilation): 

221 def _parse_param(param, name, default=None): 

222 if param is None: 

223 return default 

224 if isinstance(param, int): 

225 return param, param 

226 if isinstance(param, (list, tuple)) and len(param) == 2: 

227 return param 

228 raise ValueError(f"Invalid {name}: {param}") 

229 

230 kernel_h, kernel_w = _parse_param(kernel_size, "kernel_size") 

231 stride_h, stride_w = _parse_param(stride, "stride", default=(kernel_h, kernel_w)) 

232 padding_h, padding_w = _parse_param(padding, "padding", default=(0, 0)) 

233 dilation_h, dilation_w = _parse_param(dilation, "dilation", default=(1, 1)) 

234 

235 if stride_h <= 0 or stride_w <= 0: 

236 raise ValueError( 

237 f"stride must be positive, but got stride=({stride_h}, {stride_w})" 

238 ) 

239 if padding_h < 0 or padding_w < 0: 

240 raise ValueError( 

241 f"padding must be non-negative, but got padding=({padding_h}, {padding_w})" 

242 ) 

243 if dilation_h <= 0 or dilation_w <= 0: 

244 raise ValueError( 

245 f"dilation must be positive, but got dilation=({dilation_h}, {dilation_w})" 

246 ) 

247 

248 return ( 

249 kernel_h, 

250 kernel_w, 

251 stride_h, 

252 stride_w, 

253 padding_h, 

254 padding_w, 

255 dilation_h, 

256 dilation_w, 

257 ) 

258 

259 

260def max_pool2d_with_indices( 

261 input: torch.Tensor, 

262 kernel_size, 

263 stride=None, 

264 padding=0, 

265 dilation=1, 

266 ceil_mode=False, 

267): 

268 logger.debug("GEMS MAX_POOL2D_WITH_INDICES FORWARD") 

269 input = input.contiguous() 

270 

271 params = _parse_pool_params(kernel_size, stride, padding, dilation) 

272 ( 

273 kernel_h, 

274 kernel_w, 

275 stride_h, 

276 stride_w, 

277 padding_h, 

278 padding_w, 

279 dilation_h, 

280 dilation_w, 

281 ) = params 

282 

283 in_n, in_c, in_h, in_w = input.shape 

284 out_h = max_pool2d_output_size( 

285 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode 

286 ) 

287 out_w = max_pool2d_output_size( 

288 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode 

289 ) 

290 

291 output = torch.empty( 

292 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype 

293 ) 

294 indices = torch.empty( 

295 (in_n, in_c, out_h, out_w), device=input.device, dtype=torch.int64 

296 ) 

297 

298 if output.numel() == 0: 

299 return output, indices 

300 

301 grid = lambda meta: ( 

302 in_n * in_c, 

303 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]), 

304 ) 

305 

306 max_pool2d_forward_kernel[grid]( 

307 input, 

308 output, 

309 indices, 

310 input.stride(0), 

311 input.stride(1), 

312 input.stride(2), 

313 input.stride(3), 

314 in_c, 

315 in_h, 

316 in_w, 

317 out_h, 

318 out_w, 

319 kernel_h, 

320 kernel_w, 

321 stride_h, 

322 stride_w, 

323 padding_h, 

324 padding_w, 

325 dilation_h, 

326 dilation_w, 

327 ) 

328 

329 return output, indices 

330 

331 

332def max_pool2d_backward( 

333 grad_output: torch.Tensor, 

334 input: torch.Tensor, 

335 indices: torch.Tensor, 

336 kernel_size, 

337 stride, 

338 padding, 

339 dilation, 

340 ceil_mode, 

341): 

342 logger.debug("GEMS MAX_POOL2D BACKWARD") 

343 grad_output = grad_output.contiguous() 

344 indices = indices.contiguous() 

345 

346 params = _parse_pool_params(kernel_size, stride, padding, dilation) 

347 ( 

348 kernel_h, 

349 kernel_w, 

350 stride_h, 

351 stride_w, 

352 padding_h, 

353 padding_w, 

354 dilation_h, 

355 dilation_w, 

356 ) = params 

357 

358 in_n, in_c, in_h, in_w = input.shape 

359 out_h, out_w = grad_output.shape[2], grad_output.shape[3] 

360 

361 grad_input = torch.zeros_like(input, dtype=torch.float32) 

362 

363 if grad_input.numel() == 0: 

364 return grad_input.to(grad_output.dtype) 

365 

366 grid = lambda meta: ( 

367 in_n * in_c, 

368 triton.cdiv(in_h, meta["BLOCK_IN_H"]) * triton.cdiv(in_w, meta["BLOCK_IN_W"]), 

369 ) 

370 

371 out_stride_nc = out_h * out_w 

372 out_stride_h = out_w 

373 out_stride_w = 1 

374 

375 max_pool2d_backward_kernel[grid]( 

376 grad_output, 

377 indices, 

378 grad_input, 

379 in_h, 

380 in_w, 

381 out_h, 

382 out_w, 

383 out_stride_nc, 

384 out_stride_h, 

385 out_stride_w, 

386 kernel_h, 

387 kernel_w, 

388 stride_h, 

389 stride_w, 

390 padding_h, 

391 padding_w, 

392 dilation_h, 

393 dilation_w, 

394 ) 

395 

396 return grad_input.to(grad_output.dtype)