Coverage for src/flag_gems/experimental_ops/upsample_nearest3d.py: 0%
121 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def upsample_nearest3d_kernel(
8 in_ptr,
9 out_ptr,
10 N,
11 C,
12 ID,
13 IH,
14 IW,
15 OD,
16 OH,
17 OW,
18 in_stride_n,
19 in_stride_c,
20 in_stride_d,
21 in_stride_h,
22 in_stride_w,
23 out_stride_n,
24 out_stride_c,
25 out_stride_d,
26 out_stride_h,
27 out_stride_w,
28 scale_d,
29 scale_h,
30 scale_w,
31 total_elements,
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid = tl.program_id(axis=0)
35 block_start = pid * BLOCK_SIZE
36 offsets = block_start + tl.arange(0, BLOCK_SIZE)
37 mask = offsets < total_elements
39 # Unravel offsets into (n, c, od, oh, ow) for an output tensor of shape [N, C, OD, OH, OW]
40 ow = offsets % OW
41 tmp = offsets // OW
42 oh = tmp % OH
43 tmp = tmp // OH
44 od = tmp % OD
45 tmp = tmp // OD
46 c = tmp % C
47 n = tmp // C
49 # Compute nearest input indices
50 od_f = od.to(tl.float32)
51 oh_f = oh.to(tl.float32)
52 ow_f = ow.to(tl.float32)
54 id_src = tl.minimum((od_f * scale_d).to(tl.int32), ID - 1)
55 ih_src = tl.minimum((oh_f * scale_h).to(tl.int32), IH - 1)
56 iw_src = tl.minimum((ow_f * scale_w).to(tl.int32), IW - 1)
58 # Compute input/output offsets using strides
59 in_offset = (
60 n * in_stride_n
61 + c * in_stride_c
62 + id_src * in_stride_d
63 + ih_src * in_stride_h
64 + iw_src * in_stride_w
65 )
66 out_offset = (
67 n * out_stride_n
68 + c * out_stride_c
69 + od * out_stride_d
70 + oh * out_stride_h
71 + ow * out_stride_w
72 )
74 vals = tl.load(in_ptr + in_offset, mask=mask, other=0)
75 tl.store(out_ptr + out_offset, vals, mask=mask)
78def _ensure_5d_input(x: torch.Tensor):
79 if x.dim() != 5:
80 raise ValueError(
81 f"Expected 5D input [N, C, D, H, W], but got shape {tuple(x.shape)}"
82 )
83 return x
86def _normalize_output_size(output_size):
87 if output_size is None:
88 return None
89 if isinstance(output_size, torch.Size):
90 output_size = tuple(int(s) for s in output_size)
91 elif isinstance(output_size, (list, tuple)):
92 output_size = tuple(int(s) for s in output_size)
93 else:
94 raise ValueError("output_size must be a sequence of 3 integers or torch.Size")
95 if len(output_size) != 3:
96 raise ValueError("output_size must have length 3: (out_d, out_h, out_w)")
97 return output_size
100def _normalize_scale_factors(scales):
101 if scales is None:
102 return None
103 if isinstance(scales, (list, tuple)):
104 if len(scales) != 3:
105 raise ValueError(
106 "scale_factors must have length 3: (scale_d, scale_h, scale_w)"
107 )
108 return tuple(float(s) if s is not None else None for s in scales)
109 else:
110 raise ValueError("scale_factors must be a sequence of 3 floats")
113def _compute_out_size_and_kernel_scales(ID, IH, IW, output_size, scales_tuple):
114 # Returns (OD, OH, OW, kscale_d, kscale_h, kscale_w)
115 # kscale_* is the multiplier used as: src_idx = floor(out_idx * kscale_*)
116 if output_size is not None:
117 OD, OH, OW = int(output_size[0]), int(output_size[1]), int(output_size[2])
118 if OD <= 0 or OH <= 0 or OW <= 0:
119 raise ValueError("Output sizes must be positive")
120 # When output_size is given, kscale = input_size / output_size
121 kscale_d = float(ID) / float(OD)
122 kscale_h = float(IH) / float(OH)
123 kscale_w = float(IW) / float(OW)
124 else:
125 sd, sh, sw = scales_tuple
126 if sd is None or sh is None or sw is None:
127 raise ValueError(
128 "All scale factors (scale_d, scale_h, scale_w) must be provided when output_size is None"
129 )
130 if sd <= 0.0 or sh <= 0.0 or sw <= 0.0:
131 raise ValueError("Scale factors must be positive")
132 OD = int(torch.floor(torch.tensor(ID * sd)).item())
133 OH = int(torch.floor(torch.tensor(IH * sh)).item())
134 OW = int(torch.floor(torch.tensor(IW * sw)).item())
135 if OD <= 0 or OH <= 0 or OW <= 0:
136 raise ValueError("Computed output sizes must be positive")
137 # When scale_factors are given, src_idx = floor(out_idx / scale) = floor(out_idx * (1/scale))
138 kscale_d = 1.0 / float(sd)
139 kscale_h = 1.0 / float(sh)
140 kscale_w = 1.0 / float(sw)
141 return OD, OH, OW, kscale_d, kscale_h, kscale_w
144def _launch_upsample_nearest3d(
145 input: torch.Tensor,
146 output: torch.Tensor,
147 kscale_d: float,
148 kscale_h: float,
149 kscale_w: float,
150):
151 N, C, ID, IH, IW = input.shape
152 OD, OH, OW = output.shape[2], output.shape[3], output.shape[4]
154 in_strides = input.stride()
155 out_strides = output.stride()
157 total = N * C * OD * OH * OW
158 if total == 0:
159 return output
161 BLOCK_SIZE = 1024
162 grid = lambda meta: (triton.cdiv(total, meta["BLOCK_SIZE"]),)
164 upsample_nearest3d_kernel[grid](
165 input,
166 output,
167 N,
168 C,
169 ID,
170 IH,
171 IW,
172 OD,
173 OH,
174 OW,
175 in_strides[0],
176 in_strides[1],
177 in_strides[2],
178 in_strides[3],
179 in_strides[4],
180 out_strides[0],
181 out_strides[1],
182 out_strides[2],
183 out_strides[3],
184 out_strides[4],
185 float(kscale_d),
186 float(kscale_h),
187 float(kscale_w),
188 total,
189 BLOCK_SIZE=BLOCK_SIZE,
190 )
191 return output
194def upsample_nearest3d(
195 input: torch.Tensor, output_size=None, scales_d=None, scales_h=None, scales_w=None
196):
197 x = _ensure_5d_input(input)
198 output_size = _normalize_output_size(output_size)
199 scales_tuple = None
200 if output_size is None:
201 scales_tuple = (
202 None if scales_d is None else float(scales_d),
203 None if scales_h is None else float(scales_h),
204 None if scales_w is None else float(scales_w),
205 )
206 N, C, ID, IH, IW = x.shape
207 OD, OH, OW, ksd, ksh, ksw = _compute_out_size_and_kernel_scales(
208 ID, IH, IW, output_size, scales_tuple
209 )
210 out = torch.empty(
211 (N, C, OD, OH, OW), dtype=x.dtype, device=x.device, layout=x.layout
212 )
213 return _launch_upsample_nearest3d(x, out, ksd, ksh, ksw)
216def upsample_nearest3d_vec(input: torch.Tensor, output_size=None, scale_factors=None):
217 x = _ensure_5d_input(input)
218 output_size = _normalize_output_size(output_size)
219 scales_tuple = None
220 if output_size is None:
221 scales_tuple = _normalize_scale_factors(scale_factors)
222 N, C, ID, IH, IW = x.shape
223 OD, OH, OW, ksd, ksh, ksw = _compute_out_size_and_kernel_scales(
224 ID, IH, IW, output_size, scales_tuple
225 )
226 out = torch.empty(
227 (N, C, OD, OH, OW), dtype=x.dtype, device=x.device, layout=x.layout
228 )
229 return _launch_upsample_nearest3d(x, out, ksd, ksh, ksw)
232def upsample_nearest3d_out(
233 input: torch.Tensor,
234 output_size=None,
235 scales_d=None,
236 scales_h=None,
237 scales_w=None,
238 out: torch.Tensor = None,
239):
240 x = _ensure_5d_input(input)
241 output_size = _normalize_output_size(output_size)
242 scales_tuple = None
243 if output_size is None:
244 scales_tuple = (
245 None if scales_d is None else float(scales_d),
246 None if scales_h is None else float(scales_h),
247 None if scales_w is None else float(scales_w),
248 )
249 N, C, ID, IH, IW = x.shape
250 OD, OH, OW, ksd, ksh, ksw = _compute_out_size_and_kernel_scales(
251 ID, IH, IW, output_size, scales_tuple
252 )
254 if out is None:
255 raise ValueError("Argument 'out' must be provided for upsample_nearest3d_out")
256 if out.device != x.device or out.dtype != x.dtype:
257 raise ValueError(
258 "Output tensor 'out' must have the same device and dtype as input"
259 )
260 expected_shape = (N, C, OD, OH, OW)
261 if tuple(out.shape) != expected_shape:
262 raise ValueError(
263 f"Output tensor 'out' must have shape {expected_shape}, but got {tuple(out.shape)}"
264 )
266 _launch_upsample_nearest3d(x, out, ksd, ksh, ksw)
267 return out