Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/max_pool2d_with_indices.py: 0%
145 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils.limits import get_dtype_min
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14def max_pool2d_output_size(
15 in_size: int,
16 kernel_size: int,
17 stride: int,
18 padding: int,
19 dilation: int,
20 ceil_mode: bool = False,
21) -> int:
22 effective_kernel_size = (kernel_size - 1) * dilation + 1
23 numerator = in_size + 2 * padding - effective_kernel_size
24 if ceil_mode:
25 output_size = (numerator + stride - 1) // stride + 1
26 # PyTorch-compatible adjustment for ceil_mode
27 if (output_size - 1) * stride >= in_size + padding:
28 output_size -= 1
29 else:
30 output_size = numerator // stride + 1
32 return output_size
35@libentry()
36@triton.autotune(
37 configs=[
38 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8),
39 ],
40 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
41)
42@triton.jit
43def max_pool2d_forward_kernel(
44 input_ptr,
45 output_ptr,
46 indices_ptr,
47 # Input tensor strides
48 in_stride_n,
49 in_stride_c,
50 in_stride_h,
51 in_stride_w,
52 # Input/Output shapes
53 in_c,
54 in_h,
55 in_w,
56 out_h,
57 out_w,
58 # Pooling parameters
59 kernel_h: tl.constexpr,
60 kernel_w: tl.constexpr,
61 stride_h: tl.constexpr,
62 stride_w: tl.constexpr,
63 padding_h: tl.constexpr,
64 padding_w: tl.constexpr,
65 dilation_h: tl.constexpr,
66 dilation_w: tl.constexpr,
67 # Meta-parameters for tiling
68 BLOCK_H: tl.constexpr,
69 BLOCK_W: tl.constexpr,
70):
71 pid_nc = tl.program_id(0)
72 pid_hw = tl.program_id(1)
73 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
74 h_block_idx = pid_hw // num_w_blocks
75 w_block_idx = pid_hw % num_w_blocks
76 n_idx = pid_nc // in_c
77 c_idx = pid_nc % in_c
79 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
80 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
82 dtype = input_ptr.type.element_ty
83 min_val = get_dtype_min(dtype)
84 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype)
85 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int32)
87 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
89 for kh in tl.static_range(0, kernel_h):
90 for kw in tl.static_range(0, kernel_w):
91 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h
92 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w
93 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w)
94 input_offset = h_in * in_stride_h + w_in * in_stride_w
95 current_val = tl.load(
96 input_base_ptr + input_offset, mask=in_mask, other=min_val
97 )
98 current_idx = h_in * in_w + w_in
100 is_new_max = current_val > max_val_acc
101 max_val_acc = tl.where(is_new_max, current_val, max_val_acc)
102 max_idx_acc = tl.where(is_new_max & in_mask, current_idx, max_idx_acc)
104 out_base_ptr = output_ptr + pid_nc * out_h * out_w
105 indices_base_ptr = indices_ptr + pid_nc * out_h * out_w
106 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
107 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
108 output_block_ptr = (
109 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
110 )
111 indices_block_ptr = (
112 indices_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
113 )
115 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w)
116 tl.store(output_block_ptr, max_val_acc, mask=out_mask)
117 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask)
120@libentry()
121@triton.autotune(
122 configs=[
123 triton.Config({"BLOCK_IN_H": 64, "BLOCK_IN_W": 16}, num_warps=8),
124 ],
125 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
126)
127@triton.jit
128def max_pool2d_backward_kernel(
129 grad_output_ptr,
130 indices_ptr,
131 grad_input_ptr,
132 # Shape info
133 in_c,
134 in_h,
135 in_w,
136 out_h,
137 out_w,
138 # Strides for grad_output/indices
139 out_stride_nc,
140 out_stride_h,
141 out_stride_w,
142 # Pooling parameters
143 kernel_h: tl.constexpr,
144 kernel_w: tl.constexpr,
145 stride_h: tl.constexpr,
146 stride_w: tl.constexpr,
147 padding_h: tl.constexpr,
148 padding_w: tl.constexpr,
149 dilation_h: tl.constexpr,
150 dilation_w: tl.constexpr,
151 # Tiling parameters
152 BLOCK_IN_H: tl.constexpr,
153 BLOCK_IN_W: tl.constexpr,
154):
155 nc_idx = tl.program_id(0)
156 pid_hw = tl.program_id(1)
158 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W)
159 h_block_idx = pid_hw // num_w_blocks
160 w_block_idx = pid_hw % num_w_blocks
162 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H)
163 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W)
165 current_input_flat_idx = h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
166 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32)
168 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc
169 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc
171 for kh in tl.static_range(0, kernel_h):
172 for kw in tl.static_range(0, kernel_w):
173 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h
174 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w
176 valid_map_mask = (numerator_h % stride_h == 0) & (
177 numerator_w % stride_w == 0
178 )
179 h_out = numerator_h // stride_h
180 w_out = numerator_w // stride_w
181 out_bounds_mask = (
182 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w)
183 )
184 load_mask = valid_map_mask & out_bounds_mask
186 safe_h_out = tl.where(load_mask, h_out, 0)
187 safe_w_out = tl.where(load_mask, w_out, 0)
188 out_offsets = safe_h_out * out_stride_h + safe_w_out
190 indices_block = tl.load(
191 indices_base_ptr + out_offsets, mask=load_mask, other=-1
192 )
193 match_mask = indices_block == current_input_flat_idx
195 grad_block = tl.load(
196 grad_output_base_ptr + out_offsets, mask=match_mask, other=0.0
197 )
198 grad_acc += grad_block
200 grad_input_base_ptr = grad_input_ptr + nc_idx * in_h * in_w
201 grad_input_offsets = h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
202 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w)
203 tl.store(grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask)
206def _parse_pool_params(kernel_size, stride, padding, dilation):
207 def _parse_param(param, name, default=None):
208 if param is None:
209 return default
210 if isinstance(param, int):
211 return param, param
212 if isinstance(param, (list, tuple)) and len(param) == 2:
213 return param
214 raise ValueError(f"Invalid {name}: {param}")
216 kernel_h, kernel_w = _parse_param(kernel_size, "kernel_size")
217 stride_h, stride_w = _parse_param(stride, "stride", default=(kernel_h, kernel_w))
218 padding_h, padding_w = _parse_param(padding, "padding", default=(0, 0))
219 dilation_h, dilation_w = _parse_param(dilation, "dilation", default=(1, 1))
221 if stride_h <= 0 or stride_w <= 0:
222 raise ValueError(
223 f"stride must be positive, but got stride=({stride_h}, {stride_w})"
224 )
225 if padding_h < 0 or padding_w < 0:
226 raise ValueError(
227 f"padding must be non-negative, but got padding=({padding_h}, {padding_w})"
228 )
229 if dilation_h <= 0 or dilation_w <= 0:
230 raise ValueError(
231 f"dilation must be positive, but got dilation=({dilation_h}, {dilation_w})"
232 )
234 return (
235 kernel_h,
236 kernel_w,
237 stride_h,
238 stride_w,
239 padding_h,
240 padding_w,
241 dilation_h,
242 dilation_w,
243 )
246def max_pool2d_with_indices(
247 input: torch.Tensor,
248 kernel_size,
249 stride=None,
250 padding=0,
251 dilation=1,
252 ceil_mode=False,
253):
254 logger.debug("GEMS MAX_POOL2D_WITH_INDICES FORWARD")
255 input = input.contiguous()
257 params = _parse_pool_params(kernel_size, stride, padding, dilation)
258 (
259 kernel_h,
260 kernel_w,
261 stride_h,
262 stride_w,
263 padding_h,
264 padding_w,
265 dilation_h,
266 dilation_w,
267 ) = params
269 in_n, in_c, in_h, in_w = input.shape
270 out_h = max_pool2d_output_size(
271 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode
272 )
273 out_w = max_pool2d_output_size(
274 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode
275 )
277 output = torch.empty(
278 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype
279 )
280 indices = torch.empty(
281 (in_n, in_c, out_h, out_w), device=input.device, dtype=torch.int32
282 )
284 if output.numel() == 0:
285 return output, indices
287 grid = lambda meta: (
288 in_n * in_c,
289 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]),
290 )
292 with torch_device_fn.device(input.device):
293 max_pool2d_forward_kernel[grid](
294 input,
295 output,
296 indices,
297 input.stride(0),
298 input.stride(1),
299 input.stride(2),
300 input.stride(3),
301 in_c,
302 in_h,
303 in_w,
304 out_h,
305 out_w,
306 kernel_h,
307 kernel_w,
308 stride_h,
309 stride_w,
310 padding_h,
311 padding_w,
312 dilation_h,
313 dilation_w,
314 )
316 return output, indices
319def max_pool2d_backward(
320 grad_output: torch.Tensor,
321 input: torch.Tensor,
322 indices: torch.Tensor,
323 kernel_size,
324 stride,
325 padding,
326 dilation,
327 ceil_mode,
328):
329 logger.debug("GEMS MAX_POOL2D BACKWARD")
330 original_dtype = grad_output.dtype
331 grad_output = grad_output.to(torch.float32).contiguous()
332 indices = indices.to(torch.int32).contiguous()
334 params = _parse_pool_params(kernel_size, stride, padding, dilation)
335 (
336 kernel_h,
337 kernel_w,
338 stride_h,
339 stride_w,
340 padding_h,
341 padding_w,
342 dilation_h,
343 dilation_w,
344 ) = params
346 in_n, in_c, in_h, in_w = input.shape
347 out_h, out_w = grad_output.shape[2], grad_output.shape[3]
349 grad_input = torch.zeros_like(input, dtype=torch.float32)
351 if grad_input.numel() == 0:
352 return grad_input.to(original_dtype)
354 grid = lambda meta: (
355 in_n * in_c,
356 triton.cdiv(in_h, meta["BLOCK_IN_H"]) * triton.cdiv(in_w, meta["BLOCK_IN_W"]),
357 )
359 out_stride_nc = out_h * out_w
360 out_stride_h = out_w
361 out_stride_w = 1
363 with torch_device_fn.device(grad_input.device):
364 max_pool2d_backward_kernel[grid](
365 grad_output,
366 indices,
367 grad_input,
368 in_c,
369 in_h,
370 in_w,
371 out_h,
372 out_w,
373 out_stride_nc,
374 out_stride_h,
375 out_stride_w,
376 kernel_h,
377 kernel_w,
378 stride_h,
379 stride_w,
380 padding_h,
381 padding_w,
382 dilation_h,
383 dilation_w,
384 )
386 return grad_input.to(original_dtype)