Coverage for src/flag_gems/ops/upsample_nearest2d.py: 52%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
11device = device.name
12logger = logging.getLogger(__name__)
15@triton.autotune(
16 configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"]
17)
18@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d"))
19@triton.jit
20def upsample_nearest2d_kernel(
21 ptr_o,
22 ptr_i,
23 N,
24 C,
25 OH,
26 OW,
27 IH,
28 IW,
29 reciprocal_scale_h,
30 reciprocal_scale_w,
31 BLOCK_SIZE: tl.constexpr,
32 SAME_H: tl.constexpr,
33 SAME_W: tl.constexpr,
34 USE_INT32_IDX: tl.constexpr,
35):
36 if USE_INT32_IDX:
37 pid = tl.program_id(axis=0)
38 else:
39 pid = tl.program_id(axis=0).to(tl.int64)
40 nc_stride = tl.num_programs(axis=1)
41 NC = N * C
42 nc_iter = tl.program_id(axis=1)
43 pid = tl.program_id(axis=0)
44 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
45 ow = idx % OW
46 oh = idx // OW % OH
47 if SAME_H:
48 ih = oh
49 else:
50 # tl.floor() cannot be found in 2.3.1, using int trunc
51 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1)
52 if SAME_W:
53 iw = ow
54 else:
55 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1)
57 offset_o = (nc_iter * OH + oh) * OW + ow
58 offset_i = (nc_iter * IH + ih) * IW + iw
59 src_index_stride = nc_stride * IH * IW
60 dst_index_stride = nc_stride * OH * OW
61 while nc_iter < NC:
62 data = tl.load(ptr_i + offset_i)
63 tl.store(ptr_o + offset_o, data)
64 ptr_i += src_index_stride
65 ptr_o += dst_index_stride
66 nc_iter += nc_stride
69def upsample_nearest2d(
70 input: torch.Tensor,
71 output_size: Tuple[int],
72 scales_h: Optional[float] = None,
73 scales_w: Optional[float] = None,
74) -> torch.Tensor:
75 logger.debug("GEMS UPSAMPLE NEAREST2D")
76 assert input.device.type == device
77 assert input.ndim == 4, "The ndim of input must be 4"
78 assert len(output_size) == 2, "The len of output_size must be 2"
79 OH, OW = output_size
80 N, C, IH, IW = input.shape
81 if scales_h is not None:
82 reciprocal_scale_h = 1 / scales_h
83 else:
84 reciprocal_scale_h = IH / OH
85 if scales_w is not None:
86 reciprocal_scale_w = 1 / scales_w
87 else:
88 reciprocal_scale_w = IW / OW
89 # allocate output
90 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
91 total_threads = OH * OW
92 grid = lambda META: (
93 triton.cdiv(total_threads, META["BLOCK_SIZE"]),
94 triton.cdiv(N * C, 4),
95 )
97 with torch_device_fn.device(input.device):
98 upsample_nearest2d_kernel[grid](
99 output,
100 input,
101 N,
102 C,
103 OH,
104 OW,
105 IH,
106 IW,
107 reciprocal_scale_h,
108 reciprocal_scale_w,
109 )
110 return output