Coverage for src/flag_gems/ops/avg_pool2d.py: 43%
150 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
9logger = logging.getLogger(__name__)
12def pool2d_output_size(
13 in_size: int,
14 kernel_size: int,
15 stride: int,
16 padding: int,
17 dilation: int,
18 ceil_mode: bool = False,
19) -> int:
20 effective_kernel_size = (kernel_size - 1) * dilation + 1
21 numerator = in_size + 2 * padding - effective_kernel_size
22 if ceil_mode:
23 output_size = (numerator + stride - 1) // stride + 1
24 if (output_size - 1) * stride >= in_size + padding:
25 output_size -= 1
26 else:
27 output_size = numerator // stride + 1
29 return output_size
32@libentry()
33@triton.autotune(
34 configs=[
35 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
36 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
37 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
38 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
39 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2),
40 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2),
41 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2),
42 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8),
43 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8),
44 ],
45 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
46)
47@triton.jit
48def avg_pool2d_forward_kernel(
49 input_ptr,
50 output_ptr,
51 # Input tensor strides
52 in_stride_n,
53 in_stride_c,
54 in_stride_h,
55 in_stride_w,
56 # Input/Output shapes
57 in_c,
58 in_h,
59 in_w,
60 out_h,
61 out_w,
62 # Pooling parameters
63 kernel_h: tl.constexpr,
64 kernel_w: tl.constexpr,
65 stride_h: tl.constexpr,
66 stride_w: tl.constexpr,
67 padding_h: tl.constexpr,
68 padding_w: tl.constexpr,
69 dilation_h: tl.constexpr,
70 dilation_w: tl.constexpr,
71 # AvgPool specific parameters
72 COUNT_INCLUDE_PAD: tl.constexpr,
73 divisor_override,
74 # Tiling meta-parameters
75 BLOCK_H: tl.constexpr,
76 BLOCK_W: tl.constexpr,
77):
78 pid_nc = tl.program_id(0)
79 pid_hw = tl.program_id(1)
80 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
81 h_block_idx = pid_hw // num_w_blocks
82 w_block_idx = pid_hw % num_w_blocks
83 n_idx = pid_nc // in_c
84 c_idx = pid_nc % in_c
86 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
87 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
89 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)
90 count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32)
92 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
94 for kh in range(0, kernel_h):
95 for kw in range(0, kernel_w):
96 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h
97 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w
98 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w)
100 input_offset = h_in * in_stride_h + w_in * in_stride_w
101 current_val = tl.load(
102 input_base_ptr + input_offset, mask=in_mask, other=0.0
103 )
105 sum_acc += tl.where(in_mask, current_val, 0.0)
106 count_acc += in_mask.to(tl.int32)
108 if divisor_override != 0:
109 divisor = tl.full((BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32)
110 elif COUNT_INCLUDE_PAD:
111 divisor = tl.full((BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32)
112 else:
113 divisor = count_acc.to(tl.float32)
115 output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0)
117 out_base_ptr = output_ptr + pid_nc * out_h * out_w
118 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
119 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
120 output_block_ptr = (
121 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
122 )
124 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w)
125 tl.store(
126 output_block_ptr, output_vals.to(output_ptr.type.element_ty), mask=out_mask
127 )
130@libentry()
131@triton.autotune(
132 configs=[
133 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
134 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
135 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
136 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
137 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=2, num_warps=8),
138 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=2, num_warps=8),
139 ],
140 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
141)
142@triton.jit
143def avg_pool2d_backward_kernel(
144 grad_output_ptr,
145 grad_input_ptr,
146 # Input/Output shapes
147 in_c,
148 in_h,
149 in_w,
150 out_h,
151 out_w,
152 # Strides
153 in_stride_n,
154 in_stride_c,
155 in_stride_h,
156 in_stride_w,
157 out_stride_n,
158 out_stride_c,
159 out_stride_h,
160 out_stride_w,
161 # Pooling parameters
162 kernel_h: tl.constexpr,
163 kernel_w: tl.constexpr,
164 stride_h: tl.constexpr,
165 stride_w: tl.constexpr,
166 padding_h: tl.constexpr,
167 padding_w: tl.constexpr,
168 dilation_h: tl.constexpr,
169 dilation_w: tl.constexpr,
170 # AvgPool specific parameters
171 COUNT_INCLUDE_PAD: tl.constexpr,
172 divisor_override,
173 # Tiling meta-parameters
174 BLOCK_H: tl.constexpr,
175 BLOCK_W: tl.constexpr,
176):
177 pid_nc = tl.program_id(0)
178 pid_hw = tl.program_id(1)
180 num_w_blocks = tl.cdiv(in_w, BLOCK_W)
182 h_block_idx = pid_hw // num_w_blocks
183 w_block_idx = pid_hw % num_w_blocks
184 n_idx = pid_nc // in_c
185 c_idx = pid_nc % in_c
187 grad_input_block_ptr = grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
188 grad_output_base_ptr = grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c
190 h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
191 w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
193 grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)
195 for kh_loop in range(kernel_h):
196 for kw_loop in range(kernel_w):
197 h_out_num = h_in_offsets[:, None] + padding_h - kh_loop * dilation_h
198 w_out_num = w_in_offsets[None, :] + padding_w - kw_loop * dilation_w
200 h_valid_map = (h_out_num >= 0) & ((h_out_num % stride_h) == 0)
201 w_valid_map = (w_out_num >= 0) & ((w_out_num % stride_w) == 0)
203 h_out = h_out_num // stride_h
204 w_out = w_out_num // stride_w
206 h_out_mask = h_valid_map & (h_out < out_h)
207 w_out_mask = w_valid_map & (w_out < out_w)
208 out_mask = h_out_mask & w_out_mask
210 if divisor_override != 0:
211 divisor = tl.full(
212 (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32
213 )
214 elif COUNT_INCLUDE_PAD:
215 divisor = tl.full(
216 (BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32
217 )
218 else:
219 h_start = h_out * stride_h - padding_h
220 w_start = w_out * stride_w - padding_w
221 count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32)
222 for kh_count in range(0, kernel_h):
223 for kw_count in range(0, kernel_w):
224 h_in_for_count = h_start + kh_count * dilation_h
225 w_in_for_count = w_start + kw_count * dilation_w
226 is_valid = (
227 (h_in_for_count >= 0)
228 & (h_in_for_count < in_h)
229 & (w_in_for_count >= 0)
230 & (w_in_for_count < in_w)
231 )
232 count += is_valid.to(tl.int32)
233 divisor = count.to(tl.float32)
235 divisor = tl.where(divisor == 0, 1.0, divisor)
237 grad_out_ptr = (
238 grad_output_base_ptr + h_out * out_stride_h + w_out * out_stride_w
239 )
240 grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0)
241 grad_acc += tl.where(out_mask, grad_out_val / divisor, 0.0)
242 # grad_to_add = grad_out_val.to(tl.float32) / divisor.to(tl.float32)
243 # grad_acc += tl.where(out_mask, grad_to_add, 0.0)
245 grad_input_store_ptr = (
246 grad_input_block_ptr
247 + h_in_offsets[:, None] * in_stride_h
248 + w_in_offsets[None, :] * in_stride_w
249 )
250 in_write_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w)
251 tl.store(
252 grad_input_store_ptr,
253 grad_acc.to(grad_input_ptr.type.element_ty),
254 mask=in_write_mask,
255 )
258def _parse_pool_params(kernel_size, stride, padding):
259 if isinstance(kernel_size, int):
260 kernel_h = kernel_w = kernel_size
261 else:
262 kernel_h, kernel_w = kernel_size
264 if stride is None or (isinstance(stride, (list, tuple)) and not stride):
265 stride_h, stride_w = kernel_h, kernel_w
266 elif isinstance(stride, int):
267 stride_h = stride_w = stride
268 else:
269 stride_h, stride_w = stride
271 if isinstance(padding, int):
272 padding_h = padding_w = padding
273 else:
274 padding_h, padding_w = padding
276 if stride_h <= 0 or stride_w <= 0:
277 raise ValueError("stride must be greater than zero")
279 if padding_h < 0 or padding_w < 0:
280 raise ValueError("padding must be non-negative")
282 if padding_h > kernel_h // 2 or padding_w > kernel_w // 2:
283 raise ValueError("pad should be smaller than or equal to half of kernel size")
285 return kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w
288def avg_pool2d(
289 input: torch.Tensor,
290 kernel_size,
291 stride=None,
292 padding=0,
293 ceil_mode=False,
294 count_include_pad=True,
295 divisor_override=None,
296):
297 logger.debug("GEMS AVG_POOL2D FORWARD")
299 if divisor_override is not None and divisor_override == 0:
300 raise ValueError("divisor_override cannot be zero")
302 input = input.contiguous()
304 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params(
305 kernel_size, stride, padding
306 )
307 dilation_h, dilation_w = 1, 1
309 in_n, in_c, in_h, in_w = input.shape
311 out_h = pool2d_output_size(
312 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode
313 )
314 out_w = pool2d_output_size(
315 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode
316 )
318 output = torch.empty(
319 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype
320 )
322 if output.numel() == 0:
323 return output
325 grid = lambda meta: (
326 in_n * in_c,
327 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]),
328 )
330 avg_pool2d_forward_kernel[grid](
331 input,
332 output,
333 input.stride(0),
334 input.stride(1),
335 input.stride(2),
336 input.stride(3),
337 in_c,
338 in_h,
339 in_w,
340 out_h,
341 out_w,
342 kernel_h,
343 kernel_w,
344 stride_h,
345 stride_w,
346 padding_h,
347 padding_w,
348 dilation_h,
349 dilation_w,
350 COUNT_INCLUDE_PAD=count_include_pad,
351 divisor_override=divisor_override if divisor_override is not None else 0.0,
352 )
354 return output
357def avg_pool2d_backward(
358 grad_output: torch.Tensor,
359 input: torch.Tensor,
360 kernel_size,
361 stride,
362 padding,
363 ceil_mode,
364 count_include_pad,
365 divisor_override,
366):
367 logger.debug("GEMS AVG_POOL2D BACKWARD")
369 if divisor_override is not None and divisor_override == 0:
370 raise ValueError("divisor_override cannot be zero")
372 grad_output = grad_output.contiguous()
374 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params(
375 kernel_size, stride, padding
376 )
377 dilation_h, dilation_w = 1, 1
379 in_n, in_c, in_h, in_w = input.shape
380 out_h, out_w = grad_output.shape[2], grad_output.shape[3]
382 grad_input = torch.zeros_like(input, dtype=torch.float32)
384 if grad_output.numel() == 0:
385 return grad_input.to(grad_output.dtype)
387 grid = lambda meta: (
388 in_n * in_c,
389 triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"]),
390 )
392 avg_pool2d_backward_kernel[grid](
393 grad_output,
394 grad_input,
395 in_c,
396 in_h,
397 in_w,
398 out_h,
399 out_w,
400 grad_input.stride(0),
401 grad_input.stride(1),
402 grad_input.stride(2),
403 grad_input.stride(3),
404 grad_output.stride(0),
405 grad_output.stride(1),
406 grad_output.stride(2),
407 grad_output.stride(3),
408 kernel_h,
409 kernel_w,
410 stride_h,
411 stride_w,
412 padding_h,
413 padding_w,
414 dilation_h,
415 dilation_w,
416 COUNT_INCLUDE_PAD=count_include_pad,
417 divisor_override=divisor_override if divisor_override is not None else 0.0,
418 )
420 return grad_input.to(grad_output.dtype)