Coverage for src/flag_gems/runtime/backend/_cambricon/ops/avg_pool2d.py: 0%
177 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
9from ..utils import MAX_GRID_SIZE_X, MAX_GRID_SIZE_Y
11logger = logging.getLogger(__name__)
14def pool2d_output_size(
15 in_size: int,
16 kernel_size: int,
17 stride: int,
18 padding: int,
19 dilation: int,
20 ceil_mode: bool = False,
21) -> int:
22 effective_kernel_size = (kernel_size - 1) * dilation + 1
23 numerator = in_size + 2 * padding - effective_kernel_size
24 if ceil_mode:
25 output_size = (numerator + stride - 1) // stride + 1
26 if (output_size - 1) * stride >= in_size + padding:
27 output_size -= 1
28 else:
29 output_size = numerator // stride + 1
31 return output_size
34def limit_grid(grid_0, grid_1):
35 grid_0_ub = MAX_GRID_SIZE_X // 4
36 grid_1_ub = MAX_GRID_SIZE_Y
37 return min(grid_0, grid_0_ub), min(grid_1, grid_1_ub)
40@libentry()
41@triton.autotune(
42 configs=[
43 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
44 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
46 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
47 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2),
48 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2),
49 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2),
50 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8),
51 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8),
52 ],
53 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
54)
55@triton.jit
56def avg_pool2d_forward_kernel(
57 input_ptr,
58 output_ptr,
59 # Input tensor strides
60 in_stride_n,
61 in_stride_c,
62 in_stride_h,
63 in_stride_w,
64 # Input/Output shapes
65 in_c,
66 in_h,
67 in_w,
68 out_h,
69 out_w,
70 # Total number of tasks on axis 0
71 task_num_0,
72 # Pooling parameters
73 kernel_h: tl.constexpr,
74 kernel_w: tl.constexpr,
75 stride_h: tl.constexpr,
76 stride_w: tl.constexpr,
77 padding_h: tl.constexpr,
78 padding_w: tl.constexpr,
79 dilation_h: tl.constexpr,
80 dilation_w: tl.constexpr,
81 # AvgPool specific parameters
82 COUNT_INCLUDE_PAD: tl.constexpr,
83 divisor_override,
84 # Tiling meta-parameters
85 BLOCK_H: tl.constexpr,
86 BLOCK_W: tl.constexpr,
87):
88 task_num_1 = tl.cdiv(out_h, BLOCK_H) * tl.cdiv(out_w, BLOCK_W)
89 grid_0 = tl.num_programs(0)
90 grid_1 = tl.num_programs(1)
91 pid_nc = tl.program_id(0)
92 while pid_nc < task_num_0:
93 pid_hw = tl.program_id(1)
94 while pid_hw < task_num_1:
95 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
96 h_block_idx = pid_hw // num_w_blocks
97 w_block_idx = pid_hw % num_w_blocks
98 n_idx = pid_nc // in_c
99 c_idx = pid_nc % in_c
101 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
102 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
104 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)
105 count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32)
107 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
109 for kh in range(0, kernel_h):
110 for kw in range(0, kernel_w):
111 h_in = (
112 h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h
113 )
114 w_in = (
115 w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w
116 )
117 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w)
119 input_offset = h_in * in_stride_h + w_in * in_stride_w
120 current_val = tl.load(
121 input_base_ptr + input_offset, mask=in_mask, other=0.0
122 )
124 sum_acc += tl.where(in_mask, current_val, 0.0)
125 count_acc += in_mask.to(tl.int32)
127 if divisor_override != 0:
128 divisor = tl.full(
129 (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32
130 )
131 elif COUNT_INCLUDE_PAD:
132 divisor = tl.full(
133 (BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32
134 )
135 else:
136 divisor = count_acc.to(tl.float32)
138 output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0)
140 out_base_ptr = output_ptr + pid_nc * out_h * out_w
141 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
142 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
143 output_block_ptr = (
144 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
145 )
147 out_mask = (out_h_offsets[:, None] < out_h) & (
148 out_w_offsets[None, :] < out_w
149 )
150 tl.store(
151 output_block_ptr,
152 output_vals.to(output_ptr.type.element_ty),
153 mask=out_mask,
154 )
155 pid_hw += grid_1
156 pid_nc += grid_0
159@libentry()
160@triton.autotune(
161 configs=[
162 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
163 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
164 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
165 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
166 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=2, num_warps=8),
167 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=2, num_warps=8),
168 ],
169 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
170)
171@triton.jit
172def avg_pool2d_backward_kernel(
173 grad_output_ptr,
174 grad_input_ptr,
175 # Input/Output shapes
176 in_c,
177 in_h,
178 in_w,
179 out_h,
180 out_w,
181 task_num_0,
182 # Strides
183 in_stride_n,
184 in_stride_c,
185 in_stride_h,
186 in_stride_w,
187 out_stride_n,
188 out_stride_c,
189 out_stride_h,
190 out_stride_w,
191 # Pooling parameters
192 kernel_h: tl.constexpr,
193 kernel_w: tl.constexpr,
194 stride_h: tl.constexpr,
195 stride_w: tl.constexpr,
196 padding_h: tl.constexpr,
197 padding_w: tl.constexpr,
198 dilation_h: tl.constexpr,
199 dilation_w: tl.constexpr,
200 # AvgPool specific parameters
201 COUNT_INCLUDE_PAD: tl.constexpr,
202 divisor_override,
203 # Tiling meta-parameters
204 BLOCK_H: tl.constexpr,
205 BLOCK_W: tl.constexpr,
206):
207 task_num_1 = tl.cdiv(in_h, BLOCK_H) * tl.cdiv(in_w, BLOCK_W)
208 grid_0 = tl.num_programs(0)
209 grid_1 = tl.num_programs(1)
210 pid_nc = tl.program_id(0)
211 while pid_nc < task_num_0:
212 pid_hw = tl.program_id(1)
213 while pid_hw < task_num_1:
214 num_w_blocks = tl.cdiv(in_w, BLOCK_W)
215 h_block_idx = pid_hw // num_w_blocks
216 w_block_idx = pid_hw % num_w_blocks
217 n_idx = pid_nc // in_c
218 c_idx = pid_nc % in_c
220 grad_input_block_ptr = (
221 grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
222 )
223 grad_output_base_ptr = (
224 grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c
225 )
227 h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
228 w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
230 grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)
232 for kh_loop in range(kernel_h):
233 for kw_loop in range(kernel_w):
234 h_out_num = h_in_offsets[:, None] + padding_h - kh_loop * dilation_h
235 w_out_num = w_in_offsets[None, :] + padding_w - kw_loop * dilation_w
237 h_valid_map = (h_out_num >= 0) & ((h_out_num % stride_h) == 0)
238 w_valid_map = (w_out_num >= 0) & ((w_out_num % stride_w) == 0)
240 h_out = h_out_num // stride_h
241 w_out = w_out_num // stride_w
243 h_out_mask = h_valid_map & (h_out < out_h)
244 w_out_mask = w_valid_map & (w_out < out_w)
245 out_mask = h_out_mask & w_out_mask
247 if divisor_override != 0:
248 divisor = tl.full(
249 (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32
250 )
251 elif COUNT_INCLUDE_PAD:
252 divisor = tl.full(
253 (BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32
254 )
255 else:
256 h_start = h_out * stride_h - padding_h
257 w_start = w_out * stride_w - padding_w
258 count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32)
259 for kh_count in range(0, kernel_h):
260 for kw_count in range(0, kernel_w):
261 h_in_for_count = h_start + kh_count * dilation_h
262 w_in_for_count = w_start + kw_count * dilation_w
263 is_valid = (
264 (h_in_for_count >= 0)
265 & (h_in_for_count < in_h)
266 & (w_in_for_count >= 0)
267 & (w_in_for_count < in_w)
268 )
269 count += is_valid.to(tl.int32)
270 divisor = count.to(tl.float32)
272 divisor = tl.where(divisor == 0, 1.0, divisor)
274 grad_out_ptr = (
275 grad_output_base_ptr
276 + h_out * out_stride_h
277 + w_out * out_stride_w
278 )
279 grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0)
280 grad_acc += tl.where(out_mask, grad_out_val / divisor, 0.0)
281 # grad_to_add = grad_out_val.to(tl.float32) / divisor.to(tl.float32)
282 # grad_acc += tl.where(out_mask, grad_to_add, 0.0)
284 grad_input_store_ptr = (
285 grad_input_block_ptr
286 + h_in_offsets[:, None] * in_stride_h
287 + w_in_offsets[None, :] * in_stride_w
288 )
289 in_write_mask = (h_in_offsets[:, None] < in_h) & (
290 w_in_offsets[None, :] < in_w
291 )
292 tl.store(
293 grad_input_store_ptr,
294 grad_acc.to(grad_input_ptr.type.element_ty),
295 mask=in_write_mask,
296 )
297 pid_hw += grid_1
298 pid_nc += grid_0
301def _parse_pool_params(kernel_size, stride, padding):
302 if isinstance(kernel_size, int):
303 kernel_h = kernel_w = kernel_size
304 else:
305 kernel_h, kernel_w = kernel_size
307 if stride is None or (isinstance(stride, (list, tuple)) and not stride):
308 stride_h, stride_w = kernel_h, kernel_w
309 elif isinstance(stride, int):
310 stride_h = stride_w = stride
311 else:
312 stride_h, stride_w = stride
314 if isinstance(padding, int):
315 padding_h = padding_w = padding
316 else:
317 padding_h, padding_w = padding
319 if stride_h <= 0 or stride_w <= 0:
320 raise ValueError("stride must be greater than zero")
322 if padding_h < 0 or padding_w < 0:
323 raise ValueError("padding must be non-negative")
325 if padding_h > kernel_h // 2 or padding_w > kernel_w // 2:
326 raise ValueError("pad should be smaller than or equal to half of kernel size")
328 return kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w
331def avg_pool2d(
332 input: torch.Tensor,
333 kernel_size,
334 stride=None,
335 padding=0,
336 ceil_mode=False,
337 count_include_pad=True,
338 divisor_override=None,
339):
340 logger.debug("GEMS_CAMBRICON AVG_POOL2D FORWARD")
342 if divisor_override is not None and divisor_override == 0:
343 raise ValueError("divisor_override cannot be zero")
345 input = input.contiguous()
347 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params(
348 kernel_size, stride, padding
349 )
350 dilation_h, dilation_w = 1, 1
352 in_n, in_c, in_h, in_w = input.shape
354 out_h = pool2d_output_size(
355 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode
356 )
357 out_w = pool2d_output_size(
358 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode
359 )
361 output = torch.empty(
362 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype
363 )
365 if output.numel() == 0:
366 return output
368 def grid(meta):
369 grid_0 = in_n * in_c
370 grid_1 = triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(
371 out_w, meta["BLOCK_W"]
372 )
373 return limit_grid(grid_0, grid_1)
375 task_num_0 = in_n * in_c
377 avg_pool2d_forward_kernel[grid](
378 input,
379 output,
380 input.stride(0),
381 input.stride(1),
382 input.stride(2),
383 input.stride(3),
384 in_c,
385 in_h,
386 in_w,
387 out_h,
388 out_w,
389 task_num_0,
390 kernel_h,
391 kernel_w,
392 stride_h,
393 stride_w,
394 padding_h,
395 padding_w,
396 dilation_h,
397 dilation_w,
398 COUNT_INCLUDE_PAD=count_include_pad,
399 divisor_override=divisor_override if divisor_override is not None else 0.0,
400 )
402 return output
405def avg_pool2d_backward(
406 grad_output: torch.Tensor,
407 input: torch.Tensor,
408 kernel_size,
409 stride,
410 padding,
411 ceil_mode,
412 count_include_pad,
413 divisor_override,
414):
415 logger.debug("GEMS_CAMBRICON AVG_POOL2D BACKWARD")
417 if divisor_override is not None and divisor_override == 0:
418 raise ValueError("divisor_override cannot be zero")
420 grad_output = grad_output.contiguous()
422 kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params(
423 kernel_size, stride, padding
424 )
425 dilation_h, dilation_w = 1, 1
427 in_n, in_c, in_h, in_w = input.shape
428 out_h, out_w = grad_output.shape[2], grad_output.shape[3]
430 grad_input = torch.zeros_like(input, dtype=torch.float32)
432 if grad_output.numel() == 0:
433 return grad_input.to(grad_output.dtype)
435 def grid(meta):
436 grid_0 = in_n * in_c
437 grid_1 = triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"])
438 return limit_grid(grid_0, grid_1)
440 task_num_0 = in_n * in_c
441 avg_pool2d_backward_kernel[grid](
442 grad_output,
443 grad_input,
444 in_c,
445 in_h,
446 in_w,
447 out_h,
448 out_w,
449 task_num_0,
450 grad_input.stride(0),
451 grad_input.stride(1),
452 grad_input.stride(2),
453 grad_input.stride(3),
454 grad_output.stride(0),
455 grad_output.stride(1),
456 grad_output.stride(2),
457 grad_output.stride(3),
458 kernel_h,
459 kernel_w,
460 stride_h,
461 stride_w,
462 padding_h,
463 padding_w,
464 dilation_h,
465 dilation_w,
466 COUNT_INCLUDE_PAD=count_include_pad,
467 divisor_override=divisor_override if divisor_override is not None else 0.0,
468 )
470 return grad_input.to(grad_output.dtype)