Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/upsample_nearest2d.py: 0%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13device = device.name
16def heur_block_size(args):
17 return triton.next_power_of_2(
18 triton.cdiv(args["N"] * args["C"] * args["OH"] * args["OW"], 12)
19 ) # cluster_num
22# @triton.autotune(
23# configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"]
24# )
25@triton.heuristics(
26 {
27 "SAME_H": lambda args: args["OH"] == args["IH"],
28 "SAME_W": lambda args: args["OW"] == args["IW"],
29 "BLOCK_SIZE": heur_block_size,
30 }
31)
32@triton.jit
33def upsample_nearest2d_kernel(
34 ptr_o,
35 ptr_i,
36 N: tl.constexpr,
37 C: tl.constexpr,
38 OH,
39 OW,
40 IH,
41 IW,
42 reciprocal_scale_h,
43 reciprocal_scale_w,
44 BLOCK_SIZE: tl.constexpr,
45 SAME_H: tl.constexpr,
46 SAME_W: tl.constexpr,
47):
48 pid = tle.program_id(axis=0)
49 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
50 ow = idx % OW
51 oh = idx // OW % OH
52 c = idx // OW // OH % C
53 n = idx // OW // OH // C % N
54 if SAME_H:
55 ih = oh
56 else:
57 # tl.floor() cannot be found in 2.3.1, using int trunc
58 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1)
59 if SAME_W:
60 iw = ow
61 else:
62 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1)
63 offset_o = ((n * C + c) * OH + oh) * OW + ow
64 offset_i = ((n * C + c) * IH + ih) * IW + iw
65 data = tl.load(ptr_i + offset_i)
66 tl.store(ptr_o + offset_o, data)
69def upsample_nearest2d(
70 input: torch.Tensor,
71 output_size: Tuple[int],
72 scales_h: Optional[float] = None,
73 scales_w: Optional[float] = None,
74) -> torch.Tensor:
75 logger.debug("GEMS UPSAMPLE NEAREST2D")
76 assert input.device.type == device
77 assert input.ndim == 4, "The ndim of input must be 4"
78 assert len(output_size) == 2, "The len of output_size must be 2"
79 OH, OW = output_size
80 N, C, IH, IW = input.shape
81 if scales_h is not None:
82 reciprocal_scale_h = 1 / scales_h
83 else:
84 reciprocal_scale_h = IH / OH
85 if scales_w is not None:
86 reciprocal_scale_w = 1 / scales_w
87 else:
88 reciprocal_scale_w = IW / OW
89 # allocate output
90 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
91 total_threads = N * C * OH * OW
92 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),)
93 with torch_device_fn.device(input.device):
94 upsample_nearest2d_kernel[grid](
95 output, input, N, C, OH, OW, IH, IW, reciprocal_scale_h, reciprocal_scale_w
96 )
97 return output