Coverage for src/flag_gems/experimental_ops/upsample_nearest1d.py: 0%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import math
3import torch
4import triton
5import triton.language as tl
8@triton.jit
9def _upsample_nearest1d_kernel(
10 in_ptr,
11 out_ptr,
12 N,
13 C,
14 W_IN,
15 W_OUT,
16 in_stride_n,
17 in_stride_c,
18 in_stride_w,
19 out_stride_n,
20 out_stride_c,
21 out_stride_w,
22 use_scale,
23 inv_scale,
24 BLOCK_W: tl.constexpr,
25):
26 pid_w = tl.program_id(0) # along W_OUT
27 pid_nc = tl.program_id(1) # along N*C
29 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
30 nc = pid_nc
32 n = nc // C
33 c = nc % C
35 mask = (offs_w < W_OUT) & (n < N) & (c < C)
37 # Compute source indices
38 # Using integer math when output_size is provided: j = floor(offs_w * W_IN / W_OUT)
39 j_from_output = tl.minimum((offs_w * W_IN) // W_OUT, W_IN - 1)
41 # Using explicit scale factor when provided: j = floor(offs_w / scale) = floor(offs_w * inv_scale)
42 j_from_scale = tl.minimum(
43 (offs_w.to(tl.float32) * inv_scale).to(tl.int32), W_IN - 1
44 )
46 cond = use_scale != 0
47 j = tl.where(cond, j_from_scale, j_from_output)
49 base_in = n * in_stride_n + c * in_stride_c
50 base_out = n * out_stride_n + c * out_stride_c
52 in_idx = base_in + j * in_stride_w
53 out_idx = base_out + offs_w * out_stride_w
55 val = tl.load(in_ptr + in_idx, mask=mask, other=0)
56 tl.store(out_ptr + out_idx, val, mask=mask)
59def _upsample_nearest1d_impl(
60 input: torch.Tensor, output_size=None, scales=None, out: torch.Tensor = None
61):
62 if not input.is_cuda:
63 raise ValueError("Input tensor must be on CUDA device.")
64 if input.dim() != 3:
65 raise ValueError("upsample_nearest1d expects a 3D tensor of shape (N, C, W).")
66 N, C, W_in = input.shape
68 use_scale = False
69 inv_scale = 0.0
71 if output_size is not None:
72 if not isinstance(output_size, (list, tuple)) or len(output_size) != 1:
73 raise ValueError(
74 "output_size must be a sequence of length 1 for 1D upsampling."
75 )
76 W_out = int(output_size[0])
77 else:
78 # derive from scales
79 if scales is None:
80 raise ValueError("Either output_size or scales must be provided.")
81 if isinstance(scales, (list, tuple)):
82 if len(scales) == 0 or scales[0] is None:
83 raise ValueError("Invalid scales for 1D upsampling.")
84 s = float(scales[0])
85 else:
86 s = float(scales)
87 if s <= 0:
88 raise ValueError("Scale factor must be positive.")
89 W_out = int(math.floor(W_in * s))
90 use_scale = True
91 inv_scale = 1.0 / s
93 if W_out <= 0:
94 raise ValueError("Computed output width must be positive.")
96 # Prepare output
97 if out is None:
98 out = torch.empty((N, C, W_out), device=input.device, dtype=input.dtype)
99 else:
100 if not out.is_cuda:
101 raise ValueError("Output tensor must be on CUDA device.")
102 if list(out.shape) != [N, C, W_out]:
103 raise ValueError(
104 f"Output tensor has incorrect shape, expected ({N}, {C}, {W_out})."
105 )
106 if out.dtype != input.dtype:
107 raise ValueError("Output tensor must have the same dtype as input.")
109 # Extract strides
110 in_stride_n, in_stride_c, in_stride_w = input.stride()
111 out_stride_n, out_stride_c, out_stride_w = out.stride()
113 # Launch kernel
114 BLOCK_W = 256
115 grid = (triton.cdiv(W_out, BLOCK_W), N * C)
116 _upsample_nearest1d_kernel[grid](
117 input,
118 out,
119 N,
120 C,
121 W_in,
122 W_out,
123 in_stride_n,
124 in_stride_c,
125 in_stride_w,
126 out_stride_n,
127 out_stride_c,
128 out_stride_w,
129 int(use_scale),
130 float(inv_scale),
131 BLOCK_W=BLOCK_W,
132 )
133 return out
136def upsample_nearest1d(input: torch.Tensor, output_size=None, scales=None):
137 return _upsample_nearest1d_impl(
138 input, output_size=output_size, scales=scales, out=None
139 )
142def upsample_nearest1d_vec(input: torch.Tensor, output_size=None, scales=None):
143 # scales expected to be a sequence; pass through as-is
144 return _upsample_nearest1d_impl(
145 input, output_size=output_size, scales=scales, out=None
146 )
149def upsample_nearest1d_out(
150 input: torch.Tensor, output_size=None, scales=None, *, out: torch.Tensor
151):
152 _upsample_nearest1d_impl(input, output_size=output_size, scales=scales, out=out)
153 return out