Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/max_pool2d_with_indices.py: 0%

145 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils.limits import get_dtype_min 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14def max_pool2d_output_size( 

15 in_size: int, 

16 kernel_size: int, 

17 stride: int, 

18 padding: int, 

19 dilation: int, 

20 ceil_mode: bool = False, 

21) -> int: 

22 effective_kernel_size = (kernel_size - 1) * dilation + 1 

23 numerator = in_size + 2 * padding - effective_kernel_size 

24 if ceil_mode: 

25 output_size = (numerator + stride - 1) // stride + 1 

26 # PyTorch-compatible adjustment for ceil_mode 

27 if (output_size - 1) * stride >= in_size + padding: 

28 output_size -= 1 

29 else: 

30 output_size = numerator // stride + 1 

31 

32 return output_size 

33 

34 

35@libentry() 

36@triton.autotune( 

37 configs=[ 

38 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

39 ], 

40 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

41) 

42@triton.jit 

43def max_pool2d_forward_kernel( 

44 input_ptr, 

45 output_ptr, 

46 indices_ptr, 

47 # Input tensor strides 

48 in_stride_n, 

49 in_stride_c, 

50 in_stride_h, 

51 in_stride_w, 

52 # Input/Output shapes 

53 in_c, 

54 in_h, 

55 in_w, 

56 out_h, 

57 out_w, 

58 # Pooling parameters 

59 kernel_h: tl.constexpr, 

60 kernel_w: tl.constexpr, 

61 stride_h: tl.constexpr, 

62 stride_w: tl.constexpr, 

63 padding_h: tl.constexpr, 

64 padding_w: tl.constexpr, 

65 dilation_h: tl.constexpr, 

66 dilation_w: tl.constexpr, 

67 # Meta-parameters for tiling 

68 BLOCK_H: tl.constexpr, 

69 BLOCK_W: tl.constexpr, 

70): 

71 pid_nc = tl.program_id(0) 

72 pid_hw = tl.program_id(1) 

73 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

74 h_block_idx = pid_hw // num_w_blocks 

75 w_block_idx = pid_hw % num_w_blocks 

76 n_idx = pid_nc // in_c 

77 c_idx = pid_nc % in_c 

78 

79 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

80 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

81 

82 dtype = input_ptr.type.element_ty 

83 min_val = get_dtype_min(dtype) 

84 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype) 

85 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int32) 

86 

87 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

88 

89 for kh in tl.static_range(0, kernel_h): 

90 for kw in tl.static_range(0, kernel_w): 

91 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h 

92 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w 

93 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) 

94 input_offset = h_in * in_stride_h + w_in * in_stride_w 

95 current_val = tl.load( 

96 input_base_ptr + input_offset, mask=in_mask, other=min_val 

97 ) 

98 current_idx = h_in * in_w + w_in 

99 

100 is_new_max = current_val > max_val_acc 

101 max_val_acc = tl.where(is_new_max, current_val, max_val_acc) 

102 max_idx_acc = tl.where(is_new_max & in_mask, current_idx, max_idx_acc) 

103 

104 out_base_ptr = output_ptr + pid_nc * out_h * out_w 

105 indices_base_ptr = indices_ptr + pid_nc * out_h * out_w 

106 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

107 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

108 output_block_ptr = ( 

109 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

110 ) 

111 indices_block_ptr = ( 

112 indices_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

113 ) 

114 

115 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w) 

116 tl.store(output_block_ptr, max_val_acc, mask=out_mask) 

117 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask) 

118 

119 

120@libentry() 

121@triton.autotune( 

122 configs=[ 

123 triton.Config({"BLOCK_IN_H": 64, "BLOCK_IN_W": 16}, num_warps=8), 

124 ], 

125 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

126) 

127@triton.jit 

