Coverage for src/flag_gems/runtime/backend/_cambricon/ops/upsample_nearest2d.py: 0%
108 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
11from ..utils import MAX_GRID_SIZE_X, TOTAL_CORE_NUM
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14device = device.name
17@triton.autotune(
18 configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"]
19)
20@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d"))
21@triton.jit
22def upsample_nearest2d_kernel(
23 ptr_o,
24 ptr_i,
25 N,
26 C,
27 OH,
28 OW,
29 IH,
30 IW,
31 reciprocal_scale_h,
32 reciprocal_scale_w,
33 BLOCK_SIZE: tl.constexpr,
34 SAME_H: tl.constexpr,
35 SAME_W: tl.constexpr,
36):
37 pid = tl.program_id(axis=0) + tl.program_id(axis=1) * tl.num_programs(0)
38 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
39 ow = idx % OW
40 oh = idx // OW % OH
41 c = idx // OW // OH % C
42 n = idx // OW // OH // C % N
43 if SAME_H:
44 ih = oh
45 else:
46 # tl.floor() cannot be found in 2.3.1, using int trunc
47 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1)
48 if SAME_W:
49 iw = ow
50 else:
51 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1)
52 offset_o = ((n * C + c) * OH + oh) * OW + ow
53 offset_i = ((n * C + c) * IH + ih) * IW + iw
54 data = tl.load(ptr_i + offset_i)
55 tl.store(ptr_o + offset_o, data)
58def configs2():
59 block_h = [8, 16, 32, 64, 128, 256]
60 num_stage = [1, 3]
61 return [
62 triton.Config({"BLOCK_H": bh}, num_warps=1, num_stages=s)
63 for s in num_stage
64 for bh in block_h
65 ]
68@triton.autotune(configs=configs2(), key=["N", "C", "OH", "OW"])
69@triton.jit
70def upsample_nearest2d_kernel_opt(
71 ptr_o,
72 ptr_i,
73 N,
74 C,
75 OH,
76 OW: tl.constexpr,
77 IH,
78 IW: tl.constexpr,
79 BLOCK_H: tl.constexpr,
80):
81 pid = tl.program_id(axis=0)
82 num_jobs = tl.num_programs(axis=0)
84 nc_nums_per_job = (N * C + num_jobs - 1) // num_jobs
85 nc_begin = pid * nc_nums_per_job
86 nc_end = min(nc_begin + nc_nums_per_job, N * C)
88 loop_num_h = (OH + BLOCK_H - 1) // BLOCK_H
89 for idx in range((nc_end - nc_begin) * loop_num_h):
90 nc_idx = nc_begin + (idx // loop_num_h)
91 h_idx = (idx % loop_num_h) * BLOCK_H
93 init_out = nc_idx * OH * OW
94 init_in = nc_idx * IH * IW
96 ih = h_idx // 2 + tl.arange(0, BLOCK_H // 2)
97 iw = tl.arange(0, IW)
98 offset_i = init_in + ih[:, None] * IW + iw
100 oh = h_idx + tl.arange(0, BLOCK_H)
101 ow = tl.arange(0, OW)
102 offset_o = init_out + oh[:, None] * OW + ow
104 data = tl.load(ptr_i + offset_i, mask=(ih[:, None] < IH))
106 tmp = (
107 data.reshape(BLOCK_H // 2, OW // 2, 1)
108 .broadcast_to(BLOCK_H // 2, OW // 2, 2)
109 .reshape(BLOCK_H // 2, 1, OW)
110 )
111 tmp1 = tmp.broadcast_to(BLOCK_H // 2, 2, OW).reshape(BLOCK_H, OW)
113 tl.store(ptr_o + offset_o, tmp1, mask=(oh[:, None] < OH))
116@triton.autotune(configs=configs2(), key=["N", "C", "OH", "OW"])
117@triton.jit
118def upsample_nearest2d_kernel_opt_tile_h(
119 ptr_o,
120 ptr_i,
121 N,
122 C,
123 OH,
124 OW: tl.constexpr,
125 IH,
126 IW: tl.constexpr,
127 BLOCK_H: tl.constexpr,
128):
129 pid = tl.program_id(axis=0)
130 num_jobs = tl.num_programs(axis=0)
132 start = pid * BLOCK_H
133 step = BLOCK_H * num_jobs
134 loop_num_h = (OH - start + step - 1) // step
136 for idx in range(N * C * loop_num_h):
137 nc_idx = idx // loop_num_h
138 h_idx = (idx % loop_num_h) * step + start
140 init_out = nc_idx * OH * OW
141 init_in = nc_idx * IH * IW
143 ih = h_idx // 2 + tl.arange(0, BLOCK_H // 2)
144 iw = tl.arange(0, IW)
145 offset_i = init_in + ih[:, None] * IW + iw
147 oh = h_idx + tl.arange(0, BLOCK_H)
148 ow = tl.arange(0, OW)
149 offset_o = init_out + oh[:, None] * OW + ow
151 data = tl.load(ptr_i + offset_i, mask=(ih[:, None] < IH))
153 tmp = (
154 data.reshape(BLOCK_H // 2, OW // 2, 1)
155 .broadcast_to(BLOCK_H // 2, OW // 2, 2)
156 .reshape(BLOCK_H // 2, 1, OW)
157 )
158 tmp1 = tmp.broadcast_to(BLOCK_H // 2, 2, OW).reshape(BLOCK_H, OW)
160 tl.store(ptr_o + offset_o, tmp1, mask=(oh[:, None] < OH))
163def upsample_nearest2d(
164 input: torch.Tensor,
165 output_size: Tuple[int],
166 scales_h: Optional[float] = None,
167 scales_w: Optional[float] = None,
168) -> torch.Tensor:
169 logger.debug("GEMS_CAMBRICON UPSAMPLE NEAREST2D")
170 assert input.device.type == device
171 assert input.ndim == 4, "The ndim of input must be 4"
172 assert len(output_size) == 2, "The len of output_size must be 2"
173 OH, OW = output_size
174 N, C, IH, IW = input.shape
175 if scales_h is not None:
176 reciprocal_scale_h = 1 / scales_h
177 else:
178 reciprocal_scale_h = IH / OH
179 if scales_w is not None:
180 reciprocal_scale_w = 1 / scales_w
181 else:
182 reciprocal_scale_w = IW / OW
183 # allocate output
184 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
186 with torch_device_fn.device(input.device):
187 if (
188 reciprocal_scale_h == 0.5
189 and reciprocal_scale_w == 0.5
190 and IH / OH == 0.5
191 and IW / OW == 0.5
192 ):
193 if N * C > 48:
194 upsample_nearest2d_kernel_opt[TOTAL_CORE_NUM,](
195 output, input, N, C, OH, OW, IH, IW
196 )
197 else:
198 upsample_nearest2d_kernel_opt_tile_h[TOTAL_CORE_NUM,](
199 output, input, N, C, OH, OW, IH, IW
200 )
201 else:
202 total_threads = N * C * OH * OW
204 # incase grid check error
205 def grid_fn(META):
206 num_threads = triton.cdiv(total_threads, META["BLOCK_SIZE"])
207 grid_x = min(num_threads, MAX_GRID_SIZE_X)
208 grid_y = triton.cdiv(num_threads, grid_x)
209 return (
210 grid_x,
211 grid_y,
212 )
214 upsample_nearest2d_kernel[grid_fn](
215 output,
216 input,
217 N,
218 C,
219 OH,
220 OW,
221 IH,
222 IW,
223 reciprocal_scale_h,
224 reciprocal_scale_w,
225 )
227 return output