Coverage for src/flag_gems/ops/_upsample_nearest_exact1d.py: 53%
133 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9logger = logging.getLogger(__name__)
12@triton.jit
13def _upsample_nearest_exact1d_kernel(
14 in_ptr,
15 out_ptr,
16 N,
17 C,
18 IW,
19 OW,
20 sN_in,
21 sC_in,
22 sW_in,
23 sN_out,
24 sC_out,
25 sW_out,
26 use_scales: tl.constexpr,
27 scale_w,
28 BLOCK_W: tl.constexpr,
29):
30 pid_w = tl.program_id(0)
31 pid_nc = tl.program_id(1)
33 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
34 mask = offs_w < OW
36 # Compute (n, c) from flattened plane index
37 nc = pid_nc
38 n = nc // C
39 c = nc - n * C
41 base_in = n * sN_in + c * sC_in
42 base_out = n * sN_out + c * sC_out
44 # Compute source indices iw for each output index ow
45 iw = tl.zeros([BLOCK_W], dtype=tl.int32)
46 if use_scales:
47 ow_f = offs_w.to(tl.float32)
48 iw_f = tl.floor(ow_f / scale_w)
49 iw = iw_f.to(tl.int32)
50 else:
51 iw = (offs_w * IW) // OW
52 iw = tl.minimum(iw, IW - 1)
54 in_ptrs = in_ptr + base_in + iw * sW_in
55 x = tl.load(in_ptrs, mask=mask)
57 out_ptrs = out_ptr + base_out + offs_w * sW_out
58 tl.store(out_ptrs, x, mask=mask)
61def _parse_size_1d(val):
62 if val is None:
63 return None
64 if isinstance(val, torch.Size):
65 return int(val[-1]) if len(val) > 0 else None
66 if isinstance(val, (list, tuple)):
67 if len(val) == 0:
68 return None
69 return int(val[-1])
70 return int(val)
73def _parse_scale_1d(val):
74 if val is None:
75 return None
76 if isinstance(val, (list, tuple)):
77 if len(val) == 0:
78 return None
79 return float(val[-1])
80 return float(val)
83def _compute_out_w(iw, output_size, scale):
84 if output_size is not None:
85 return int(output_size)
86 if scale is None:
87 raise ValueError(
88 "Either output_size or scale must be provided for _upsample_nearest_exact1d."
89 )
90 # Follow common convention: OW = floor(IW * scale)
91 return int(math.floor(iw * scale))
94def _launch_upsample_nearest_exact1d_kernel(input, out, output_size=None, scale=None):
95 if input.ndim != 3:
96 raise ValueError(
97 f"_upsample_nearest_exact1d expects a 3D tensor (N, C, W); got shape {tuple(input.shape)}"
98 )
99 if not input.is_cuda or not out.is_cuda:
100 # Fallback to the native operator on CPU or non-CUDA devices
101 return torch.ops.aten._upsample_nearest_exact1d(
102 input, [out.shape[-1]], [scale] if scale is not None else None
103 )
105 N, C, IW = input.shape
106 OW = out.shape[-1]
108 sN_in, sC_in, sW_in = input.stride()
109 sN_out, sC_out, sW_out = out.stride()
111 BLOCK_W = 256
112 grid = (triton.cdiv(OW, BLOCK_W), N * C)
114 use_scales = scale is not None and output_size is None
115 scale_w = float(scale) if use_scales else 1.0
117 _upsample_nearest_exact1d_kernel[grid](
118 input,
119 out,
120 N,
121 C,
122 IW,
123 OW,
124 sN_in,
125 sC_in,
126 sW_in,
127 sN_out,
128 sC_out,
129 sW_out,
130 use_scales=use_scales,
131 scale_w=scale_w,
132 BLOCK_W=BLOCK_W,
133 )
134 return out
137def _extract_io_and_params(args, kwargs, expect_out=False):
138 # Extract input tensor
139 in_t = kwargs.get("input", None)
140 if in_t is None:
141 in_t = kwargs.get("self", None)
142 if in_t is None and len(args) > 0 and isinstance(args[0], torch.Tensor):
143 in_t = args[0]
144 args = args[1:]
145 if in_t is None or not isinstance(in_t, torch.Tensor):
146 raise ValueError("Input tensor not found for _upsample_nearest_exact1d.")
148 # Extract output_size / scales from kwargs or remaining args
149 output_size = kwargs.get(
150 "output_size", kwargs.get("size", kwargs.get("output_size_list", None))
151 )
152 scales = kwargs.get(
153 "scale_factor",
154 kwargs.get("scales", kwargs.get("scale_factors", kwargs.get("scale", None))),
155 )
157 # If positional arguments contain size and/or scales
158 # Try to interpret next positional as output_size if present and not a tensor
159 pos = 0
160 if (
161 output_size is None
162 and pos < len(args)
163 and not isinstance(args[pos], torch.Tensor)
164 ):
165 output_size = args[pos]
166 pos += 1
167 if scales is None and pos < len(args) and not isinstance(args[pos], torch.Tensor):
168 scales = args[pos]
169 pos += 1
171 out_t = None
172 if expect_out:
173 out_t = kwargs.get("out", None)
174 if out_t is None:
175 # find last tensor among remaining args as out
176 for a in reversed(args):
177 if isinstance(a, torch.Tensor):
178 out_t = a
179 break
180 if out_t is None:
181 raise ValueError(
182 "Output tensor 'out' not found for _upsample_nearest_exact1d_out."
183 )
185 # Normalize single-dim size and scale
186 out_w = _parse_size_1d(output_size)
187 scale_w = _parse_scale_1d(scales)
189 return in_t, out_t, out_w, scale_w
192def _prepare_out_tensor(in_t, out_w, scale_w, dtype=None, device=None):
193 N, C, IW = in_t.shape
194 OW = _compute_out_w(IW, out_w, scale_w)
195 if OW < 0:
196 raise ValueError("Output width must be non-negative.")
197 if dtype is None:
198 dtype = in_t.dtype
199 if device is None:
200 device = in_t.device
201 return torch.empty((N, C, OW), dtype=dtype, device=device)
204def _upsample_nearest_exact1d(*args, **kwargs):
205 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D")
206 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False)
207 out_t = _prepare_out_tensor(in_t, out_w, scale_w)
208 if out_t.numel() == 0:
209 return out_t
210 return _launch_upsample_nearest_exact1d_kernel(
211 in_t, out_t, output_size=out_w, scale=scale_w
212 )
215def _upsample_nearest_exact1d_out(*args, **kwargs):
216 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D_OUT")
217 in_t, out_t, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=True)
218 if out_t.ndim != 3:
219 raise ValueError(
220 f"Out tensor must be 3D (N, C, W); got shape {tuple(out_t.shape)}"
221 )
222 # Validate that out_t has the correct computed width if parameters are provided
223 expected_w = _compute_out_w(in_t.shape[-1], out_w, scale_w)
224 if out_t.shape[-1] != expected_w:
225 raise ValueError(
226 f"Provided out tensor has width {out_t.shape[-1]} but expected {expected_w}."
227 )
228 if out_t.numel() == 0:
229 return out_t
230 return _launch_upsample_nearest_exact1d_kernel(
231 in_t, out_t, output_size=out_w, scale=scale_w
232 )
235def _upsample_nearest_exact1d_vec(*args, **kwargs):
236 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D_VEC")
237 # Treat vec the same as base variant, allowing list-like output_size/scales
238 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False)
239 out_t = _prepare_out_tensor(in_t, out_w, scale_w)
240 if out_t.numel() == 0:
241 return out_t
242 return _launch_upsample_nearest_exact1d_kernel(
243 in_t, out_t, output_size=out_w, scale=scale_w
244 )