Coverage for src/flag_gems/experimental_ops/upsample_nearest3d.py: 0%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def upsample_nearest3d_kernel( 

8 in_ptr, 

9 out_ptr, 

10 N, 

11 C, 

12 ID, 

13 IH, 

14 IW, 

15 OD, 

16 OH, 

17 OW, 

18 in_stride_n, 

19 in_stride_c, 

20 in_stride_d, 

21 in_stride_h, 

22 in_stride_w, 

23 out_stride_n, 

24 out_stride_c, 

25 out_stride_d, 

26 out_stride_h, 

27 out_stride_w, 

28 scale_d, 

29 scale_h, 

30 scale_w, 

31 total_elements, 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid = tl.program_id(axis=0) 

35 block_start = pid * BLOCK_SIZE 

36 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

37 mask = offsets < total_elements 

38 

39 # Unravel offsets into (n, c, od, oh, ow) for an output tensor of shape [N, C, OD, OH, OW] 

40 ow = offsets % OW 

41 tmp = offsets // OW 

42 oh = tmp % OH 

43 tmp = tmp // OH 

44 od = tmp % OD 

45 tmp = tmp // OD 

46 c = tmp % C 

47 n = tmp // C 

48 

49 # Compute nearest input indices 

50 od_f = od.to(tl.float32) 

51 oh_f = oh.to(tl.float32) 

52 ow_f = ow.to(tl.float32) 

53 

54 id_src = tl.minimum((od_f * scale_d).to(tl.int32), ID - 1) 

55 ih_src = tl.minimum((oh_f * scale_h).to(tl.int32), IH - 1) 

56 iw_src = tl.minimum((ow_f * scale_w).to(tl.int32), IW - 1) 

57 

58 # Compute input/output offsets using strides 

59 in_offset = ( 

60 n * in_stride_n 

61 + c * in_stride_c 

62 + id_src * in_stride_d 

63 + ih_src * in_stride_h 

64 + iw_src * in_stride_w 

65 ) 

66 out_offset = ( 

67 n * out_stride_n 

68 + c * out_stride_c 

69 + od * out_stride_d 

70 + oh * out_stride_h 

71 + ow * out_stride_w 

72 ) 

73 

74 vals = tl.load(in_ptr + in_offset, mask=mask, other=0) 

75 tl.store(out_ptr + out_offset, vals, mask=mask) 

76 

77 

78def _ensure_5d_input(x: torch.Tensor): 

79 if x.dim() != 5: 

80 raise ValueError( 

81 f"Expected 5D input [N, C, D, H, W], but got shape {tuple(x.shape)}" 

82 ) 

83 return x 

84 

85 

86def _normalize_output_size(output_size): 

87 if output_size is None: 

88 return None 

89 if isinstance(output_size, torch.Size): 

90 output_size = tuple(int(s) for s in output_size) 

91 elif isinstance(output_size, (list, tuple)): 

92 output_size = tuple(int(s) for s in output_size) 

93 else: 

94 raise ValueError("output_size must be a sequence of 3 integers or torch.Size") 

95 if len(output_size) != 3: 

96 raise ValueError("output_size must have length 3: (out_d, out_h, out_w)") 

97 return output_size 

98 

99 

100def _normalize_scale_factors(scales): 

101 if scales is None: 

102 return None 

103 if isinstance(scales, (list, tuple)): 

104 if len(scales) != 3: 

105 raise ValueError( 

106 "scale_factors must have length 3: (scale_d, scale_h, scale_w)" 

107 ) 

108 return tuple(float(s) if s is not None else None for s in scales) 

109 else: 

110 raise ValueError("scale_factors must be a sequence of 3 floats") 

111 

112 

113def _compute_out_size_and_kernel_scales(ID, IH, IW, output_size, scales_tuple): 

114 # Returns (OD, OH, OW, kscale_d, kscale_h, kscale_w) 

115 # kscale_* is the multiplier used as: src_idx = floor(out_idx * kscale_*) 

116 if output_size is not None: 

117 OD, OH, OW = int(output_size[0]), int(output_size[1]), int(output_size[2]) 

118 if OD <= 0 or OH <= 0 or OW <= 0: 

119 raise ValueError("Output sizes must be positive") 

120 # When output_size is given, kscale = input_size / output_size 

121 kscale_d = float(ID) / float(OD) 

122 kscale_h = float(IH) / float(OH) 

123 kscale_w = float(IW) / float(OW) 

124 else: 

125 sd, sh, sw = scales_tuple 

126 if sd is None or sh is None or sw is None: 

127 raise ValueError( 

128 "All scale factors (scale_d, scale_h, scale_w) must be provided when output_size is None" 

129 ) 

130 if sd <= 0.0 or sh <= 0.0 or sw <= 0.0: 

131 raise ValueError("Scale factors must be positive") 

132 OD = int(torch.floor(torch.tensor(ID * sd)).item()) 

133 OH = int(torch.floor(torch.tensor(IH * sh)).item()) 

134 OW = int(torch.floor(torch.tensor(IW * sw)).item()) 

135 if OD <= 0 or OH <= 0 or OW <= 0: 

136 raise ValueError("Computed output sizes must be positive") 

137 # When scale_factors are given, src_idx = floor(out_idx / scale) = floor(out_idx * (1/scale)) 

138 kscale_d = 1.0 / float(sd) 

139 kscale_h = 1.0 / float(sh) 

140 kscale_w = 1.0 / float(sw) 

141 return OD, OH, OW, kscale_d, kscale_h, kscale_w 

142 

143 

144def _launch_upsample_nearest3d( 

145 input: torch.Tensor, 

146 output: torch.Tensor, 

147 kscale_d: float, 

148 kscale_h: float, 

149 kscale_w: float, 

150): 

151 N, C, ID, IH, IW = input.shape 

152 OD, OH, OW = output.shape[2], output.shape[3], output.shape[4] 

153 

154 in_strides = input.stride() 

155 out_strides = output.stride() 

156 

157 total = N * C * OD * OH * OW 

158 if total == 0: 

159 return output 

160 

161 BLOCK_SIZE = 1024 

162 grid = lambda meta: (triton.cdiv(total, meta["BLOCK_SIZE"]),) 

163 

164 upsample_nearest3d_kernel[grid]( 

165 input, 

166 output, 

167 N, 

168 C, 

169 ID, 

170 IH, 

171 IW, 

172 OD, 

173 OH, 

174 OW, 

175 in_strides[0], 

176 in_strides[1], 

177 in_strides[2], 

178 in_strides[3], 

179 in_strides[4], 

180 out_strides[0], 

181 out_strides[1], 

182 out_strides[2], 

183 out_strides[3], 

184 out_strides[4], 

185 float(kscale_d), 

186 float(kscale_h), 

187 float(kscale_w), 

188 total, 

189 BLOCK_SIZE=BLOCK_SIZE, 

190 ) 

191 return output 

192 

193 

194def upsample_nearest3d( 

195 input: torch.Tensor, output_size=None, scales_d=None, scales_h=None, scales_w=None 

196): 

197 x = _ensure_5d_input(input) 

198 output_size = _normalize_output_size(output_size) 

199 scales_tuple = None 

200 if output_size is None: 

201 scales_tuple = ( 

202 None if scales_d is None else float(scales_d), 

203 None if scales_h is None else float(scales_h), 

204 None if scales_w is None else float(scales_w), 

205 ) 

206 N, C, ID, IH, IW = x.shape 

207 OD, OH, OW, ksd, ksh, ksw = _compute_out_size_and_kernel_scales( 

208 ID, IH, IW, output_size, scales_tuple 

209 ) 

210 out = torch.empty( 

211 (N, C, OD, OH, OW), dtype=x.dtype, device=x.device, layout=x.layout 

212 ) 

213 return _launch_upsample_nearest3d(x, out, ksd, ksh, ksw) 

214 

215 

216def upsample_nearest3d_vec(input: torch.Tensor, output_size=None, scale_factors=None): 

217 x = _ensure_5d_input(input) 

218 output_size = _normalize_output_size(output_size) 

219 scales_tuple = None 

220 if output_size is None: 

221 scales_tuple = _normalize_scale_factors(scale_factors) 

222 N, C, ID, IH, IW = x.shape 

223 OD, OH, OW, ksd, ksh, ksw = _compute_out_size_and_kernel_scales( 

224 ID, IH, IW, output_size, scales_tuple 

225 ) 

226 out = torch.empty( 

227 (N, C, OD, OH, OW), dtype=x.dtype, device=x.device, layout=x.layout 

228 ) 

229 return _launch_upsample_nearest3d(x, out, ksd, ksh, ksw) 

230 

231 

232def upsample_nearest3d_out( 

233 input: torch.Tensor, 

234 output_size=None, 

235 scales_d=None, 

236 scales_h=None, 

237 scales_w=None, 

238 out: torch.Tensor = None, 

239): 

240 x = _ensure_5d_input(input) 

241 output_size = _normalize_output_size(output_size) 

242 scales_tuple = None 

243 if output_size is None: 

244 scales_tuple = ( 

245 None if scales_d is None else float(scales_d), 

246 None if scales_h is None else float(scales_h), 

247 None if scales_w is None else float(scales_w), 

248 ) 

249 N, C, ID, IH, IW = x.shape 

250 OD, OH, OW, ksd, ksh, ksw = _compute_out_size_and_kernel_scales( 

251 ID, IH, IW, output_size, scales_tuple 

252 ) 

253 

254 if out is None: 

255 raise ValueError("Argument 'out' must be provided for upsample_nearest3d_out") 

256 if out.device != x.device or out.dtype != x.dtype: 

257 raise ValueError( 

258 "Output tensor 'out' must have the same device and dtype as input" 

259 ) 

260 expected_shape = (N, C, OD, OH, OW) 

261 if tuple(out.shape) != expected_shape: 

262 raise ValueError( 

263 f"Output tensor 'out' must have shape {expected_shape}, but got {tuple(out.shape)}" 

264 ) 

265 

266 _launch_upsample_nearest3d(x, out, ksd, ksh, ksw) 

267 return out