Coverage for src/flag_gems/ops/max_pool2d_with_indices.py: 48%
141 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
8from flag_gems.utils.limits import get_dtype_min
10logger = logging.getLogger(__name__)
13def max_pool2d_output_size(
14 in_size: int,
15 kernel_size: int,
16 stride: int,
17 padding: int,
18 dilation: int,
19 ceil_mode: bool = False,
20) -> int:
21 effective_kernel_size = (kernel_size - 1) * dilation + 1
22 numerator = in_size + 2 * padding - effective_kernel_size
23 if ceil_mode:
24 output_size = (numerator + stride - 1) // stride + 1
25 # PyTorch-compatible adjustment for ceil_mode
26 if (output_size - 1) * stride >= in_size + padding:
27 output_size -= 1
28 else:
29 output_size = numerator // stride + 1
31 return output_size
34@libentry()
35@triton.autotune(
36 configs=[
37 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
38 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
39 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
40 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
41 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2),
42 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2),
43 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2),
44 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8),
45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8),
46 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=3, num_warps=8),
47 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=3, num_warps=8),
48 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8),
49 ],
50 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
51)
52@triton.jit
53def max_pool2d_forward_kernel(
54 input_ptr,
55 output_ptr,
56 indices_ptr,
57 # Input tensor strides
58 in_stride_n,
59 in_stride_c,
60 in_stride_h,
61 in_stride_w,
62 # Input/Output shapes
63 in_c,
64 in_h,
65 in_w,
66 out_h,
67 out_w,
68 # Pooling parameters
69 kernel_h: tl.constexpr,
70 kernel_w: tl.constexpr,
71 stride_h: tl.constexpr,
72 stride_w: tl.constexpr,
73 padding_h: tl.constexpr,
74 padding_w: tl.constexpr,
75 dilation_h: tl.constexpr,
76 dilation_w: tl.constexpr,
77 # Meta-parameters for tiling
78 BLOCK_H: tl.constexpr,
79 BLOCK_W: tl.constexpr,
80):
81 pid_nc = tl.program_id(0)
82 pid_hw = tl.program_id(1)
83 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
84 h_block_idx = pid_hw // num_w_blocks
85 w_block_idx = pid_hw % num_w_blocks
86 n_idx = pid_nc // in_c
87 c_idx = pid_nc % in_c
89 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
90 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
92 dtype = input_ptr.type.element_ty
93 min_val = get_dtype_min(dtype)
94 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype)
95 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int64)
97 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
99 for kh in tl.static_range(0, kernel_h):
100 for kw in tl.static_range(0, kernel_w):
101 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h
102 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w
103 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w)
104 input_offset = h_in * in_stride_h + w_in * in_stride_w
105 current_val = tl.load(
106 input_base_ptr + input_offset, mask=in_mask, other=min_val
107 )
108 current_idx = h_in * in_w + w_in
110 is_new_max = current_val > max_val_acc
111 max_val_acc = tl.where(is_new_max, current_val, max_val_acc)
112 max_idx_acc = tl.where(is_new_max & in_mask, current_idx, max_idx_acc)
114 out_base_ptr = output_ptr + pid_nc * out_h * out_w
115 indices_base_ptr = indices_ptr + pid_nc * out_h * out_w
116 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
117 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
118 output_block_ptr = (
119 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
120 )
121 indices_block_ptr = (
122 indices_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
123 )
125 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w)
126 tl.store(output_block_ptr, max_val_acc, mask=out_mask)
127 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask)
130@libentry()
131@triton.autotune(
132 configs=[
133 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 16}, num_warps=4),
134 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 8}, num_warps=4),
135 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 32}, num_warps=4),
136 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=8),
137 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 64}, num_warps=8),
138 triton.Config({"BLOCK_IN_H": 64, "BLOCK_IN_W": 16}, num_warps=8),
139 ],
140 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
141)
142@triton.jit
143def max_pool2d_backward_kernel(
144 grad_output_ptr,
145 indices_ptr,
146 grad_input_ptr,
147 # Shape info
148 in_h,
149 in_w,
150 out_h,
151 out_w,
152 # Strides for grad_output/indices
153 out_stride_nc,
154 out_stride_h,
155 out_stride_w,
156 # Pooling parameters
157 kernel_h: tl.constexpr,
158 kernel_w: tl.constexpr,
159 stride_h: tl.constexpr,
160 stride_w: tl.constexpr,
161 padding_h: tl.constexpr,
162 padding_w: tl.constexpr,
163 dilation_h: tl.constexpr,
164 dilation_w: tl.constexpr,
165 # Tiling parameters
166 BLOCK_IN_H: tl.constexpr,
167 BLOCK_IN_W: tl.constexpr,
168):
169 nc_idx = tl.program_id(0)
170 pid_hw = tl.program_id(1)
172 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W)
173 h_block_idx = pid_hw // num_w_blocks
174 w_block_idx = pid_hw % num_w_blocks
176 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H)
177 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W)
179 current_input_flat_idx = h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
180 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32)
182 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc
183 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc
185 for kh in tl.static_range(0, kernel_h):
186 for kw in tl.static_range(0, kernel_w):
187 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h
188 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w
190 valid_map_mask = (numerator_h % stride_h == 0) & (
191 numerator_w % stride_w == 0
192 )
193 h_out = numerator_h // stride_h
194 w_out = numerator_w // stride_w
195 out_bounds_mask = (
196 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w)
197 )
198 load_mask = valid_map_mask & out_bounds_mask
200 safe_h_out = tl.where(load_mask, h_out, 0)
201 safe_w_out = tl.where(load_mask, w_out, 0)
202 out_offsets = safe_h_out * out_stride_h + safe_w_out
204 indices_block = tl.load(
205 indices_base_ptr + out_offsets, mask=load_mask, other=-1
206 )
207 match_mask = indices_block == current_input_flat_idx
209 grad_block = tl.load(
210 grad_output_base_ptr + out_offsets, mask=match_mask, other=0.0
211 )
212 grad_acc += grad_block
214 grad_input_base_ptr = grad_input_ptr + nc_idx * in_h * in_w
215 grad_input_offsets = h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
216 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w)
217 tl.store(grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask)
220def _parse_pool_params(kernel_size, stride, padding, dilation):
221 def _parse_param(param, name, default=None):
222 if param is None:
223 return default
224 if isinstance(param, int):
225 return param, param
226 if isinstance(param, (list, tuple)) and len(param) == 2:
227 return param
228 raise ValueError(f"Invalid {name}: {param}")
230 kernel_h, kernel_w = _parse_param(kernel_size, "kernel_size")
231 stride_h, stride_w = _parse_param(stride, "stride", default=(kernel_h, kernel_w))
232 padding_h, padding_w = _parse_param(padding, "padding", default=(0, 0))
233 dilation_h, dilation_w = _parse_param(dilation, "dilation", default=(1, 1))
235 if stride_h <= 0 or stride_w <= 0:
236 raise ValueError(
237 f"stride must be positive, but got stride=({stride_h}, {stride_w})"
238 )
239 if padding_h < 0 or padding_w < 0:
240 raise ValueError(
241 f"padding must be non-negative, but got padding=({padding_h}, {padding_w})"
242 )
243 if dilation_h <= 0 or dilation_w <= 0:
244 raise ValueError(
245 f"dilation must be positive, but got dilation=({dilation_h}, {dilation_w})"
246 )
248 return (
249 kernel_h,
250 kernel_w,
251 stride_h,
252 stride_w,
253 padding_h,
254 padding_w,
255 dilation_h,
256 dilation_w,
257 )
260def max_pool2d_with_indices(
261 input: torch.Tensor,
262 kernel_size,
263 stride=None,
264 padding=0,
265 dilation=1,
266 ceil_mode=False,
267):
268 logger.debug("GEMS MAX_POOL2D_WITH_INDICES FORWARD")
269 input = input.contiguous()
271 params = _parse_pool_params(kernel_size, stride, padding, dilation)
272 (
273 kernel_h,
274 kernel_w,
275 stride_h,
276 stride_w,
277 padding_h,
278 padding_w,
279 dilation_h,
280 dilation_w,
281 ) = params
283 in_n, in_c, in_h, in_w = input.shape
284 out_h = max_pool2d_output_size(
285 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode
286 )
287 out_w = max_pool2d_output_size(
288 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode
289 )
291 output = torch.empty(
292 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype
293 )
294 indices = torch.empty(
295 (in_n, in_c, out_h, out_w), device=input.device, dtype=torch.int64
296 )
298 if output.numel() == 0:
299 return output, indices
301 grid = lambda meta: (
302 in_n * in_c,
303 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]),
304 )
306 max_pool2d_forward_kernel[grid](
307 input,
308 output,
309 indices,
310 input.stride(0),
311 input.stride(1),
312 input.stride(2),
313 input.stride(3),
314 in_c,
315 in_h,
316 in_w,
317 out_h,
318 out_w,
319 kernel_h,
320 kernel_w,
321 stride_h,
322 stride_w,
323 padding_h,
324 padding_w,
325 dilation_h,
326 dilation_w,
327 )
329 return output, indices
332def max_pool2d_backward(
333 grad_output: torch.Tensor,
334 input: torch.Tensor,
335 indices: torch.Tensor,
336 kernel_size,
337 stride,
338 padding,
339 dilation,
340 ceil_mode,
341):
342 logger.debug("GEMS MAX_POOL2D BACKWARD")
343 grad_output = grad_output.contiguous()
344 indices = indices.contiguous()
346 params = _parse_pool_params(kernel_size, stride, padding, dilation)
347 (
348 kernel_h,
349 kernel_w,
350 stride_h,
351 stride_w,
352 padding_h,
353 padding_w,
354 dilation_h,
355 dilation_w,
356 ) = params
358 in_n, in_c, in_h, in_w = input.shape
359 out_h, out_w = grad_output.shape[2], grad_output.shape[3]
361 grad_input = torch.zeros_like(input, dtype=torch.float32)
363 if grad_input.numel() == 0:
364 return grad_input.to(grad_output.dtype)
366 grid = lambda meta: (
367 in_n * in_c,
368 triton.cdiv(in_h, meta["BLOCK_IN_H"]) * triton.cdiv(in_w, meta["BLOCK_IN_W"]),
369 )
371 out_stride_nc = out_h * out_w
372 out_stride_h = out_w
373 out_stride_w = 1
375 max_pool2d_backward_kernel[grid](
376 grad_output,
377 indices,
378 grad_input,
379 in_h,
380 in_w,
381 out_h,
382 out_w,
383 out_stride_nc,
384 out_stride_h,
385 out_stride_w,
386 kernel_h,
387 kernel_w,
388 stride_h,
389 stride_w,
390 padding_h,
391 padding_w,
392 dilation_h,
393 dilation_w,
394 )
396 return grad_input.to(grad_output.dtype)