Coverage for src/flag_gems/runtime/backend/_cambricon/ops/max_pool2d_with_indices.py: 0%
168 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
8from flag_gems.utils.limits import get_dtype_min
10from ..utils import MAX_GRID_SIZE_X, MAX_GRID_SIZE_Y
12logger = logging.getLogger(__name__)
15def max_pool2d_output_size(
16 in_size: int,
17 kernel_size: int,
18 stride: int,
19 padding: int,
20 dilation: int,
21 ceil_mode: bool = False,
22) -> int:
23 effective_kernel_size = (kernel_size - 1) * dilation + 1
24 numerator = in_size + 2 * padding - effective_kernel_size
25 if ceil_mode:
26 output_size = (numerator + stride - 1) // stride + 1
27 # PyTorch-compatible adjustment for ceil_mode
28 if (output_size - 1) * stride >= in_size + padding:
29 output_size -= 1
30 else:
31 output_size = numerator // stride + 1
33 return output_size
36def limit_grid(grid_0, grid_1):
37 grid_0_ub = MAX_GRID_SIZE_X // 4
38 grid_1_ub = MAX_GRID_SIZE_Y
39 return min(grid_0, grid_0_ub), min(grid_1, grid_1_ub)
42@libentry()
43@triton.autotune(
44 configs=[
45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
46 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
47 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
48 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
49 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2),
50 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2),
51 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2),
52 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8),
53 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8),
54 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=3, num_warps=8),
55 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=3, num_warps=8),
56 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8),
57 ],
58 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
59)
60@triton.jit
61def max_pool2d_forward_kernel(
62 input_ptr,
63 output_ptr,
64 indices_ptr,
65 # Input tensor strides
66 in_stride_n,
67 in_stride_c,
68 in_stride_h,
69 in_stride_w,
70 # Input/Output shapes
71 in_c,
72 in_h,
73 in_w,
74 out_h,
75 out_w,
76 # Total number of tasks on axis 0
77 task_num_0,
78 # Pooling parameters
79 kernel_h: tl.constexpr,
80 kernel_w: tl.constexpr,
81 stride_h: tl.constexpr,
82 stride_w: tl.constexpr,
83 padding_h: tl.constexpr,
84 padding_w: tl.constexpr,
85 dilation_h: tl.constexpr,
86 dilation_w: tl.constexpr,
87 # Meta-parameters for tiling
88 BLOCK_H: tl.constexpr,
89 BLOCK_W: tl.constexpr,
90):
91 task_num_1 = tl.cdiv(out_h, BLOCK_H) * tl.cdiv(out_w, BLOCK_W)
92 grid_0 = tl.num_programs(0)
93 grid_1 = tl.num_programs(1)
94 pid_nc = tl.program_id(0)
95 while pid_nc < task_num_0:
96 pid_hw = tl.program_id(1)
97 while pid_hw < task_num_1:
98 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
99 h_block_idx = pid_hw // num_w_blocks
100 w_block_idx = pid_hw % num_w_blocks
101 n_idx = pid_nc // in_c
102 c_idx = pid_nc % in_c
104 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
105 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
107 dtype = input_ptr.type.element_ty
108 min_val = get_dtype_min(dtype)
109 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype)
110 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int64)
112 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
114 for kh in tl.static_range(0, kernel_h):
115 for kw in tl.static_range(0, kernel_w):
116 h_in = (
117 h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h
118 )
119 w_in = (
120 w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w
121 )
122 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w)
123 input_offset = h_in * in_stride_h + w_in * in_stride_w
124 current_val = tl.load(
125 input_base_ptr + input_offset, mask=in_mask, other=min_val
126 )
127 current_idx = h_in * in_w + w_in
129 is_new_max = current_val > max_val_acc
130 max_val_acc = tl.where(is_new_max, current_val, max_val_acc)
131 max_idx_acc = tl.where(
132 is_new_max & in_mask, current_idx, max_idx_acc
133 )
135 out_base_ptr = output_ptr + pid_nc * out_h * out_w
136 indices_base_ptr = indices_ptr + pid_nc * out_h * out_w
137 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
138 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
139 output_block_ptr = (
140 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
141 )
142 indices_block_ptr = (
143 indices_base_ptr
144 + out_h_offsets[:, None] * out_w
145 + out_w_offsets[None, :]
146 )
148 out_mask = (out_h_offsets[:, None] < out_h) & (
149 out_w_offsets[None, :] < out_w
150 )
151 tl.store(output_block_ptr, max_val_acc, mask=out_mask)
152 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask)
153 pid_hw += grid_1
154 pid_nc += grid_0
157@libentry()
158@triton.autotune(
159 configs=[
160 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 16}, num_warps=4),
161 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 8}, num_warps=4),
162 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 32}, num_warps=4),
163 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=8),
164 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 64}, num_warps=8),
165 triton.Config({"BLOCK_IN_H": 64, "BLOCK_IN_W": 16}, num_warps=8),
166 ],
167 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
168)
169@triton.jit
170def max_pool2d_backward_kernel(
171 grad_output_ptr,
172 indices_ptr,
173 grad_input_ptr,
174 # Shape info
175 in_h,
176 in_w,
177 out_h,
178 out_w,
179 # Strides for grad_output/indices
180 out_stride_nc,
181 out_stride_h,
182 out_stride_w,
183 # Total number of tasks on axis 0
184 task_num_0,
185 # Pooling parameters
186 kernel_h: tl.constexpr,
187 kernel_w: tl.constexpr,
188 stride_h: tl.constexpr,
189 stride_w: tl.constexpr,
190 padding_h: tl.constexpr,
191 padding_w: tl.constexpr,
192 dilation_h: tl.constexpr,
193 dilation_w: tl.constexpr,
194 # Tiling parameters
195 BLOCK_IN_H: tl.constexpr,
196 BLOCK_IN_W: tl.constexpr,
197):
198 task_num_1 = tl.cdiv(in_h, BLOCK_IN_H) * tl.cdiv(in_w, BLOCK_IN_W)
199 grid_0 = tl.num_programs(0)
200 grid_1 = tl.num_programs(1)
201 nc_idx = tl.program_id(0)
202 while nc_idx < task_num_0:
203 pid_hw = tl.program_id(1)
204 while pid_hw < task_num_1:
205 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W)
206 h_block_idx = pid_hw // num_w_blocks
207 w_block_idx = pid_hw % num_w_blocks
209 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H)
210 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W)
212 current_input_flat_idx = (
213 h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
214 )
215 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32)
217 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc
218 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc
220 for kh in tl.static_range(0, kernel_h):
221 for kw in tl.static_range(0, kernel_w):
222 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h
223 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w
225 valid_map_mask = (numerator_h % stride_h == 0) & (
226 numerator_w % stride_w == 0
227 )
228 h_out = numerator_h // stride_h
229 w_out = numerator_w // stride_w
230 out_bounds_mask = (
231 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w)
232 )
233 load_mask = valid_map_mask & out_bounds_mask
235 safe_h_out = tl.where(load_mask, h_out, 0)
236 safe_w_out = tl.where(load_mask, w_out, 0)
237 out_offsets = safe_h_out * out_stride_h + safe_w_out
239 indices_block = tl.load(
240 indices_base_ptr + out_offsets, mask=load_mask, other=-1
241 )
242 match_mask = indices_block == current_input_flat_idx
244 grad_block = tl.load(
245 grad_output_base_ptr + out_offsets, mask=match_mask, other=0.0
246 )
247 grad_acc += grad_block
249 grad_input_base_ptr = grad_input_ptr + nc_idx * in_h * in_w
250 grad_input_offsets = h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
251 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w)
252 tl.store(
253 grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask
254 )
255 pid_hw += grid_1
256 nc_idx += grid_0
259def _parse_pool_params(kernel_size, stride, padding, dilation):
260 def _parse_param(param, name, default=None):
261 if param is None:
262 return default
263 if isinstance(param, int):
264 return param, param
265 if isinstance(param, (list, tuple)) and len(param) == 2:
266 return param
267 raise ValueError(f"Invalid {name}: {param}")
269 kernel_h, kernel_w = _parse_param(kernel_size, "kernel_size")
270 stride_h, stride_w = _parse_param(stride, "stride", default=(kernel_h, kernel_w))
271 padding_h, padding_w = _parse_param(padding, "padding", default=(0, 0))
272 dilation_h, dilation_w = _parse_param(dilation, "dilation", default=(1, 1))
274 if stride_h <= 0 or stride_w <= 0:
275 raise ValueError(
276 f"stride must be positive, but got stride=({stride_h}, {stride_w})"
277 )
278 if padding_h < 0 or padding_w < 0:
279 raise ValueError(
280 f"padding must be non-negative, but got padding=({padding_h}, {padding_w})"
281 )
282 if dilation_h <= 0 or dilation_w <= 0:
283 raise ValueError(
284 f"dilation must be positive, but got dilation=({dilation_h}, {dilation_w})"
285 )
287 return (
288 kernel_h,
289 kernel_w,
290 stride_h,
291 stride_w,
292 padding_h,
293 padding_w,
294 dilation_h,
295 dilation_w,
296 )
299def max_pool2d_with_indices(
300 input: torch.Tensor,
301 kernel_size,
302 stride=None,
303 padding=0,
304 dilation=1,
305 ceil_mode=False,
306):
307 logger.debug("GEMS_CAMBRICON MAX_POOL2D_WITH_INDICES FORWARD")
308 input = input.contiguous()
310 params = _parse_pool_params(kernel_size, stride, padding, dilation)
311 (
312 kernel_h,
313 kernel_w,
314 stride_h,
315 stride_w,
316 padding_h,
317 padding_w,
318 dilation_h,
319 dilation_w,
320 ) = params
322 in_n, in_c, in_h, in_w = input.shape
323 out_h = max_pool2d_output_size(
324 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode
325 )
326 out_w = max_pool2d_output_size(
327 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode
328 )
330 output = torch.empty(
331 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype
332 )
333 indices = torch.empty(
334 (in_n, in_c, out_h, out_w), device=input.device, dtype=torch.int64
335 )
337 if output.numel() == 0:
338 return output, indices
340 def grid(meta):
341 grid_0 = in_n * in_c
342 grid_1 = triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(
343 out_w, meta["BLOCK_W"]
344 )
345 return limit_grid(grid_0, grid_1)
347 task_num_0 = in_n * in_c
348 max_pool2d_forward_kernel[grid](
349 input,
350 output,
351 indices,
352 input.stride(0),
353 input.stride(1),
354 input.stride(2),
355 input.stride(3),
356 in_c,
357 in_h,
358 in_w,
359 out_h,
360 out_w,
361 task_num_0,
362 kernel_h,
363 kernel_w,
364 stride_h,
365 stride_w,
366 padding_h,
367 padding_w,
368 dilation_h,
369 dilation_w,
370 )
372 return output, indices
375def max_pool2d_backward(
376 grad_output: torch.Tensor,
377 input: torch.Tensor,
378 indices: torch.Tensor,
379 kernel_size,
380 stride,
381 padding,
382 dilation,
383 ceil_mode,
384):
385 logger.debug("GEMS_CAMBRICON MAX_POOL2D_WITH_INDICES BACKWARD")
386 grad_output = grad_output.contiguous()
387 indices = indices.contiguous()
389 params = _parse_pool_params(kernel_size, stride, padding, dilation)
390 (
391 kernel_h,
392 kernel_w,
393 stride_h,
394 stride_w,
395 padding_h,
396 padding_w,
397 dilation_h,
398 dilation_w,
399 ) = params
401 in_n, in_c, in_h, in_w = input.shape
402 out_h, out_w = grad_output.shape[2], grad_output.shape[3]
404 grad_input = torch.zeros_like(input, dtype=torch.float32)
406 if grad_input.numel() == 0:
407 return grad_input.to(grad_output.dtype)
409 def grid(meta):
410 grid_0 = in_n * in_c
411 grid_1 = triton.cdiv(in_h, meta["BLOCK_IN_H"]) * triton.cdiv(
412 in_w, meta["BLOCK_IN_W"]
413 )
414 return limit_grid(grid_0, grid_1)
416 task_num_0 = in_n * in_c
418 out_stride_nc = out_h * out_w
419 out_stride_h = out_w
420 out_stride_w = 1
422 max_pool2d_backward_kernel[grid](
423 grad_output,
424 indices,
425 grad_input,
426 in_h,
427 in_w,
428 out_h,
429 out_w,
430 out_stride_nc,
431 out_stride_h,
432 out_stride_w,
433 task_num_0,
434 kernel_h,
435 kernel_w,
436 stride_h,
437 stride_w,
438 padding_h,
439 padding_w,
440 dilation_h,
441 dilation_w,
442 )
444 return grad_input.to(grad_output.dtype)