Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/upsample_bicubic2d_aa.py: 0%
200 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13device = device.name
16def configs():
17 block = [(bx, by) for bx in (512, 256, 128, 64) for by in (2, 1)]
18 warps = [4, 8]
19 return [
20 triton.Config(
21 {
22 "BLOCK_X": bs[0],
23 "BLOCK_Y": bs[1],
24 },
25 num_warps=wp,
26 )
27 for bs in block
28 for wp in warps
29 ]
32def heur_m_block_size(args):
33 return triton.next_power_of_2(triton.cdiv(args["OW"], 12)) # cluster_num
36def heur_n_block_size(args):
37 return 1
38 import builtins
40 return builtins.min(triton.next_power_of_2(args["OH"]), 8192)
43# @triton.autotune(
44# configs=runtime.get_tuned_config("upsample_bicubic2d_aa"),
45# key=["N", "C", "OH", "OW"],
46# )
47@triton.heuristics(
48 values={
49 "BLOCK_X": heur_m_block_size,
50 "BLOCK_Y": heur_n_block_size,
51 },
52)
53@triton.jit
54def upsample_bicubic2d_aa_kernel(
55 ptr_o,
56 ptr_i,
57 N: tl.constexpr,
58 C: tl.constexpr,
59 OH,
60 OW,
61 IH,
62 IW,
63 reciprocal_scale_h,
64 reciprocal_scale_w,
65 BLOCK_X: tl.constexpr,
66 BLOCK_Y: tl.constexpr,
67):
68 pid_x = tle.program_id(axis=0)
69 pid_y = tle.program_id(axis=1)
70 ow = (pid_x * BLOCK_X + tl.arange(0, BLOCK_X)) % OW
71 oh = (pid_y * BLOCK_Y + tl.arange(0, BLOCK_Y)) % OH
73 support_w = 2.0
74 support_h = 2.0
76 # _compute_weights_span
77 center_w = (ow + 0.5) * reciprocal_scale_w
78 center_h = (oh + 0.5) * reciprocal_scale_h
79 span_start_w = tl.maximum(center_w - support_w + 0.5, 0).to(tl.int32)
80 span_start_h = tl.maximum(center_h - support_h + 0.5, 0).to(tl.int32)
81 span_size_w = (tl.minimum(center_w + support_w + 0.5, IW) - span_start_w).to(
82 tl.int32
83 )
84 span_size_h = (tl.minimum(center_h + support_h + 0.5, IH) - span_start_h).to(
85 tl.int32
86 )
87 start_minus_center_w = span_start_w - center_w
88 start_minus_center_h = span_start_h - center_h
89 invscale_w = 1.0
90 invscale_h = 1.0
91 a = -0.5
92 wy0 = tl.abs((0 + start_minus_center_h + 0.5) * invscale_h)
93 weight_y0 = tl.where(
94 0 < span_size_h,
95 tl.where(
96 wy0 < 1.0,
97 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1,
98 tl.where(wy0 < 2.0, (((wy0 - 5) * wy0 + 8) * wy0 - 4) * a, 0),
99 ),
100 0,
101 )
102 wy1 = tl.abs((1 + start_minus_center_h + 0.5) * invscale_h)
103 weight_y1 = tl.where(
104 1 < span_size_h,
105 tl.where(
106 wy1 < 1.0,
107 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1,
108 tl.where(wy1 < 2.0, (((wy1 - 5) * wy1 + 8) * wy1 - 4) * a, 0),
109 ),
110 0,
111 )
112 wy2 = tl.abs((2 + start_minus_center_h + 0.5) * invscale_h)
113 weight_y2 = tl.where(
114 2 < span_size_h,
115 tl.where(
116 wy2 < 1.0,
117 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1,
118 tl.where(wy2 < 2.0, (((wy2 - 5) * wy2 + 8) * wy2 - 4) * a, 0),
119 ),
120 0,
121 )
122 wy3 = tl.abs((3 + start_minus_center_h + 0.5) * invscale_h)
123 weight_y3 = tl.where(
124 3 < span_size_h,
125 tl.where(
126 wy3 < 1.0,
127 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1,
128 tl.where(wy3 < 2.0, (((wy3 - 5) * wy3 + 8) * wy3 - 4) * a, 0),
129 ),
130 0,
131 )
132 wy4 = tl.abs((4 + start_minus_center_h + 0.5) * invscale_h)
133 weight_y4 = tl.where(
134 4 < span_size_h,
135 tl.where(
136 wy4 < 1.0,
137 ((a + 2) * wy4 - (a + 3)) * wy4 * wy4 + 1,
138 tl.where(wy4 < 2.0, (((wy4 - 5) * wy4 + 8) * wy4 - 4) * a, 0),
139 ),
140 0,
141 )
142 weight_y_total = weight_y0 + weight_y1 + weight_y2 + weight_y3 + weight_y4
143 weight_y_total = tl.where(weight_y_total != 0, weight_y_total, 1)
144 weight_y0 /= weight_y_total
145 weight_y1 /= weight_y_total
146 weight_y2 /= weight_y_total
147 weight_y3 /= weight_y_total
148 weight_y4 /= weight_y_total
150 wx0 = tl.abs((0 + start_minus_center_w + 0.5) * invscale_w)
151 weight_x0 = tl.where(
152 0 < span_size_w,
153 tl.where(
154 wx0 < 1.0,
155 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1,
156 tl.where(wx0 < 2.0, (((wx0 - 5) * wx0 + 8) * wx0 - 4) * a, 0),
157 ),
158 0,
159 )
160 wx1 = tl.abs((1 + start_minus_center_w + 0.5) * invscale_w)
161 weight_x1 = tl.where(
162 1 < span_size_w,
163 tl.where(
164 wx1 < 1.0,
165 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1,
166 tl.where(wx1 < 2.0, (((wx1 - 5) * wx1 + 8) * wx1 - 4) * a, 0),
167 ),
168 0,
169 )
170 wx2 = tl.abs((2 + start_minus_center_w + 0.5) * invscale_w)
171 weight_x2 = tl.where(
172 2 < span_size_w,
173 tl.where(
174 wx2 < 1.0,
175 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1,
176 tl.where(wx2 < 2.0, (((wx2 - 5) * wx2 + 8) * wx2 - 4) * a, 0),
177 ),
178 0,
179 )
180 wx3 = tl.abs((3 + start_minus_center_w + 0.5) * invscale_w)
181 weight_x3 = tl.where(
182 3 < span_size_w,
183 tl.where(
184 wx3 < 1.0,
185 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1,
186 tl.where(wx3 < 2.0, (((wx3 - 5) * wx3 + 8) * wx3 - 4) * a, 0),
187 ),
188 0,
189 )
190 wx4 = tl.abs((4 + start_minus_center_w + 0.5) * invscale_w)
191 weight_x4 = tl.where(
192 4 < span_size_w,
193 tl.where(
194 wx4 < 1.0,
195 ((a + 2) * wx4 - (a + 3)) * wx4 * wx4 + 1,
196 tl.where(wx4 < 2.0, (((wx4 - 5) * wx4 + 8) * wx4 - 4) * a, 0),
197 ),
198 0,
199 )
200 weight_x_total = weight_x0 + weight_x1 + weight_x2 + weight_x3 + weight_x4
201 weight_x_total = tl.where(weight_x_total != 0, weight_x_total, 1)
202 weight_x0 /= weight_x_total
203 weight_x1 /= weight_x_total
204 weight_x2 /= weight_x_total
205 weight_x3 /= weight_x_total
206 weight_x4 /= weight_x_total
208 mask_y0 = span_start_h[:, None] + 0 < IH
209 mask_y1 = span_start_h[:, None] + 1 < IH
210 mask_y2 = span_start_h[:, None] + 2 < IH
211 mask_y3 = span_start_h[:, None] + 3 < IH
212 mask_y4 = span_start_h[:, None] + 4 < IH
213 mask_x0 = span_start_w[None, :] + 0 < IW
214 mask_x1 = span_start_w[None, :] + 1 < IW
215 mask_x2 = span_start_w[None, :] + 2 < IW
216 mask_x3 = span_start_w[None, :] + 3 < IW
217 mask_x4 = span_start_w[None, :] + 4 < IW
219 for n in range(0, N, 1):
220 for c in range(0, C, 1):
221 offset_base = (
222 (n * C + c) * IH + span_start_h[:, None]
223 ) * IW + span_start_w[None, :]
225 data00 = tl.load(
226 ptr_i + (offset_base + 0 * IW + 0),
227 mask=(mask_y0 & mask_x0),
228 other=0,
229 )
230 data01 = tl.load(
231 ptr_i + (offset_base + 0 * IW + 1),
232 mask=(mask_y0 & mask_x1),
233 other=0,
234 )
235 data02 = tl.load(
236 ptr_i + (offset_base + 0 * IW + 2),
237 mask=(mask_y0 & mask_x2),
238 other=0,
239 )
240 data03 = tl.load(
241 ptr_i + (offset_base + 0 * IW + 3),
242 mask=(mask_y0 & mask_x3),
243 other=0,
244 )
245 data04 = tl.load(
246 ptr_i + (offset_base + 0 * IW + 4),
247 mask=(mask_y0 & mask_x4),
248 other=0,
249 )
251 data10 = tl.load(
252 ptr_i + (offset_base + 1 * IW + 0),
253 mask=(mask_y1 & mask_x0),
254 other=0,
255 )
256 data11 = tl.load(
257 ptr_i + (offset_base + 1 * IW + 1),
258 mask=(mask_y1 & mask_x1),
259 other=0,
260 )
261 data12 = tl.load(
262 ptr_i + (offset_base + 1 * IW + 2),
263 mask=(mask_y1 & mask_x2),
264 other=0,
265 )
266 data13 = tl.load(
267 ptr_i + (offset_base + 1 * IW + 3),
268 mask=(mask_y1 & mask_x3),
269 other=0,
270 )
271 data14 = tl.load(
272 ptr_i + (offset_base + 1 * IW + 4),
273 mask=(mask_y1 & mask_x4),
274 other=0,
275 )
277 data20 = tl.load(
278 ptr_i + (offset_base + 2 * IW + 0),
279 mask=(mask_y2 & mask_x0),
280 other=0,
281 )
282 data21 = tl.load(
283 ptr_i + (offset_base + 2 * IW + 1),
284 mask=(mask_y2 & mask_x1),
285 other=0,
286 )
287 data22 = tl.load(
288 ptr_i + (offset_base + 2 * IW + 2),
289 mask=(mask_y2 & mask_x2),
290 other=0,
291 )
292 data23 = tl.load(
293 ptr_i + (offset_base + 2 * IW + 3),
294 mask=(mask_y2 & mask_x3),
295 other=0,
296 )
297 data24 = tl.load(
298 ptr_i + (offset_base + 2 * IW + 4),
299 mask=(mask_y2 & mask_x4),
300 other=0,
301 )
303 data30 = tl.load(
304 ptr_i + (offset_base + 3 * IW + 0),
305 mask=(mask_y3 & mask_x0),
306 other=0,
307 )
308 data31 = tl.load(
309 ptr_i + (offset_base + 3 * IW + 1),
310 mask=(mask_y3 & mask_x1),
311 other=0,
312 )
313 data32 = tl.load(
314 ptr_i + (offset_base + 3 * IW + 2),
315 mask=(mask_y3 & mask_x2),
316 other=0,
317 )
318 data33 = tl.load(
319 ptr_i + (offset_base + 3 * IW + 3),
320 mask=(mask_y3 & mask_x3),
321 other=0,
322 )
323 data34 = tl.load(
324 ptr_i + (offset_base + 3 * IW + 4),
325 mask=(mask_y3 & mask_x4),
326 other=0,
327 )
329 data40 = tl.load(
330 ptr_i + (offset_base + 4 * IW + 0),
331 mask=(mask_y4 & mask_x0),
332 other=0,
333 )
334 data41 = tl.load(
335 ptr_i + (offset_base + 4 * IW + 1),
336 mask=(mask_y4 & mask_x1),
337 other=0,
338 )
339 data42 = tl.load(
340 ptr_i + (offset_base + 4 * IW + 2),
341 mask=(mask_y4 & mask_x2),
342 other=0,
343 )
344 data43 = tl.load(
345 ptr_i + (offset_base + 4 * IW + 3),
346 mask=(mask_y4 & mask_x3),
347 other=0,
348 )
349 data44 = tl.load(
350 ptr_i + (offset_base + 4 * IW + 4),
351 mask=(mask_y4 & mask_x4),
352 other=0,
353 )
355 data0 = (
356 data00 * weight_x0[None, :]
357 + data01 * weight_x1[None, :]
358 + data02 * weight_x2[None, :]
359 + data03 * weight_x3[None, :]
360 + data04 * weight_x4[None, :]
361 )
362 data1 = (
363 data10 * weight_x0[None, :]
364 + data11 * weight_x1[None, :]
365 + data12 * weight_x2[None, :]
366 + data13 * weight_x3[None, :]
367 + data14 * weight_x4[None, :]
368 )
369 data2 = (
370 data20 * weight_x0[None, :]
371 + data21 * weight_x1[None, :]
372 + data22 * weight_x2[None, :]
373 + data23 * weight_x3[None, :]
374 + data24 * weight_x4[None, :]
375 )
376 data3 = (
377 data30 * weight_x0[None, :]
378 + data31 * weight_x1[None, :]
379 + data32 * weight_x2[None, :]
380 + data33 * weight_x3[None, :]
381 + data34 * weight_x4[None, :]
382 )
383 data4 = (
384 data40 * weight_x0[None, :]
385 + data41 * weight_x1[None, :]
386 + data42 * weight_x2[None, :]
387 + data43 * weight_x3[None, :]
388 + data44 * weight_x4[None, :]
389 )
390 result = (
391 data0 * weight_y0[:, None]
392 + data1 * weight_y1[:, None]
393 + data2 * weight_y2[:, None]
394 + data3 * weight_y3[:, None]
395 + data4 * weight_y4[:, None]
396 )
398 offset_o = ((n * C + c) * OH + oh[:, None]) * OW + ow[None, :]
399 tl.store(ptr_o + offset_o, result)
402# upsample and downsample
403# @triton.autotune(
404# configs=runtime.get_tuned_config("upsample_bicubic2d_aa"),
405# key=["N", "C", "OH", "OW"],
406# )
407@triton.heuristics(
408 values={
409 "BLOCK_X": heur_m_block_size,
410 "BLOCK_Y": heur_n_block_size,
411 },
412)
413@triton.jit
414def general_interpolate_bicubic2d_aa_kernel(
415 ptr_o,
416 ptr_i,
417 N,
418 C,
419 OH,
420 OW,
421 IH,
422 IW,
423 reciprocal_scale_h,
424 reciprocal_scale_w,
425 BLOCK_X: tl.constexpr,
426 BLOCK_Y: tl.constexpr,
427):
428 pid_x = tle.program_id(axis=0)
429 pid_y = tle.program_id(axis=1)
430 ow = (pid_x * BLOCK_X + tl.arange(0, BLOCK_X)) % OW
431 oh = (pid_y * BLOCK_Y + tl.arange(0, BLOCK_Y)) % OH
433 if reciprocal_scale_w >= 1.0:
434 support_w = 2 * reciprocal_scale_w
435 else:
436 support_w = 2.0
437 if reciprocal_scale_h >= 1.0:
438 support_h = 2 * reciprocal_scale_h
439 else:
440 support_h = 2.0
442 interpolate_w = (support_w + 0.5).to(tl.int32) * 2 + 1
443 interpolate_h = (support_h + 0.5).to(tl.int32) * 2 + 1
445 # _compute_weights_span
446 center_w = (ow + 0.5) * reciprocal_scale_w
447 center_h = (oh + 0.5) * reciprocal_scale_h
448 span_start_w = tl.maximum(center_w - support_w + 0.5, 0).to(tl.int32)
449 span_start_h = tl.maximum(center_h - support_h + 0.5, 0).to(tl.int32)
450 span_size_w = (tl.minimum(center_w + support_w + 0.5, IW) - span_start_w).to(
451 tl.int32
452 )
453 span_size_h = (tl.minimum(center_h + support_h + 0.5, IH) - span_start_h).to(
454 tl.int32
455 )
457 if reciprocal_scale_w >= 1.0:
458 invscale_w = 1.0 / reciprocal_scale_w
459 else:
460 invscale_w = 1.0
461 if reciprocal_scale_h >= 1.0:
462 invscale_h = 1.0 / reciprocal_scale_h
463 else:
464 invscale_h = 1.0
465 start_minus_center_w = span_start_w - center_w
466 start_minus_center_h = span_start_h - center_h
468 a = -0.5
469 for n in range(0, N, 1):
470 for c in range(0, C, 1):
471 offset_base = ((n * C + c) * IH + span_start_h[:, None]) * IW + span_start_w
472 weight_y_total = tl.zeros((BLOCK_Y,), dtype=tl.float32)
473 result = tl.zeros((BLOCK_Y, BLOCK_X), dtype=tl.float32)
474 for y in range(0, interpolate_h, 1):
475 wy = tl.abs((y + start_minus_center_h + 0.5) * invscale_h)
476 weight_y = tl.where(
477 y < span_size_h,
478 tl.where(
479 wy < 1.0,
480 ((a + 2) * wy - (a + 3)) * wy * wy + 1,
481 tl.where(wy < 2.0, (((wy - 5) * wy + 8) * wy - 4) * a, 0),
482 ),
483 0,
484 )
485 weight_y_total += weight_y
486 weight_x_total = tl.zeros((BLOCK_X,), dtype=tl.float32)
487 buffer = tl.zeros((BLOCK_Y, BLOCK_X), dtype=tl.float32)
488 for x in range(0, interpolate_w, 1):
489 wx = tl.abs((x + start_minus_center_w + 0.5) * invscale_w)
490 weight_x = tl.where(
491 x < span_size_w,
492 tl.where(
493 wx < 1.0,
494 ((a + 2) * wx - (a + 3)) * wx * wx + 1,
495 tl.where(wx < 2.0, (((wx - 5) * wx + 8) * wx - 4) * a, 0),
496 ),
497 0,
498 )
499 weight_x_total += weight_x
500 data = tl.load(
501 ptr_i + (offset_base + y * IW + x),
502 mask=(span_start_h[:, None] + y < IH)
503 & (span_start_w[None, :] + x < IW),
504 other=0,
505 )
506 buffer += data * weight_x[None, :]
507 weight_x_total = tl.where(weight_x_total != 0, weight_x_total, 1)
508 result += buffer / weight_x_total[None, :] * weight_y[:, None]
509 weight_y_total = tl.where(weight_y_total != 0, weight_y_total, 1)
510 result /= weight_y_total[:, None]
511 offset_o = ((n * C + c) * OH + oh[:, None]) * OW + ow[None, :]
512 tl.store(ptr_o + offset_o, result)
515def bicubic_reciprocal_scale(src_size, dst_size, align_corners, scale):
516 if align_corners:
517 if dst_size > 1:
518 return (src_size - 1) / (dst_size - 1)
519 else:
520 return 0
521 else:
522 if scale is not None and scale > 0:
523 return 1.0 / scale
524 else:
525 return src_size / dst_size
528# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L12547
529def _upsample_bicubic2d_aa(
530 input: torch.Tensor,
531 output_size: Tuple[int],
532 align_corners: bool = False,
533 scales_h: Optional[float] = None,
534 scales_w: Optional[float] = None,
535):
536 logger.debug("GEMS UPSAMPLE BICUBIC2D AA")
537 assert input.device.type == device
538 assert input.ndim == 4, "The ndim of input must be 4"
539 assert len(output_size) == 2, "The len of output_size must be 2"
541 OH, OW = output_size
542 N, C, IH, IW = input.shape
544 reciprocal_scale_h = bicubic_reciprocal_scale(IH, OH, align_corners, scales_h)
545 reciprocal_scale_w = bicubic_reciprocal_scale(IW, OW, align_corners, scales_w)
547 # allocate output
548 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
549 grid = lambda META: (
550 triton.cdiv(OW, META["BLOCK_X"]),
551 triton.cdiv(OH, META["BLOCK_Y"]),
552 )
553 kernel = (
554 general_interpolate_bicubic2d_aa_kernel
555 if (reciprocal_scale_w >= 1.0) or (reciprocal_scale_h >= 1.0)
556 else upsample_bicubic2d_aa_kernel
557 )
559 import os
561 os.environ["TRITONXPU_OTHER_SIM"] = "1"
562 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
563 with torch_device_fn.device(input.device):
564 kernel[grid](
565 output,
566 input,
567 N,
568 C,
569 OH,
570 OW,
571 IH,
572 IW,
573 reciprocal_scale_h,
574 reciprocal_scale_w,
575 )
577 if "TRITONXPU_OTHER_SIM" in os.environ:
578 del os.environ["TRITONXPU_OTHER_SIM"]
579 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
580 del os.environ["TRITONXPU_STORE_MASK_SIM"]
582 return output