128def max_pool2d_backward_kernel( 

129 grad_output_ptr, 

130 indices_ptr, 

131 grad_input_ptr, 

132 # Shape info 

133 in_c, 

134 in_h, 

135 in_w, 

136 out_h, 

137 out_w, 

138 # Strides for grad_output/indices 

139 out_stride_nc, 

140 out_stride_h, 

141 out_stride_w, 

142 # Pooling parameters 

143 kernel_h: tl.constexpr, 

144 kernel_w: tl.constexpr, 

145 stride_h: tl.constexpr, 

146 stride_w: tl.constexpr, 

147 padding_h: tl.constexpr, 

148 padding_w: tl.constexpr, 

149 dilation_h: tl.constexpr, 

150 dilation_w: tl.constexpr, 

151 # Tiling parameters 

152 BLOCK_IN_H: tl.constexpr, 

153 BLOCK_IN_W: tl.constexpr, 

154): 

155 nc_idx = tl.program_id(0) 

156 pid_hw = tl.program_id(1) 

157 

158 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W) 

159 h_block_idx = pid_hw // num_w_blocks 

160 w_block_idx = pid_hw % num_w_blocks 

161 

162 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H) 

163 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W) 

164 

165 current_input_flat_idx = h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

166 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32) 

167 

168 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc 

169 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc 

170 

171 for kh in tl.static_range(0, kernel_h): 

172 for kw in tl.static_range(0, kernel_w): 

173 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h 

174 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w 

175 

176 valid_map_mask = (numerator_h % stride_h == 0) & ( 

177 numerator_w % stride_w == 0 

178 ) 

179 h_out = numerator_h // stride_h 

180 w_out = numerator_w // stride_w 

181 out_bounds_mask = ( 

182 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w) 

183 ) 

184 load_mask = valid_map_mask & out_bounds_mask 

185 

186 safe_h_out = tl.where(load_mask, h_out, 0) 

187 safe_w_out = tl.where(load_mask, w_out, 0) 

188 out_offsets = safe_h_out * out_stride_h + safe_w_out 

189 

190 indices_block = tl.load( 

191 indices_base_ptr + out_offsets, mask=load_mask, other=-1 

192 ) 

193 match_mask = indices_block == current_input_flat_idx 

194 

195 grad_block = tl.load( 

196 grad_output_base_ptr + out_offsets, mask=match_mask, other=0.0 

197 ) 

198 grad_acc += grad_block 

199 

200 grad_input_base_ptr = grad_input_ptr + nc_idx * in_h * in_w 

201 grad_input_offsets = h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

202 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) 

203 tl.store(grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask) 

204 

205 

206def _parse_pool_params(kernel_size, stride, padding, dilation): 

207 def _parse_param(param, name, default=None): 

208 if param is None: 

209 return default 

210 if isinstance(param, int): 

211 return param, param 

212 if isinstance(param, (list, tuple)) and len(param) == 2: 

213 return param 

214 raise ValueError(f"Invalid {name}: {param}") 

215 

216 kernel_h, kernel_w = _parse_param(kernel_size, "kernel_size") 

217 stride_h, stride_w = _parse_param(stride, "stride", default=(kernel_h, kernel_w)) 

218 padding_h, padding_w = _parse_param(padding, "padding", default=(0, 0)) 

219 dilation_h, dilation_w = _parse_param(dilation, "dilation", default=(1, 1)) 

220 

221 if stride_h <= 0 or stride_w <= 0: 

222 raise ValueError( 

223 f"stride must be positive, but got stride=({stride_h}, {stride_w})" 

224 ) 

225 if padding_h < 0 or padding_w < 0: 

226 raise ValueError( 

227 f"padding must be non-negative, but got padding=({padding_h}, {padding_w})" 

228 ) 

229 if dilation_h <= 0 or dilation_w <= 0: 

230 raise ValueError( 

231 f"dilation must be positive, but got dilation=({dilation_h}, {dilation_w})" 

232 ) 

233 

234 return ( 

235 kernel_h, 

236 kernel_w, 

237 stride_h, 

238 stride_w, 

239 padding_h, 

240 padding_w, 

241 dilation_h, 

242 dilation_w, 

243 ) 

244 

245 

