Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/avg_pool2d.py: 0%
150 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
9logger = logging.getLogger(__name__)
12def pool2d_output_size(
13 in_size: int,
14 kernel_size: int,
15 stride: int,
16 padding: int,
17 dilation: int,
18 ceil_mode: bool = False,
19) -> int:
20 effective_kernel_size = (kernel_size - 1) * dilation + 1
21 numerator = in_size + 2 * padding - effective_kernel_size
22 if ceil_mode:
23 output_size = (numerator + stride - 1) // stride + 1
24 if (output_size - 1) * stride >= in_size + padding:
25 output_size -= 1
26 else:
27 output_size = numerator // stride + 1
29 return output_size
32@libentry()
33@triton.autotune(
34 configs=[
35 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8),
36 ],
37 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
38)
39@triton.jit
40def avg_pool2d_forward_kernel(
41 input_ptr,
42 output_ptr,
43 # Input tensor strides
44 in_stride_n,
45 in_stride_c,
46 in_stride_h,
47 in_stride_w,
48 # Input/Output shapes
49 in_c,
50 in_h,
51 in_w,
52 out_h,
53 out_w,
54 # Pooling parameters
55 kernel_h: tl.constexpr,
56 kernel_w: tl.constexpr,
57 stride_h: tl.constexpr,
58 stride_w: tl.constexpr,
59 padding_h: tl.constexpr,
60 padding_w: tl.constexpr,
61 dilation_h: tl.constexpr,
62 dilation_w: tl.constexpr,
63 # AvgPool specific parameters
64 COUNT_INCLUDE_PAD: tl.constexpr,
65 divisor_override,
66 # Tiling meta-parameters
67 BLOCK_H: tl.constexpr,
68 BLOCK_W: tl.constexpr,
69):
70 pid_nc = tl.program_id(0)
71 pid_hw = tl.program_id(1)
72 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
73 h_block_idx = pid_hw // num_w_blocks
74 w_block_idx = pid_hw % num_w_blocks
75 n_idx = pid_nc // in_c
76 c_idx = pid_nc % in_c
78 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
79 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
81 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)
82 count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32)
84 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
86 for kh in tl.static_range(0, kernel_h):
87 for kw in tl.static_range(0, kernel_w):
88 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h
89 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w
90 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w)
92 input_offset = h_in * in_stride_h + w_in * in_stride_w
93 current_val = tl.load(
94 input_base_ptr + input_offset, mask=in_mask, other=0.0
95 )
97 sum_acc += tl.where(in_mask, current_val, 0.0)
98 count_acc += in_mask.to(tl.int32)
100 count_divisor = count_acc.to(tl.float32)
102 if COUNT_INCLUDE_PAD:
103 default_divisor = tl.where(
104 count_divisor >= 0, float(kernel_h * kernel_w), count_divisor
105 )
106 else:
107 default_divisor = count_divisor
109 divisor = tl.where(
110 divisor_override != 0, divisor_override + default_divisor * 0, default_divisor
111 )
113 output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0)
115 out_base_ptr = output_ptr + pid_nc * out_h * out_w
116 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
117 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
118 output_block_ptr = (
119 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
120 )
122 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w)
123 tl.store(
124 output_block_ptr, output_vals.to(output_ptr.type.element_ty), mask=out_mask
125 )
128@libentry()
129@triton.autotune(
130 configs=[
131 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_warps=8),
132 ],
133 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
134)
135@triton.jit
136def avg_pool2d_backward_kernel(
137 grad_output_ptr,
138 grad_input_ptr,
139 # Input/Output shapes
140 in_c,
141 in_h,
142 in_w,
143 out_h,
144 out_w,
145 # Strides
146 in_stride_n,
147 in_stride_c,
148 in_stride_h,
149 in_stride_w,
150 out_stride_n,
151 out_stride_c,
152 out_stride_h,
153 out_stride_w,
154 # Pooling parameters
155 kernel_h: tl.constexpr,
156 kernel_w: tl.constexpr,
157 stride_h: tl.constexpr,
158 stride_w: tl.constexpr,
159 padding_h: tl.constexpr,
160 padding_w: tl.constexpr,
161 dilation_h: tl.constexpr,
162 dilation_w: tl.constexpr,
163 # AvgPool specific parameters
164 COUNT_INCLUDE_PAD: tl.constexpr,
165 divisor_override,
166 # Tiling meta-parameters
167 BLOCK_H: tl.constexpr,
168 BLOCK_W: tl.constexpr,
169):
170 pid_nc = tl.program_id(0)
171 pid_hw = tl.program_id(1)
173 num_w_blocks = tl.cdiv(in_w, BLOCK_W)
175 h_block_idx = pid_hw // num_w_blocks
176 w_block_idx = pid_hw % num_w_blocks
177 n_idx = pid_nc // in_c
178 c_idx = pid_nc % in_c
180 grad_input_block_ptr = grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
181 grad_output_base_ptr = grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c
183 h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
184 w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
186 grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)
188 for kh_loop in tl.static_range(0, kernel_h):
189 for kw_loop in tl.static_range(0, kernel_w):
190 h_out_num = h_in_offsets[:, None] + padding_h - kh_loop * dilation_h
191 w_out_num = w_in_offsets[None, :] + padding_w - kw_loop * dilation_w
193 h_valid_map = (h_out_num >= 0) & ((h_out_num % stride_h) == 0)
194 w_valid_map = (w_out_num >= 0) & ((w_out_num % stride_w) == 0)
196 h_out = h_out_num // stride_h
197 w_out = w_out_num // stride_w
199 h_out_mask = h_valid_map & (h_out < out_h)
200 w_out_mask = w_valid_map & (w_out < out_w)
201 out_mask = h_out_mask & w_out_mask
203 # Compute count for this output position (for count_include_pad=False)
204 h_start = h_out * stride_h - padding_h
205 w_start = w_out * stride_w - padding_w
206 count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32)
207 for kh_count in tl.static_range(0, kernel_h):
208 for kw_count in tl.static_range(0, kernel_w):
209 h_in_for_count = h_start + kh_count * dilation_h
210 w_in_for_count = w_start + kw_count * dilation_w
211 is_valid = (
212 (h_in_for_count >= 0)
213 & (h_in_for_count < in_h)
214 & (w_in_for_count >= 0)
215 & (w_in_for_count < in_w)
216 )
217 count += is_valid.to(tl.int32)
219 count_divisor = count.to(tl.float32)
221 if COUNT_INCLUDE_PAD:
222 default_divisor = tl.where(
223 count_divisor >= 0, float(kernel_h * kernel_w), count_divisor
224 )
225 else:
226 default_divisor = count_divisor
228 divisor = tl.where(
229 divisor_override != 0,
230 divisor_override + default_divisor * 0,
231 default_divisor,
232 )
233 divisor = tl.where(divisor == 0, 1.0, divisor)
235 grad_out_ptr = (
236 grad_output_base_ptr + h_out * out_stride_h + w_out * out_stride_w
237 )
238 grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0)
239 grad_acc += tl.where(out_mask, grad_out_val / divisor, 0.0)
241 grad_input_store_ptr = (
242 grad_input_block_ptr
243 + h_in_offsets[:, None] * in_stride_h
244 + w_in_offsets[None, :] * in_stride_w
245 )
246 in_write_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w)
247 tl.store(
248 grad_input_store_ptr,
249 grad_acc.to(grad_input_ptr.type.element_ty),
250 mask=in_write_mask,
251 )
254def _parse_pool_params(kernel_size, stride, padding):
255 if isinstance(kernel_size, int):
256 kernel_h = kernel_w = kernel_size
257 else:
258 kernel_h, kernel_w = kernel_size
260 if stride is None or (isinstance(stride, (list, tuple)) and not stride):
261 stride_h, stride_w = kernel_h, kernel_w
262 elif isinstance(stride, int):
263 stride_h = stride_w = stride
264 else:
265 stride_h, stride_w = stride
267 if isinstance(padding, int):
268 padding_h = padding_w = padding
269 else:
270 padding_h, padding_w = padding
272 if stride_h <= 0 or stride_w <= 0:
273 raise ValueError("stride must be greater than zero")
275 if padding_h < 0 or padding_w < 0:
276 raise ValueError("padding must be non-negative")
278 if padding_h > kernel_h // 2 or padding_w > kernel_w // 2:
279 raise ValueError("pad should be smaller than or equal to half of kernel size")
281 return kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w
284def avg_pool2d(
285 input: torch.Tensor,
286 kernel_size,
287 stride=None,
288 padding=0,
289 ceil_mode=False,
290 count_include_pad=True,
291 divisor_override=None,
292):
293 logger.debug("GEMS AVG_POOL2D FORWARD")
295 if divisor_override is not None and divisor_override == 0:
296 raise ValueError("divisor_override cannot be zero")
298 input = input.contiguous()
300 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params(
301 kernel_size, stride, padding
302 )
303 dilation_h, dilation_w = 1, 1
305 in_n, in_c, in_h, in_w = input.shape
307 out_h = pool2d_output_size(
308 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode
309 )
310 out_w = pool2d_output_size(
311 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode
312 )
314 output = torch.empty(
315 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype
316 )
318 if output.numel() == 0:
319 return output
321 grid = lambda meta: (
322 in_n * in_c,
323 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]),
324 )
326 avg_pool2d_forward_kernel[grid](
327 input,
328 output,
329 input.stride(0),
330 input.stride(1),
331 input.stride(2),
332 input.stride(3),
333 in_c,
334 in_h,
335 in_w,
336 out_h,
337 out_w,
338 kernel_h,
339 kernel_w,
340 stride_h,
341 stride_w,
342 padding_h,
343 padding_w,
344 dilation_h,
345 dilation_w,
346 COUNT_INCLUDE_PAD=count_include_pad,
347 divisor_override=divisor_override if divisor_override is not None else 0.0,
348 )
350 return output
353def avg_pool2d_backward(
354 grad_output: torch.Tensor,
355 input: torch.Tensor,
356 kernel_size,
357 stride,
358 padding,
359 ceil_mode,
360 count_include_pad,
361 divisor_override,
362):
363 logger.debug("GEMS AVG_POOL2D BACKWARD")
365 if divisor_override is not None and divisor_override == 0:
366 raise ValueError("divisor_override cannot be zero")
368 grad_output = grad_output.contiguous()
370 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params(
371 kernel_size, stride, padding
372 )
373 dilation_h, dilation_w = 1, 1
375 in_n, in_c, in_h, in_w = input.shape
376 out_h, out_w = grad_output.shape[2], grad_output.shape[3]
378 grad_input = torch.zeros_like(input, dtype=torch.float32)
380 if grad_output.numel() == 0:
381 return grad_input.to(grad_output.dtype)
383 grid = lambda meta: (
384 in_n * in_c,
385 triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"]),
386 )
388 avg_pool2d_backward_kernel[grid](
389 grad_output,
390 grad_input,
391 in_c,
392 in_h,
393 in_w,
394 out_h,
395 out_w,
396 grad_input.stride(0),
397 grad_input.stride(1),
398 grad_input.stride(2),
399 grad_input.stride(3),
400 grad_output.stride(0),
401 grad_output.stride(1),
402 grad_output.stride(2),
403 grad_output.stride(3),
404 kernel_h,
405 kernel_w,
406 stride_h,
407 stride_w,
408 padding_h,
409 padding_w,
410 dilation_h,
411 dilation_w,
412 COUNT_INCLUDE_PAD=count_include_pad,
413 divisor_override=divisor_override if divisor_override is not None else 0.0,
414 )
416 return grad_input.to(grad_output.dtype)