Coverage for src/flag_gems/experimental_ops/_upsample_nearest_exact1d.py: 0%

128 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7 

8@triton.jit 

9def _upsample_nearest_exact1d_kernel( 

10 in_ptr, 

11 out_ptr, 

12 N, 

13 C, 

14 IW, 

15 OW, 

16 sN_in, 

17 sC_in, 

18 sW_in, 

19 sN_out, 

20 sC_out, 

21 sW_out, 

22 use_scales: tl.constexpr, 

23 scale_w, 

24 BLOCK_W: tl.constexpr, 

25): 

26 pid_w = tl.program_id(0) 

27 pid_nc = tl.program_id(1) 

28 

29 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

30 mask = offs_w < OW 

31 

32 # Compute (n, c) from flattened plane index 

33 nc = pid_nc 

34 n = nc // C 

35 c = nc - n * C 

36 

37 base_in = n * sN_in + c * sC_in 

38 base_out = n * sN_out + c * sC_out 

39 

40 # Compute source indices iw for each output index ow 

41 iw = tl.zeros([BLOCK_W], dtype=tl.int32) 

42 if use_scales: 

43 ow_f = offs_w.to(tl.float32) 

44 iw_f = tl.floor(ow_f / scale_w) 

45 iw = iw_f.to(tl.int32) 

46 else: 

47 iw = (offs_w * IW) // OW 

48 iw = tl.minimum(iw, IW - 1) 

49 

50 in_ptrs = in_ptr + base_in + iw * sW_in 

51 x = tl.load(in_ptrs, mask=mask) 

52 

53 out_ptrs = out_ptr + base_out + offs_w * sW_out 

54 tl.store(out_ptrs, x, mask=mask) 

55 

56 

57def _parse_size_1d(val): 

58 if val is None: 

59 return None 

60 if isinstance(val, torch.Size): 

61 return int(val[-1]) if len(val) > 0 else None 

62 if isinstance(val, (list, tuple)): 

63 if len(val) == 0: 

64 return None 

65 return int(val[-1]) 

66 return int(val) 

67 

68 

69def _parse_scale_1d(val): 

70 if val is None: 

71 return None 

72 if isinstance(val, (list, tuple)): 

73 if len(val) == 0: 

74 return None 

75 return float(val[-1]) 

76 return float(val) 

77 

78 

79def _compute_out_w(iw, output_size, scale): 

80 if output_size is not None: 

81 return int(output_size) 

82 if scale is None: 

83 raise ValueError( 

84 "Either output_size or scale must be provided for _upsample_nearest_exact1d." 

85 ) 

86 # Follow common convention: OW = floor(IW * scale) 

87 return int(math.floor(iw * scale)) 

88 

89 

90def _launch_upsample_nearest_exact1d_kernel(input, out, output_size=None, scale=None): 

91 if input.ndim != 3: 

92 raise ValueError( 

93 f"_upsample_nearest_exact1d expects a 3D tensor (N, C, W); got shape {tuple(input.shape)}" 

94 ) 

95 if not input.is_cuda or not out.is_cuda: 

96 # Fallback to the native operator on CPU or non-CUDA devices 

97 return torch.ops.aten._upsample_nearest_exact1d( 

98 input, [out.shape[-1]], [scale] if scale is not None else None 

99 ) 

100 

101 N, C, IW = input.shape 

102 OW = out.shape[-1] 

103 

104 sN_in, sC_in, sW_in = input.stride() 

105 sN_out, sC_out, sW_out = out.stride() 

106 

107 BLOCK_W = 256 

108 grid = (triton.cdiv(OW, BLOCK_W), N * C) 

109 

110 use_scales = scale is not None and output_size is None 

111 scale_w = float(scale) if use_scales else 1.0 

112 

113 _upsample_nearest_exact1d_kernel[grid]( 

114 input, 

115 out, 

116 N, 

117 C, 

118 IW, 

119 OW, 

120 sN_in, 

121 sC_in, 

122 sW_in, 

123 sN_out, 

124 sC_out, 

125 sW_out, 

126 use_scales=use_scales, 

127 scale_w=scale_w, 

128 BLOCK_W=BLOCK_W, 

129 ) 

130 return out 

131 

132 

133def _extract_io_and_params(args, kwargs, expect_out=False): 

