Coverage for src/flag_gems/experimental_ops/_adaptive_avg_pool3d.py: 0%
114 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def adaptive_avg_pool3d_kernel(
8 in_ptr,
9 out_ptr,
10 N,
11 C,
12 D_in,
13 H_in,
14 W_in,
15 D_out,
16 H_out,
17 W_out,
18 stride_in_n,
19 stride_in_c,
20 stride_in_d,
21 stride_in_h,
22 stride_in_w,
23 stride_out_n,
24 stride_out_c,
25 stride_out_d,
26 stride_out_h,
27 stride_out_w,
28):
29 pid = tl.program_id(axis=0)
31 # Unravel pid -> (n, c, d_o, h_o, w_o)
32 W_out_i64 = tl.full((), W_out, tl.int64)
33 H_out_i64 = tl.full((), H_out, tl.int64)
34 D_out_i64 = tl.full((), D_out, tl.int64)
35 C_i64 = tl.full((), C, tl.int64)
37 idx = tl.cast(pid, tl.int64)
38 w_o = idx % W_out_i64
39 idx = idx // W_out_i64
40 h_o = idx % H_out_i64
41 idx = idx // H_out_i64
42 d_o = idx % D_out_i64
43 idx = idx // D_out_i64
44 c = idx % C_i64
45 n = idx // C_i64
47 # Compute start/end indices for each dimension (integer arithmetic)
48 D_in_i64 = tl.full((), D_in, tl.int64)
49 H_in_i64 = tl.full((), H_in, tl.int64)
50 W_in_i64 = tl.full((), W_in, tl.int64)
52 d0 = (d_o * D_in_i64) // D_out_i64
53 d1 = ((d_o + 1) * D_in_i64 + D_out_i64 - 1) // D_out_i64
54 h0 = (h_o * H_in_i64) // H_out_i64
55 h1 = ((h_o + 1) * H_in_i64 + H_out_i64 - 1) // H_out_i64
56 w0 = (w_o * W_in_i64) // W_out_i64
57 w1 = ((w_o + 1) * W_in_i64 + W_out_i64 - 1) // W_out_i64
59 dd = d1 - d0
60 hh = h1 - h0
61 ww = w1 - w0
62 denom = dd * hh * ww
64 # Base offsets and strides (int64)
65 stride_in_n_i64 = tl.full((), stride_in_n, tl.int64)
66 stride_in_c_i64 = tl.full((), stride_in_c, tl.int64)
67 stride_in_d_i64 = tl.full((), stride_in_d, tl.int64)
68 stride_in_h_i64 = tl.full((), stride_in_h, tl.int64)
69 stride_in_w_i64 = tl.full((), stride_in_w, tl.int64)
71 stride_out_n_i64 = tl.full((), stride_out_n, tl.int64)
72 stride_out_c_i64 = tl.full((), stride_out_c, tl.int64)
73 stride_out_d_i64 = tl.full((), stride_out_d, tl.int64)
74 stride_out_h_i64 = tl.full((), stride_out_h, tl.int64)
75 stride_out_w_i64 = tl.full((), stride_out_w, tl.int64)
77 base_nc = n * stride_in_n_i64 + c * stride_in_c_i64
79 acc = tl.zeros((), dtype=tl.float32)
81 di = d0
82 while di < d1:
83 hi = h0
84 while hi < h1:
85 wi = w0
86 while wi < w1:
87 in_idx = (
88 base_nc
89 + di * stride_in_d_i64
90 + hi * stride_in_h_i64
91 + wi * stride_in_w_i64
92 )
93 val = tl.load(in_ptr + in_idx)
94 acc += tl.cast(val, tl.float32)
95 wi += 1
96 hi += 1
97 di += 1
99 denom_f = tl.cast(denom, tl.float32)
100 out_val = acc / denom_f
102 out_idx = (
103 n * stride_out_n_i64
104 + c * stride_out_c_i64
105 + d_o * stride_out_d_i64
106 + h_o * stride_out_h_i64
107 + w_o * stride_out_w_i64
108 )
109 tl.store(out_ptr + out_idx, out_val)
112def _normalize_output_size_3d(output_size):
113 if isinstance(output_size, torch.Size):
114 output_size = tuple(output_size)
115 if isinstance(output_size, (list, tuple)):
116 if len(output_size) != 3:
117 raise ValueError(
118 "output_size for _adaptive_avg_pool3d must have 3 elements (D_out, H_out, W_out)"
119 )
120 return tuple(int(x) for x in output_size)
121 raise TypeError("output_size must be a sequence of three integers")
124def _prepare_5d_input(t):
125 if t.dim() == 5:
126 return t, False
127 if t.dim() == 4:
128 return t.unsqueeze(0), True # add N=1
129 raise ValueError(
130 "input for _adaptive_avg_pool3d must be 4D (C,D,H,W) or 5D (N,C,D,H,W)"
131 )
134def _launch_adaptive_avg_pool3d_kernel(x, out):
135 assert x.is_cuda and out.is_cuda, "Tensors must be CUDA tensors"
136 N, C, D_in, H_in, W_in = x.shape
137 D_out, H_out, W_out = out.shape[-3], out.shape[-2], out.shape[-1]
139 stride_in_n, stride_in_c, stride_in_d, stride_in_h, stride_in_w = x.stride()
140 stride_out_n, stride_out_c, stride_out_d, stride_out_h, stride_out_w = out.stride()
142 total = N * C * D_out * H_out * W_out
143 if total == 0:
144 return
146 grid = (total,)
147 adaptive_avg_pool3d_kernel[grid](
148 x,
149 out,
150 N,
151 C,
152 D_in,
153 H_in,
154 W_in,
155 D_out,
156 H_out,
157 W_out,
158 stride_in_n,
159 stride_in_c,
160 stride_in_d,
161 stride_in_h,
162 stride_in_w,
163 stride_out_n,
164 stride_out_c,
165 stride_out_d,
166 stride_out_h,
167 stride_out_w,
168 num_warps=4,
169 )
172def _adaptive_avg_pool3d(input: torch.Tensor, output_size):
173 x5d, squeezed = _prepare_5d_input(input)
174 D_out, H_out, W_out = _normalize_output_size_3d(output_size)
176 N, C, D_in, H_in, W_in = x5d.shape
177 out_shape_5d = (N, C, D_out, H_out, W_out)
178 out5d = torch.empty(
179 out_shape_5d, device=x5d.device, dtype=x5d.dtype, layout=x5d.layout
180 )
182 _launch_adaptive_avg_pool3d_kernel(x5d, out5d)
184 if squeezed:
185 return out5d.squeeze(0)
186 return out5d
189def _adaptive_avg_pool3d_out(input: torch.Tensor, output_size, out: torch.Tensor):
190 x5d, squeezed = _prepare_5d_input(input)
191 D_out, H_out, W_out = _normalize_output_size_3d(output_size)
193 # Prepare out to be 5D if needed
194 if squeezed:
195 if out.dim() == 4:
196 out5d = out.unsqueeze(0)
197 elif out.dim() == 5 and out.size(0) == 1:
198 out5d = out
199 else:
200 raise ValueError("Provided 'out' must be 4D (C,D,H,W) when input is 4D")
201 else:
202 out5d = out
203 if out5d.dim() != 5:
204 raise ValueError("Provided 'out' must be 5D (N,C,D,H,W) when input is 5D")
206 # Validate shape
207 expected_shape = (x5d.size(0), x5d.size(1), D_out, H_out, W_out)
208 if tuple(out5d.shape) != expected_shape:
209 raise ValueError(
210 f"out has incorrect shape. Expected {expected_shape}, got {tuple(out5d.shape)}"
211 )
213 if out5d.device != x5d.device or out5d.dtype != x5d.dtype:
214 raise ValueError(
215 "out must be on the same device and have the same dtype as input"
216 )
218 _launch_adaptive_avg_pool3d_kernel(x5d, out5d)
220 return out