Coverage for src/flag_gems/experimental_ops/_upsample_nearest_exact1d.py: 0%
128 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import math
3import torch
4import triton
5import triton.language as tl
8@triton.jit
9def _upsample_nearest_exact1d_kernel(
10 in_ptr,
11 out_ptr,
12 N,
13 C,
14 IW,
15 OW,
16 sN_in,
17 sC_in,
18 sW_in,
19 sN_out,
20 sC_out,
21 sW_out,
22 use_scales: tl.constexpr,
23 scale_w,
24 BLOCK_W: tl.constexpr,
25):
26 pid_w = tl.program_id(0)
27 pid_nc = tl.program_id(1)
29 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
30 mask = offs_w < OW
32 # Compute (n, c) from flattened plane index
33 nc = pid_nc
34 n = nc // C
35 c = nc - n * C
37 base_in = n * sN_in + c * sC_in
38 base_out = n * sN_out + c * sC_out
40 # Compute source indices iw for each output index ow
41 iw = tl.zeros([BLOCK_W], dtype=tl.int32)
42 if use_scales:
43 ow_f = offs_w.to(tl.float32)
44 iw_f = tl.floor(ow_f / scale_w)
45 iw = iw_f.to(tl.int32)
46 else:
47 iw = (offs_w * IW) // OW
48 iw = tl.minimum(iw, IW - 1)
50 in_ptrs = in_ptr + base_in + iw * sW_in
51 x = tl.load(in_ptrs, mask=mask)
53 out_ptrs = out_ptr + base_out + offs_w * sW_out
54 tl.store(out_ptrs, x, mask=mask)
57def _parse_size_1d(val):
58 if val is None:
59 return None
60 if isinstance(val, torch.Size):
61 return int(val[-1]) if len(val) > 0 else None
62 if isinstance(val, (list, tuple)):
63 if len(val) == 0:
64 return None
65 return int(val[-1])
66 return int(val)
69def _parse_scale_1d(val):
70 if val is None:
71 return None
72 if isinstance(val, (list, tuple)):
73 if len(val) == 0:
74 return None
75 return float(val[-1])
76 return float(val)
79def _compute_out_w(iw, output_size, scale):
80 if output_size is not None:
81 return int(output_size)
82 if scale is None:
83 raise ValueError(
84 "Either output_size or scale must be provided for _upsample_nearest_exact1d."
85 )
86 # Follow common convention: OW = floor(IW * scale)
87 return int(math.floor(iw * scale))
90def _launch_upsample_nearest_exact1d_kernel(input, out, output_size=None, scale=None):
91 if input.ndim != 3:
92 raise ValueError(
93 f"_upsample_nearest_exact1d expects a 3D tensor (N, C, W); got shape {tuple(input.shape)}"
94 )
95 if not input.is_cuda or not out.is_cuda:
96 # Fallback to the native operator on CPU or non-CUDA devices
97 return torch.ops.aten._upsample_nearest_exact1d(
98 input, [out.shape[-1]], [scale] if scale is not None else None
99 )
101 N, C, IW = input.shape
102 OW = out.shape[-1]
104 sN_in, sC_in, sW_in = input.stride()
105 sN_out, sC_out, sW_out = out.stride()
107 BLOCK_W = 256
108 grid = (triton.cdiv(OW, BLOCK_W), N * C)
110 use_scales = scale is not None and output_size is None
111 scale_w = float(scale) if use_scales else 1.0
113 _upsample_nearest_exact1d_kernel[grid](
114 input,
115 out,
116 N,
117 C,
118 IW,
119 OW,
120 sN_in,
121 sC_in,
122 sW_in,
123 sN_out,
124 sC_out,
125 sW_out,
126 use_scales=use_scales,
127 scale_w=scale_w,
128 BLOCK_W=BLOCK_W,
129 )
130 return out
133def _extract_io_and_params(args, kwargs, expect_out=False):
134 # Extract input tensor
135 in_t = kwargs.get("input", None)
136 if in_t is None:
137 in_t = kwargs.get("self", None)
138 if in_t is None and len(args) > 0 and isinstance(args[0], torch.Tensor):
139 in_t = args[0]
140 args = args[1:]
141 if in_t is None or not isinstance(in_t, torch.Tensor):
142 raise ValueError("Input tensor not found for _upsample_nearest_exact1d.")
144 # Extract output_size / scales from kwargs or remaining args
145 output_size = kwargs.get(
146 "output_size", kwargs.get("size", kwargs.get("output_size_list", None))
147 )
148 scales = kwargs.get(
149 "scale_factor",
150 kwargs.get("scales", kwargs.get("scale_factors", kwargs.get("scale", None))),
151 )
153 # If positional arguments contain size and/or scales
154 # Try to interpret next positional as output_size if present and not a tensor
155 pos = 0
156 if (
157 output_size is None
158 and pos < len(args)
159 and not isinstance(args[pos], torch.Tensor)
160 ):
161 output_size = args[pos]
162 pos += 1
163 if scales is None and pos < len(args) and not isinstance(args[pos], torch.Tensor):
164 scales = args[pos]
165 pos += 1
167 out_t = None
168 if expect_out:
169 out_t = kwargs.get("out", None)
170 if out_t is None:
171 # find last tensor among remaining args as out
172 for a in reversed(args):
173 if isinstance(a, torch.Tensor):
174 out_t = a
175 break
176 if out_t is None:
177 raise ValueError(
178 "Output tensor 'out' not found for _upsample_nearest_exact1d_out."
179 )
181 # Normalize single-dim size and scale
182 out_w = _parse_size_1d(output_size)
183 scale_w = _parse_scale_1d(scales)
185 return in_t, out_t, out_w, scale_w
188def _prepare_out_tensor(in_t, out_w, scale_w, dtype=None, device=None):
189 N, C, IW = in_t.shape
190 OW = _compute_out_w(IW, out_w, scale_w)
191 if OW < 0:
192 raise ValueError("Output width must be non-negative.")
193 if dtype is None:
194 dtype = in_t.dtype
195 if device is None:
196 device = in_t.device
197 return torch.empty((N, C, OW), dtype=dtype, device=device)
200def _upsample_nearest_exact1d(*args, **kwargs):
201 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False)
202 out_t = _prepare_out_tensor(in_t, out_w, scale_w)
203 if out_t.numel() == 0:
204 return out_t
205 return _launch_upsample_nearest_exact1d_kernel(
206 in_t, out_t, output_size=out_w, scale=scale_w
207 )
210def _upsample_nearest_exact1d_out(*args, **kwargs):
211 in_t, out_t, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=True)
212 if out_t.ndim != 3:
213 raise ValueError(
214 f"Out tensor must be 3D (N, C, W); got shape {tuple(out_t.shape)}"
215 )
216 # Validate that out_t has the correct computed width if parameters are provided
217 expected_w = _compute_out_w(in_t.shape[-1], out_w, scale_w)
218 if out_t.shape[-1] != expected_w:
219 raise ValueError(
220 f"Provided out tensor has width {out_t.shape[-1]} but expected {expected_w}."
221 )
222 if out_t.numel() == 0:
223 return out_t
224 return _launch_upsample_nearest_exact1d_kernel(
225 in_t, out_t, output_size=out_w, scale=scale_w
226 )
229def _upsample_nearest_exact1d_vec(*args, **kwargs):
230 # Treat vec the same as base variant, allowing list-like output_size/scales
231 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False)
232 out_t = _prepare_out_tensor(in_t, out_w, scale_w)
233 if out_t.numel() == 0:
234 return out_t
235 return _launch_upsample_nearest_exact1d_kernel(
236 in_t, out_t, output_size=out_w, scale=scale_w
237 )