134 # Extract input tensor 

135 in_t = kwargs.get("input", None) 

136 if in_t is None: 

137 in_t = kwargs.get("self", None) 

138 if in_t is None and len(args) > 0 and isinstance(args[0], torch.Tensor): 

139 in_t = args[0] 

140 args = args[1:] 

141 if in_t is None or not isinstance(in_t, torch.Tensor): 

142 raise ValueError("Input tensor not found for _upsample_nearest_exact1d.") 

143 

144 # Extract output_size / scales from kwargs or remaining args 

145 output_size = kwargs.get( 

146 "output_size", kwargs.get("size", kwargs.get("output_size_list", None)) 

147 ) 

148 scales = kwargs.get( 

149 "scale_factor", 

150 kwargs.get("scales", kwargs.get("scale_factors", kwargs.get("scale", None))), 

151 ) 

152 

153 # If positional arguments contain size and/or scales 

154 # Try to interpret next positional as output_size if present and not a tensor 

155 pos = 0 

156 if ( 

157 output_size is None 

158 and pos < len(args) 

159 and not isinstance(args[pos], torch.Tensor) 

160 ): 

161 output_size = args[pos] 

162 pos += 1 

163 if scales is None and pos < len(args) and not isinstance(args[pos], torch.Tensor): 

164 scales = args[pos] 

165 pos += 1 

166 

167 out_t = None 

168 if expect_out: 

169 out_t = kwargs.get("out", None) 

170 if out_t is None: 

171 # find last tensor among remaining args as out 

172 for a in reversed(args): 

173 if isinstance(a, torch.Tensor): 

174 out_t = a 

175 break 

176 if out_t is None: 

177 raise ValueError( 

178 "Output tensor 'out' not found for _upsample_nearest_exact1d_out." 

179 ) 

180 

181 # Normalize single-dim size and scale 

182 out_w = _parse_size_1d(output_size) 

183 scale_w = _parse_scale_1d(scales) 

184 

185 return in_t, out_t, out_w, scale_w 

186 

187 

188def _prepare_out_tensor(in_t, out_w, scale_w, dtype=None, device=None): 

189 N, C, IW = in_t.shape 

190 OW = _compute_out_w(IW, out_w, scale_w) 

191 if OW < 0: 

192 raise ValueError("Output width must be non-negative.") 

193 if dtype is None: 

194 dtype = in_t.dtype 

195 if device is None: 

196 device = in_t.device 

197 return torch.empty((N, C, OW), dtype=dtype, device=device) 

198 

199 

200def _upsample_nearest_exact1d(*args, **kwargs): 

201 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False) 

202 out_t = _prepare_out_tensor(in_t, out_w, scale_w) 

203 if out_t.numel() == 0: 

204 return out_t 

205 return _launch_upsample_nearest_exact1d_kernel( 

206 in_t, out_t, output_size=out_w, scale=scale_w 

207 ) 

208 

209 

210def _upsample_nearest_exact1d_out(*args, **kwargs): 

211 in_t, out_t, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=True) 

212 if out_t.ndim != 3: 

213 raise ValueError( 

214 f"Out tensor must be 3D (N, C, W); got shape {tuple(out_t.shape)}" 

215 ) 

216 # Validate that out_t has the correct computed width if parameters are provided 

217 expected_w = _compute_out_w(in_t.shape[-1], out_w, scale_w) 

218 if out_t.shape[-1] != expected_w: 

219 raise ValueError( 

220 f"Provided out tensor has width {out_t.shape[-1]} but expected {expected_w}." 

221 ) 

222 if out_t.numel() == 0: 

223 return out_t 

224 return _launch_upsample_nearest_exact1d_kernel( 

225 in_t, out_t, output_size=out_w, scale=scale_w 

226 ) 

227 

228 

229def _upsample_nearest_exact1d_vec(*args, **kwargs): 

230 # Treat vec the same as base variant, allowing list-like output_size/scales 

231 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False) 

232 out_t = _prepare_out_tensor(in_t, out_w, scale_w) 

233 if out_t.numel() == 0: 

234 return out_t 

235 return _launch_upsample_nearest_exact1d_kernel( 

236 in_t, out_t, output_size=out_w, scale=scale_w 

237 )