246def max_pool2d_with_indices( 

247 input: torch.Tensor, 

248 kernel_size, 

249 stride=None, 

250 padding=0, 

251 dilation=1, 

252 ceil_mode=False, 

253): 

254 logger.debug("GEMS MAX_POOL2D_WITH_INDICES FORWARD") 

255 input = input.contiguous() 

256 

257 params = _parse_pool_params(kernel_size, stride, padding, dilation) 

258 ( 

259 kernel_h, 

260 kernel_w, 

261 stride_h, 

262 stride_w, 

263 padding_h, 

264 padding_w, 

265 dilation_h, 

266 dilation_w, 

267 ) = params 

268 

269 in_n, in_c, in_h, in_w = input.shape 

270 out_h = max_pool2d_output_size( 

271 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode 

272 ) 

273 out_w = max_pool2d_output_size( 

274 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode 

275 ) 

276 

277 output = torch.empty( 

278 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype 

279 ) 

280 indices = torch.empty( 

281 (in_n, in_c, out_h, out_w), device=input.device, dtype=torch.int32 

282 ) 

283 

284 if output.numel() == 0: 

285 return output, indices 

286 

287 grid = lambda meta: ( 

288 in_n * in_c, 

289 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]), 

290 ) 

291 

292 with torch_device_fn.device(input.device): 

293 max_pool2d_forward_kernel[grid]( 

294 input, 

295 output, 

296 indices, 

297 input.stride(0), 

298 input.stride(1), 

299 input.stride(2), 

300 input.stride(3), 

301 in_c, 

302 in_h, 

303 in_w, 

304 out_h, 

305 out_w, 

306 kernel_h, 

307 kernel_w, 

308 stride_h, 

309 stride_w, 

310 padding_h, 

311 padding_w, 

312 dilation_h, 

313 dilation_w, 

314 ) 

315 

316 return output, indices 

317 

318 

319def max_pool2d_backward( 

320 grad_output: torch.Tensor, 

321 input: torch.Tensor, 

322 indices: torch.Tensor, 

323 kernel_size, 

324 stride, 

325 padding, 

326 dilation, 

327 ceil_mode, 

328): 

329 logger.debug("GEMS MAX_POOL2D BACKWARD") 

330 original_dtype = grad_output.dtype 

331 grad_output = grad_output.to(torch.float32).contiguous() 

332 indices = indices.to(torch.int32).contiguous() 

333 

334 params = _parse_pool_params(kernel_size, stride, padding, dilation) 

335 ( 

336 kernel_h, 

337 kernel_w, 

338 stride_h, 

339 stride_w, 

340 padding_h, 

341 padding_w, 

342 dilation_h, 

343 dilation_w, 

344 ) = params 

345 

346 in_n, in_c, in_h, in_w = input.shape 

347 out_h, out_w = grad_output.shape[2], grad_output.shape[3] 

348 

349 grad_input = torch.zeros_like(input, dtype=torch.float32) 

350 

351 if grad_input.numel() == 0: 

352 return grad_input.to(original_dtype) 

353 

354 grid = lambda meta: ( 

355 in_n * in_c, 

356 triton.cdiv(in_h, meta["BLOCK_IN_H"]) * triton.cdiv(in_w, meta["BLOCK_IN_W"]), 

357 ) 

358 

359 out_stride_nc = out_h * out_w 

360 out_stride_h = out_w 

361 out_stride_w = 1 

362 

363 with torch_device_fn.device(grad_input.device): 

364 max_pool2d_backward_kernel[grid]( 

365 grad_output, 

366 indices, 

367 grad_input, 

368 in_c, 

369 in_h, 

370 in_w, 

371 out_h, 

372 out_w, 

373 out_stride_nc, 

374 out_stride_h, 

375 out_stride_w, 

376 kernel_h, 

377 kernel_w, 

378 stride_h, 

379 stride_w, 

380 padding_h, 

381 padding_w, 

382 dilation_h, 

383 dilation_w, 

384 ) 

385 

386 return grad_input.to(original_dtype)