Coverage for src/flag_gems/runtime/backend/_metax/ops/upsample_nearest2d.py: 0%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems." + __name__)
13device = device.name
16@triton.autotune(
17 configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"]
18)
19@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d"))
20@triton.jit
21def upsample_nearest2d_kernel(
22 ptr_o,
23 ptr_i,
24 N,
25 C,
26 OH,
27 OW,
28 IH,
29 IW,
30 reciprocal_scale_h,
31 reciprocal_scale_w,
32 BLOCK_SIZE: tl.constexpr,
33 SAME_H: tl.constexpr,
34 SAME_W: tl.constexpr,
35 USE_INT32_IDX: tl.constexpr,
36):
37 if USE_INT32_IDX:
38 pid = tl.program_id(axis=0)
39 else:
40 pid = tle.program_id(axis=0)
41 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42 ow = idx % OW
43 oh = idx // OW % OH
44 c = idx // OW // OH % C
45 n = idx // OW // OH // C % N
46 if SAME_H:
47 ih = oh
48 else:
49 # tl.floor() cannot be found in 2.3.1, using int trunc
50 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1)
51 if SAME_W:
52 iw = ow
53 else:
54 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1)
55 offset_o = ((n * C + c) * OH + oh) * OW + ow
56 offset_i = ((n * C + c) * IH + ih) * IW + iw
57 data = tl.load(ptr_i + offset_i)
58 tl.store(ptr_o + offset_o, data)
61def upsample_nearest2d(
62 input: torch.Tensor,
63 output_size: Tuple[int],
64 scales_h: Optional[float] = None,
65 scales_w: Optional[float] = None,
66) -> torch.Tensor:
67 logger.debug("METAX GEMS UPSAMPLE NEAREST2D")
68 assert input.device.type == device
69 assert input.ndim == 4, "The ndim of input must be 4"
70 assert len(output_size) == 2, "The len of output_size must be 2"
71 OH, OW = output_size
72 N, C, IH, IW = input.shape
73 if scales_h is not None:
74 reciprocal_scale_h = 1 / scales_h
75 else:
76 reciprocal_scale_h = IH / OH
77 if scales_w is not None:
78 reciprocal_scale_w = 1 / scales_w
79 else:
80 reciprocal_scale_w = IW / OW
81 # allocate output
82 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
83 total_threads = N * C * OH * OW
84 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),)
85 with torch_device_fn.device(input.device):
86 upsample_nearest2d_kernel[grid](
87 output, input, N, C, OH, OW, IH, IW, reciprocal_scale_h, reciprocal_scale_w
88 )
89 return output