Coverage for src/flag_gems/ops/upsample_bicubic2d.py: 28%
127 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
2import math
3from typing import Sequence
5import torch
6import triton
7import triton.language as tl
9logger = logging.getLogger(__name__)
12@triton.jit
13def cubic_weight(d, a: tl.constexpr):
14 ad = tl.abs(d)
15 ad2 = ad * ad
16 ad3 = ad2 * ad
17 w1 = (a + 2.0) * ad3 - (a + 3.0) * ad2 + 1.0
18 w2 = a * ad3 - 5.0 * a * ad2 + 8.0 * a * ad - 4.0 * a
19 return tl.where(ad <= 1.0, w1, tl.where(ad < 2.0, w2, 0.0))
22@triton.autotune(
23 configs=[
24 triton.Config({"BLOCK_W": 128}, num_warps=4),
25 triton.Config({"BLOCK_W": 256}, num_warps=4),
26 triton.Config({"BLOCK_W": 256}, num_warps=8),
27 triton.Config({"BLOCK_W": 512}, num_warps=8),
28 triton.Config({"BLOCK_W": 1024}, num_warps=8),
29 ],
30 key=["W_out"],
31)
32@triton.jit
33def _upsample_bicubic2d_row_kernel(
34 in_ptr,
35 out_ptr,
36 N,
37 C,
38 H_in,
39 W_in,
40 H_out,
41 W_out,
42 strideN,
43 strideC,
44 strideH,
45 strideW,
46 out_strideN,
47 out_strideC,
48 out_strideH,
49 out_strideW,
50 scale_h,
51 scale_w,
52 align_corners: tl.constexpr,
53 BLOCK_W: tl.constexpr,
54):
55 pid = tl.program_id(0)
56 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
58 pid_w = pid % num_w_blocks
59 row_id = pid // num_w_blocks
61 y_out = row_id % H_out
62 nc = row_id // H_out
63 c = nc % C
64 n = nc // C
66 fy = y_out * 1.0
67 if align_corners:
68 in_y = fy * scale_h
69 else:
70 in_y = (fy + 0.5) * scale_h - 0.5
72 y0f = tl.floor(in_y)
73 y0 = y0f.to(tl.int32)
74 ty = in_y - y0f
76 y_m1 = tl.maximum(0, tl.minimum(H_in - 1, y0 - 1))
77 y_0 = tl.maximum(0, tl.minimum(H_in - 1, y0 + 0))
78 y_p1 = tl.maximum(0, tl.minimum(H_in - 1, y0 + 1))
79 y_p2 = tl.maximum(0, tl.minimum(H_in - 1, y0 + 2))
81 a = -0.75
82 wy0 = cubic_weight(1.0 + ty, a)
83 wy1 = cubic_weight(ty, a)
84 wy2 = cubic_weight(1.0 - ty, a)
85 wy3 = cubic_weight(2.0 - ty, a)
87 n_64 = n.to(tl.int64)
88 c_64 = c.to(tl.int64)
89 base_ptr = in_ptr + n_64 * strideN + c_64 * strideC
91 row_m1_ptr = base_ptr + y_m1.to(tl.int64) * strideH
92 row_0_ptr = base_ptr + y_0.to(tl.int64) * strideH
93 row_p1_ptr = base_ptr + y_p1.to(tl.int64) * strideH
94 row_p2_ptr = base_ptr + y_p2.to(tl.int64) * strideH
96 x_out = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
97 mask = x_out < W_out
99 fx = x_out.to(tl.float32)
100 if align_corners:
101 in_x = fx * scale_w
102 else:
103 in_x = (fx + 0.5) * scale_w - 0.5
105 x0f = tl.floor(in_x)
106 x0 = x0f.to(tl.int32)
107 tx = in_x - x0f
109 x_m1 = tl.maximum(0, tl.minimum(W_in - 1, x0 - 1))
110 x_0 = tl.maximum(0, tl.minimum(W_in - 1, x0 + 0))
111 x_p1 = tl.maximum(0, tl.minimum(W_in - 1, x0 + 1))
112 x_p2 = tl.maximum(0, tl.minimum(W_in - 1, x0 + 2))
114 wx0 = cubic_weight(1.0 + tx, a)
115 wx1 = cubic_weight(tx, a)
116 wx2 = cubic_weight(1.0 - tx, a)
117 wx3 = cubic_weight(2.0 - tx, a)
119 off_x_m1 = x_m1 * strideW
120 off_x_0 = x_0 * strideW
121 off_x_p1 = x_p1 * strideW
122 off_x_p2 = x_p2 * strideW
124 v0 = tl.load(row_m1_ptr + off_x_m1).to(tl.float32)
125 v1 = tl.load(row_m1_ptr + off_x_0).to(tl.float32)
126 v2 = tl.load(row_m1_ptr + off_x_p1).to(tl.float32)
127 v3 = tl.load(row_m1_ptr + off_x_p2).to(tl.float32)
128 acc = (v0 * wx0 + v1 * wx1 + v2 * wx2 + v3 * wx3) * wy0
130 v0 = tl.load(row_0_ptr + off_x_m1).to(tl.float32)
131 v1 = tl.load(row_0_ptr + off_x_0).to(tl.float32)
132 v2 = tl.load(row_0_ptr + off_x_p1).to(tl.float32)
133 v3 = tl.load(row_0_ptr + off_x_p2).to(tl.float32)
134 acc += (v0 * wx0 + v1 * wx1 + v2 * wx2 + v3 * wx3) * wy1
136 v0 = tl.load(row_p1_ptr + off_x_m1).to(tl.float32)
137 v1 = tl.load(row_p1_ptr + off_x_0).to(tl.float32)
138 v2 = tl.load(row_p1_ptr + off_x_p1).to(tl.float32)
139 v3 = tl.load(row_p1_ptr + off_x_p2).to(tl.float32)
140 acc += (v0 * wx0 + v1 * wx1 + v2 * wx2 + v3 * wx3) * wy2
142 v0 = tl.load(row_p2_ptr + off_x_m1).to(tl.float32)
143 v1 = tl.load(row_p2_ptr + off_x_0).to(tl.float32)
144 v2 = tl.load(row_p2_ptr + off_x_p1).to(tl.float32)
145 v3 = tl.load(row_p2_ptr + off_x_p2).to(tl.float32)
146 acc += (v0 * wx0 + v1 * wx1 + v2 * wx2 + v3 * wx3) * wy3
148 out_offset = (
149 n_64 * out_strideN
150 + c_64 * out_strideC
151 + y_out.to(tl.int64) * out_strideH
152 + x_out.to(tl.int64) * out_strideW
153 )
154 tl.store(out_ptr + out_offset, acc.to(out_ptr.dtype.element_ty), mask=mask)
157def upsample_bicubic2d(
158 input: torch.Tensor,
159 output_size: Sequence[int] | None = None,
160 align_corners: bool = False,
161 scales_h: float | None = None,
162 scales_w: float | None = None,
163) -> torch.Tensor:
164 logger.debug("GEMS UPSAMPLE BICUBIC2D")
165 scale_factors = (scales_h, scales_w)
167 if input.dim() != 4:
168 raise ValueError("input must be a 4D tensor (N, C, H, W)")
169 if output_size is None and scale_factors is None:
170 raise ValueError("Either output_size or scale_factors must be provided")
172 N, C, H_in, W_in = input.shape
174 if output_size is not None:
175 if len(output_size) != 2:
176 raise ValueError(
177 "output_size must be a sequence of two ints (H_out, W_out)"
178 )
179 H_out, W_out = int(output_size[0]), int(output_size[1])
180 else:
181 if len(scale_factors) == 2:
182 sh, sw = float(scale_factors[0]), float(scale_factors[1])
183 elif len(scale_factors) == 1:
184 sh = sw = float(scale_factors[0])
185 else:
186 raise ValueError("scale_factors must have length 1 or 2 for 2D upsampling")
187 H_out = max(int(math.floor(H_in * sh)), 1)
188 W_out = max(int(math.floor(W_in * sw)), 1)
190 if H_out <= 0 or W_out <= 0:
191 raise ValueError("Output size must be positive")
193 device = input.device
194 if not input.is_cuda:
195 raise ValueError("This Triton kernel requires CUDA tensors")
197 if align_corners:
198 scale_h = 0.0 if H_out <= 1 else (H_in - 1.0) / (H_out - 1.0)
199 scale_w = 0.0 if W_out <= 1 else (W_in - 1.0) / (W_out - 1.0)
200 else:
201 scale_h = float(H_in) / float(H_out)
202 scale_w = float(W_in) / float(W_out)
204 out = torch.empty((N, C, H_out, W_out), dtype=input.dtype, device=device)
206 sN, sC, sH, sW = input.stride()
207 oN, oC, oH, oW = out.stride()
209 grid = lambda meta: (triton.cdiv(W_out, meta["BLOCK_W"]) * N * C * H_out,)
211 _upsample_bicubic2d_row_kernel[grid](
212 input,
213 out,
214 N,
215 C,
216 H_in,
217 W_in,
218 H_out,
219 W_out,
220 sN,
221 sC,
222 sH,
223 sW,
224 oN,
225 oC,
226 oH,
227 oW,
228 float(scale_h),
229 float(scale_w),
230 align_corners,
231 )
233 return out