Coverage for src/flag_gems/ops/upsample_bicubic2d_aa.py: 21%
178 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import triton_lang_extension as tle
12device = device.name
14logger = logging.getLogger(__name__)
17@triton.autotune(
18 configs=runtime.get_tuned_config("upsample_bicubic2d_aa"),
19 key=["N", "C", "OH", "OW"],
20)
21@triton.jit
22def upsample_bicubic2d_aa_kernel(
23 ptr_o,
24 ptr_i,
25 N,
26 C,
27 OH,
28 OW,
29 IH,
30 IW,
31 reciprocal_scale_h,
32 reciprocal_scale_w,
33 BLOCK_X: tl.constexpr,
34 BLOCK_Y: tl.constexpr,
35):
36 pid_x = tle.program_id(axis=0)
37 pid_y = tle.program_id(axis=1)
38 ow = (pid_x * BLOCK_X + tl.arange(0, BLOCK_X)) % OW
39 oh = (pid_y * BLOCK_Y + tl.arange(0, BLOCK_Y)) % OH
41 support_w = 2.0
42 support_h = 2.0
44 # _compute_weights_span
45 center_w = (ow + 0.5) * reciprocal_scale_w
46 center_h = (oh + 0.5) * reciprocal_scale_h
47 span_start_w = tl.maximum(center_w - support_w + 0.5, 0).to(tl.int32)
48 span_start_h = tl.maximum(center_h - support_h + 0.5, 0).to(tl.int32)
49 span_size_w = (tl.minimum(center_w + support_w + 0.5, IW) - span_start_w).to(
50 tl.int32
51 )
52 span_size_h = (tl.minimum(center_h + support_h + 0.5, IH) - span_start_h).to(
53 tl.int32
54 )
55 start_minus_center_w = span_start_w - center_w
56 start_minus_center_h = span_start_h - center_h
57 invscale_w = 1.0
58 invscale_h = 1.0
59 a = -0.5
60 wy0 = tl.abs((0 + start_minus_center_h + 0.5) * invscale_h)
61 weight_y0 = tl.where(
62 0 < span_size_h,
63 tl.where(
64 wy0 < 1.0,
65 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1,
66 tl.where(wy0 < 2.0, (((wy0 - 5) * wy0 + 8) * wy0 - 4) * a, 0),
67 ),
68 0,
69 )
70 wy1 = tl.abs((1 + start_minus_center_h + 0.5) * invscale_h)
71 weight_y1 = tl.where(
72 1 < span_size_h,
73 tl.where(
74 wy1 < 1.0,
75 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1,
76 tl.where(wy1 < 2.0, (((wy1 - 5) * wy1 + 8) * wy1 - 4) * a, 0),
77 ),
78 0,
79 )
80 wy2 = tl.abs((2 + start_minus_center_h + 0.5) * invscale_h)
81 weight_y2 = tl.where(
82 2 < span_size_h,
83 tl.where(
84 wy2 < 1.0,
85 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1,
86 tl.where(wy2 < 2.0, (((wy2 - 5) * wy2 + 8) * wy2 - 4) * a, 0),
87 ),
88 0,
89 )
90 wy3 = tl.abs((3 + start_minus_center_h + 0.5) * invscale_h)
91 weight_y3 = tl.where(
92 3 < span_size_h,
93 tl.where(
94 wy3 < 1.0,
95 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1,
96 tl.where(wy3 < 2.0, (((wy3 - 5) * wy3 + 8) * wy3 - 4) * a, 0),
97 ),
98 0,
99 )
100 wy4 = tl.abs((4 + start_minus_center_h + 0.5) * invscale_h)
101 weight_y4 = tl.where(
102 4 < span_size_h,
103 tl.where(
104 wy4 < 1.0,
105 ((a + 2) * wy4 - (a + 3)) * wy4 * wy4 + 1,
106 tl.where(wy4 < 2.0, (((wy4 - 5) * wy4 + 8) * wy4 - 4) * a, 0),
107 ),
108 0,
109 )
110 weight_y_total = weight_y0 + weight_y1 + weight_y2 + weight_y3 + weight_y4
111 weight_y_total = tl.where(weight_y_total != 0, weight_y_total, 1)
112 weight_y0 /= weight_y_total
113 weight_y1 /= weight_y_total
114 weight_y2 /= weight_y_total
115 weight_y3 /= weight_y_total
116 weight_y4 /= weight_y_total
118 wx0 = tl.abs((0 + start_minus_center_w + 0.5) * invscale_w)
119 weight_x0 = tl.where(
120 0 < span_size_w,
121 tl.where(
122 wx0 < 1.0,
123 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1,
124 tl.where(wx0 < 2.0, (((wx0 - 5) * wx0 + 8) * wx0 - 4) * a, 0),
125 ),
126 0,
127 )
128 wx1 = tl.abs((1 + start_minus_center_w + 0.5) * invscale_w)
129 weight_x1 = tl.where(
130 1 < span_size_w,
131 tl.where(
132 wx1 < 1.0,
133 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1,
134 tl.where(wx1 < 2.0, (((wx1 - 5) * wx1 + 8) * wx1 - 4) * a, 0),
135 ),
136 0,
137 )
138 wx2 = tl.abs((2 + start_minus_center_w + 0.5) * invscale_w)
139 weight_x2 = tl.where(
140 2 < span_size_w,
141 tl.where(
142 wx2 < 1.0,
143 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1,
144 tl.where(wx2 < 2.0, (((wx2 - 5) * wx2 + 8) * wx2 - 4) * a, 0),
145 ),
146 0,
147 )
148 wx3 = tl.abs((3 + start_minus_center_w + 0.5) * invscale_w)
149 weight_x3 = tl.where(
150 3 < span_size_w,
151 tl.where(
152 wx3 < 1.0,
153 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1,
154 tl.where(wx3 < 2.0, (((wx3 - 5) * wx3 + 8) * wx3 - 4) * a, 0),
155 ),
156 0,
157 )
158 wx4 = tl.abs((4 + start_minus_center_w + 0.5) * invscale_w)
159 weight_x4 = tl.where(
160 4 < span_size_w,
161 tl.where(
162 wx4 < 1.0,
163 ((a + 2) * wx4 - (a + 3)) * wx4 * wx4 + 1,
164 tl.where(wx4 < 2.0, (((wx4 - 5) * wx4 + 8) * wx4 - 4) * a, 0),
165 ),
166 0,
167 )
168 weight_x_total = weight_x0 + weight_x1 + weight_x2 + weight_x3 + weight_x4
169 weight_x_total = tl.where(weight_x_total != 0, weight_x_total, 1)
170 weight_x0 /= weight_x_total
171 weight_x1 /= weight_x_total
172 weight_x2 /= weight_x_total
173 weight_x3 /= weight_x_total
174 weight_x4 /= weight_x_total
176 mask_y0 = span_start_h[:, None] + 0 < IH
177 mask_y1 = span_start_h[:, None] + 1 < IH
178 mask_y2 = span_start_h[:, None] + 2 < IH
179 mask_y3 = span_start_h[:, None] + 3 < IH
180 mask_y4 = span_start_h[:, None] + 4 < IH
181 mask_x0 = span_start_w[None, :] + 0 < IW
182 mask_x1 = span_start_w[None, :] + 1 < IW
183 mask_x2 = span_start_w[None, :] + 2 < IW
184 mask_x3 = span_start_w[None, :] + 3 < IW
185 mask_x4 = span_start_w[None, :] + 4 < IW
187 for n in range(0, N, 1):
188 for c in range(0, C, 1):
189 offset_base = (
190 (n * C + c) * IH + span_start_h[:, None]
191 ) * IW + span_start_w[None, :]
193 data00 = tl.load(
194 ptr_i + (offset_base + 0 * IW + 0),
195 mask=(mask_y0 & mask_x0),
196 other=0,
197 )
198 data01 = tl.load(
199 ptr_i + (offset_base + 0 * IW + 1),
200 mask=(mask_y0 & mask_x1),
201 other=0,
202 )
203 data02 = tl.load(
204 ptr_i + (offset_base + 0 * IW + 2),
205 mask=(mask_y0 & mask_x2),
206 other=0,
207 )
208 data03 = tl.load(
209 ptr_i + (offset_base + 0 * IW + 3),
210 mask=(mask_y0 & mask_x3),
211 other=0,
212 )
213 data04 = tl.load(
214 ptr_i + (offset_base + 0 * IW + 4),
215 mask=(mask_y0 & mask_x4),
216 other=0,
217 )
219 data10 = tl.load(
220 ptr_i + (offset_base + 1 * IW + 0),
221 mask=(mask_y1 & mask_x0),
222 other=0,
223 )
224 data11 = tl.load(
225 ptr_i + (offset_base + 1 * IW + 1),
226 mask=(mask_y1 & mask_x1),
227 other=0,
228 )
229 data12 = tl.load(
230 ptr_i + (offset_base + 1 * IW + 2),
231 mask=(mask_y1 & mask_x2),
232 other=0,
233 )
234 data13 = tl.load(
235 ptr_i + (offset_base + 1 * IW + 3),
236 mask=(mask_y1 & mask_x3),
237 other=0,
238 )
239 data14 = tl.load(
240 ptr_i + (offset_base + 1 * IW + 4),
241 mask=(mask_y1 & mask_x4),
242 other=0,
243 )
245 data20 = tl.load(
246 ptr_i + (offset_base + 2 * IW + 0),
247 mask=(mask_y2 & mask_x0),
248 other=0,
249 )
250 data21 = tl.load(
251 ptr_i + (offset_base + 2 * IW + 1),
252 mask=(mask_y2 & mask_x1),
253 other=0,
254 )
255 data22 = tl.load(
256 ptr_i + (offset_base + 2 * IW + 2),
257 mask=(mask_y2 & mask_x2),
258 other=0,
259 )
260 data23 = tl.load(
261 ptr_i + (offset_base + 2 * IW + 3),
262 mask=(mask_y2 & mask_x3),
263 other=0,
264 )
265 data24 = tl.load(
266 ptr_i + (offset_base + 2 * IW + 4),
267 mask=(mask_y2 & mask_x4),
268 other=0,
269 )
271 data30 = tl.load(
272 ptr_i + (offset_base + 3 * IW + 0),
273 mask=(mask_y3 & mask_x0),
274 other=0,
275 )
276 data31 = tl.load(
277 ptr_i + (offset_base + 3 * IW + 1),
278 mask=(mask_y3 & mask_x1),
279 other=0,
280 )
281 data32 = tl.load(
282 ptr_i + (offset_base + 3 * IW + 2),
283 mask=(mask_y3 & mask_x2),
284 other=0,
285 )
286 data33 = tl.load(
287 ptr_i + (offset_base + 3 * IW + 3),
288 mask=(mask_y3 & mask_x3),
289 other=0,
290 )
291 data34 = tl.load(
292 ptr_i + (offset_base + 3 * IW + 4),
293 mask=(mask_y3 & mask_x4),
294 other=0,
295 )
297 data40 = tl.load(
298 ptr_i + (offset_base + 4 * IW + 0),
299 mask=(mask_y4 & mask_x0),
300 other=0,
301 )
302 data41 = tl.load(
303 ptr_i + (offset_base + 4 * IW + 1),
304 mask=(mask_y4 & mask_x1),
305 other=0,
306 )
307 data42 = tl.load(
308 ptr_i + (offset_base + 4 * IW + 2),
309 mask=(mask_y4 & mask_x2),
310 other=0,
311 )
312 data43 = tl.load(
313 ptr_i + (offset_base + 4 * IW + 3),
314 mask=(mask_y4 & mask_x3),
315 other=0,
316 )
317 data44 = tl.load(
318 ptr_i + (offset_base + 4 * IW + 4),
319 mask=(mask_y4 & mask_x4),
320 other=0,
321 )
323 data0 = (
324 data00 * weight_x0[None, :]
325 + data01 * weight_x1[None, :]
326 + data02 * weight_x2[None, :]
327 + data03 * weight_x3[None, :]
328 + data04 * weight_x4[None, :]
329 )
330 data1 = (
331 data10 * weight_x0[None, :]
332 + data11 * weight_x1[None, :]
333 + data12 * weight_x2[None, :]
334 + data13 * weight_x3[None, :]
335 + data14 * weight_x4[None, :]
336 )
337 data2 = (
338 data20 * weight_x0[None, :]
339 + data21 * weight_x1[None, :]
340 + data22 * weight_x2[None, :]
341 + data23 * weight_x3[None, :]
342 + data24 * weight_x4[None, :]
343 )
344 data3 = (
345 data30 * weight_x0[None, :]
346 + data31 * weight_x1[None, :]
347 + data32 * weight_x2[None, :]
348 + data33 * weight_x3[None, :]
349 + data34 * weight_x4[None, :]
350 )
351 data4 = (
352 data40 * weight_x0[None, :]
353 + data41 * weight_x1[None, :]
354 + data42 * weight_x2[None, :]
355 + data43 * weight_x3[None, :]
356 + data44 * weight_x4[None, :]
357 )
358 result = (
359 data0 * weight_y0[:, None]
360 + data1 * weight_y1[:, None]
361 + data2 * weight_y2[:, None]
362 + data3 * weight_y3[:, None]
363 + data4 * weight_y4[:, None]
364 )
366 offset_o = ((n * C + c) * OH + oh[:, None]) * OW + ow[None, :]
367 tl.store(ptr_o + offset_o, result)
370# upsample and downsample
371@triton.autotune(
372 configs=runtime.get_tuned_config("upsample_bicubic2d_aa"),
373 key=["N", "C", "OH", "OW"],
374)
375@triton.jit
376def general_interpolate_bicubic2d_aa_kernel(
377 ptr_o,
378 ptr_i,
379 N,
380 C,
381 OH,
382 OW,
383 IH,
384 IW,
385 reciprocal_scale_h,
386 reciprocal_scale_w,
387 BLOCK_X: tl.constexpr,
388 BLOCK_Y: tl.constexpr,
389):
390 pid_x = tle.program_id(axis=0)
391 pid_y = tle.program_id(axis=1)
392 ow = (pid_x * BLOCK_X + tl.arange(0, BLOCK_X)) % OW
393 oh = (pid_y * BLOCK_Y + tl.arange(0, BLOCK_Y)) % OH
395 support_w = 2 * reciprocal_scale_w if (reciprocal_scale_w >= 1.0) else 2.0
396 support_h = 2 * reciprocal_scale_h if (reciprocal_scale_h >= 1.0) else 2.0
398 interpolate_w = (support_w + 0.5).to(tl.int32) * 2 + 1
399 interpolate_h = (support_h + 0.5).to(tl.int32) * 2 + 1
401 # _compute_weights_span
402 center_w = (ow + 0.5) * reciprocal_scale_w
403 center_h = (oh + 0.5) * reciprocal_scale_h
404 span_start_w = tl.maximum(center_w - support_w + 0.5, 0).to(tl.int32)
405 span_start_h = tl.maximum(center_h - support_h + 0.5, 0).to(tl.int32)
406 span_size_w = (tl.minimum(center_w + support_w + 0.5, IW) - span_start_w).to(
407 tl.int32
408 )
409 span_size_h = (tl.minimum(center_h + support_h + 0.5, IH) - span_start_h).to(
410 tl.int32
411 )
413 invscale_w = 1.0 / reciprocal_scale_w if (reciprocal_scale_w >= 1.0) else 1.0
414 invscale_h = 1.0 / reciprocal_scale_h if (reciprocal_scale_h >= 1.0) else 1.0
415 start_minus_center_w = span_start_w - center_w
416 start_minus_center_h = span_start_h - center_h
418 a = -0.5
419 for n in range(0, N, 1):
420 for c in range(0, C, 1):
421 offset_base = ((n * C + c) * IH + span_start_h[:, None]) * IW + span_start_w
422 weight_y_total = tl.zeros((BLOCK_Y,), dtype=tl.float32)
423 result = tl.zeros((BLOCK_Y, BLOCK_X), dtype=tl.float32)
424 for y in range(0, interpolate_h, 1):
425 wy = tl.abs((y + start_minus_center_h + 0.5) * invscale_h)
426 weight_y = tl.where(
427 y < span_size_h,
428 tl.where(
429 wy < 1.0,
430 ((a + 2) * wy - (a + 3)) * wy * wy + 1,
431 tl.where(wy < 2.0, (((wy - 5) * wy + 8) * wy - 4) * a, 0),
432 ),
433 0,
434 )
435 weight_y_total += weight_y
436 weight_x_total = tl.zeros((BLOCK_X,), dtype=tl.float32)
437 buffer = tl.zeros((BLOCK_Y, BLOCK_X), dtype=tl.float32)
438 for x in range(0, interpolate_w, 1):
439 wx = tl.abs((x + start_minus_center_w + 0.5) * invscale_w)
440 weight_x = tl.where(
441 x < span_size_w,
442 tl.where(
443 wx < 1.0,
444 ((a + 2) * wx - (a + 3)) * wx * wx + 1,
445 tl.where(wx < 2.0, (((wx - 5) * wx + 8) * wx - 4) * a, 0),
446 ),
447 0,
448 )
449 weight_x_total += weight_x
450 data = tl.load(
451 ptr_i + (offset_base + y * IW + x),
452 mask=(span_start_h[:, None] + y < IH)
453 & (span_start_w[None, :] + x < IW),
454 other=0,
455 )
456 buffer += data * weight_x[None, :]
457 weight_x_total = tl.where(weight_x_total != 0, weight_x_total, 1)
458 result += buffer / weight_x_total[None, :] * weight_y[:, None]
459 weight_y_total = tl.where(weight_y_total != 0, weight_y_total, 1)
460 result /= weight_y_total[:, None]
461 offset_o = ((n * C + c) * OH + oh[:, None]) * OW + ow[None, :]
462 tl.store(ptr_o + offset_o, result)
465def bicubic_reciprocal_scale(src_size, dst_size, align_corners, scale):
466 if align_corners:
467 if dst_size > 1:
468 return (src_size - 1) / (dst_size - 1)
469 else:
470 return 0
471 else:
472 if scale is not None and scale > 0:
473 return 1.0 / scale
474 else:
475 return src_size / dst_size
478# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L12547
479def _upsample_bicubic2d_aa(
480 input: torch.Tensor,
481 output_size: Tuple[int],
482 align_corners: bool = False,
483 scales_h: Optional[float] = None,
484 scales_w: Optional[float] = None,
485):
486 logger.debug("GEMS UPSAMPLE BICUBIC2D AA")
487 assert input.device.type == device
488 assert input.ndim == 4, "The ndim of input must be 4"
489 assert len(output_size) == 2, "The len of output_size must be 2"
491 OH, OW = output_size
492 N, C, IH, IW = input.shape
494 reciprocal_scale_h = bicubic_reciprocal_scale(IH, OH, align_corners, scales_h)
495 reciprocal_scale_w = bicubic_reciprocal_scale(IW, OW, align_corners, scales_w)
497 # allocate output
498 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
499 grid = lambda META: (
500 triton.cdiv(OW, META["BLOCK_X"]),
501 triton.cdiv(OH, META["BLOCK_Y"]),
502 )
503 kernel = (
504 general_interpolate_bicubic2d_aa_kernel
505 if (reciprocal_scale_w >= 1.0) or (reciprocal_scale_h >= 1.0)
506 else upsample_bicubic2d_aa_kernel
507 )
508 with torch_device_fn.device(input.device):
509 kernel[grid](
510 output,
511 input,
512 N,
513 C,
514 OH,
515 OW,
516 IH,
517 IW,
518 reciprocal_scale_h,
519 reciprocal_scale_w,
520 )
521 return output