Coverage for src/flag_gems/ops/_upsample_nearest_exact1d.py: 53%

133 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def _upsample_nearest_exact1d_kernel( 

14 in_ptr, 

15 out_ptr, 

16 N, 

17 C, 

18 IW, 

19 OW, 

20 sN_in, 

21 sC_in, 

22 sW_in, 

23 sN_out, 

24 sC_out, 

25 sW_out, 

26 use_scales: tl.constexpr, 

27 scale_w, 

28 BLOCK_W: tl.constexpr, 

29): 

30 pid_w = tl.program_id(0) 

31 pid_nc = tl.program_id(1) 

32 

33 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

34 mask = offs_w < OW 

35 

36 # Compute (n, c) from flattened plane index 

37 nc = pid_nc 

38 n = nc // C 

39 c = nc - n * C 

40 

41 base_in = n * sN_in + c * sC_in 

42 base_out = n * sN_out + c * sC_out 

43 

44 # Compute source indices iw for each output index ow 

45 iw = tl.zeros([BLOCK_W], dtype=tl.int32) 

46 if use_scales: 

47 ow_f = offs_w.to(tl.float32) 

48 iw_f = tl.floor(ow_f / scale_w) 

49 iw = iw_f.to(tl.int32) 

50 else: 

51 iw = (offs_w * IW) // OW 

52 iw = tl.minimum(iw, IW - 1) 

53 

54 in_ptrs = in_ptr + base_in + iw * sW_in 

55 x = tl.load(in_ptrs, mask=mask) 

56 

57 out_ptrs = out_ptr + base_out + offs_w * sW_out 

58 tl.store(out_ptrs, x, mask=mask) 

59 

60 

61def _parse_size_1d(val): 

62 if val is None: 

63 return None 

64 if isinstance(val, torch.Size): 

65 return int(val[-1]) if len(val) > 0 else None 

66 if isinstance(val, (list, tuple)): 

67 if len(val) == 0: 

68 return None 

69 return int(val[-1]) 

70 return int(val) 

71 

72 

73def _parse_scale_1d(val): 

74 if val is None: 

75 return None 

76 if isinstance(val, (list, tuple)): 

77 if len(val) == 0: 

78 return None 

79 return float(val[-1]) 

80 return float(val) 

81 

82 

83def _compute_out_w(iw, output_size, scale): 

84 if output_size is not None: 

85 return int(output_size) 

86 if scale is None: 

87 raise ValueError( 

88 "Either output_size or scale must be provided for _upsample_nearest_exact1d." 

89 ) 

90 # Follow common convention: OW = floor(IW * scale) 

91 return int(math.floor(iw * scale)) 

92 

93 

94def _launch_upsample_nearest_exact1d_kernel(input, out, output_size=None, scale=None): 

95 if input.ndim != 3: 

96 raise ValueError( 

97 f"_upsample_nearest_exact1d expects a 3D tensor (N, C, W); got shape {tuple(input.shape)}" 

98 ) 

99 if not input.is_cuda or not out.is_cuda: 

100 # Fallback to the native operator on CPU or non-CUDA devices 

101 return torch.ops.aten._upsample_nearest_exact1d( 

102 input, [out.shape[-1]], [scale] if scale is not None else None 

103 ) 

104 

105 N, C, IW = input.shape 

106 OW = out.shape[-1] 

107 

108 sN_in, sC_in, sW_in = input.stride() 

109 sN_out, sC_out, sW_out = out.stride() 

110 

111 BLOCK_W = 256 

112 grid = (triton.cdiv(OW, BLOCK_W), N * C) 

113 

114 use_scales = scale is not None and output_size is None 

115 scale_w = float(scale) if use_scales else 1.0 

116 

117 _upsample_nearest_exact1d_kernel[grid]( 

118 input, 

119 out, 

120 N, 

121 C, 

122 IW, 

123 OW, 

124 sN_in, 

125 sC_in, 

126 sW_in, 

127 sN_out, 

128 sC_out, 

129 sW_out, 

130 use_scales=use_scales, 

131 scale_w=scale_w, 

132 BLOCK_W=BLOCK_W, 

133 ) 

134 return out 

135 

136 

137def _extract_io_and_params(args, kwargs, expect_out=False): 

138 # Extract input tensor 

139 in_t = kwargs.get("input", None) 

140 if in_t is None: 

141 in_t = kwargs.get("self", None) 

142 if in_t is None and len(args) > 0 and isinstance(args[0], torch.Tensor): 

143 in_t = args[0] 

144 args = args[1:] 

145 if in_t is None or not isinstance(in_t, torch.Tensor): 

146 raise ValueError("Input tensor not found for _upsample_nearest_exact1d.") 

147 

148 # Extract output_size / scales from kwargs or remaining args 

149 output_size = kwargs.get( 

150 "output_size", kwargs.get("size", kwargs.get("output_size_list", None)) 

151 ) 

152 scales = kwargs.get( 

153 "scale_factor", 

154 kwargs.get("scales", kwargs.get("scale_factors", kwargs.get("scale", None))), 

155 ) 

156 

157 # If positional arguments contain size and/or scales 

158 # Try to interpret next positional as output_size if present and not a tensor 

159 pos = 0 

160 if ( 

161 output_size is None 

162 and pos < len(args) 

163 and not isinstance(args[pos], torch.Tensor) 

164 ): 

165 output_size = args[pos] 

166 pos += 1 

167 if scales is None and pos < len(args) and not isinstance(args[pos], torch.Tensor): 

168 scales = args[pos] 

169 pos += 1 

170 

171 out_t = None 

172 if expect_out: 

173 out_t = kwargs.get("out", None) 

174 if out_t is None: 

175 # find last tensor among remaining args as out 

176 for a in reversed(args): 

177 if isinstance(a, torch.Tensor): 

178 out_t = a 

179 break 

180 if out_t is None: 

181 raise ValueError( 

182 "Output tensor 'out' not found for _upsample_nearest_exact1d_out." 

183 ) 

184 

185 # Normalize single-dim size and scale 

186 out_w = _parse_size_1d(output_size) 

187 scale_w = _parse_scale_1d(scales) 

188 

189 return in_t, out_t, out_w, scale_w 

190 

191 

192def _prepare_out_tensor(in_t, out_w, scale_w, dtype=None, device=None): 

193 N, C, IW = in_t.shape 

194 OW = _compute_out_w(IW, out_w, scale_w) 

195 if OW < 0: 

196 raise ValueError("Output width must be non-negative.") 

197 if dtype is None: 

198 dtype = in_t.dtype 

199 if device is None: 

200 device = in_t.device 

201 return torch.empty((N, C, OW), dtype=dtype, device=device) 

202 

203 

204def _upsample_nearest_exact1d(*args, **kwargs): 

205 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D") 

206 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False) 

207 out_t = _prepare_out_tensor(in_t, out_w, scale_w) 

208 if out_t.numel() == 0: 

209 return out_t 

210 return _launch_upsample_nearest_exact1d_kernel( 

211 in_t, out_t, output_size=out_w, scale=scale_w 

212 ) 

213 

214 

215def _upsample_nearest_exact1d_out(*args, **kwargs): 

216 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D_OUT") 

217 in_t, out_t, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=True) 

218 if out_t.ndim != 3: 

219 raise ValueError( 

220 f"Out tensor must be 3D (N, C, W); got shape {tuple(out_t.shape)}" 

221 ) 

222 # Validate that out_t has the correct computed width if parameters are provided 

223 expected_w = _compute_out_w(in_t.shape[-1], out_w, scale_w) 

224 if out_t.shape[-1] != expected_w: 

225 raise ValueError( 

226 f"Provided out tensor has width {out_t.shape[-1]} but expected {expected_w}." 

227 ) 

228 if out_t.numel() == 0: 

229 return out_t 

230 return _launch_upsample_nearest_exact1d_kernel( 

231 in_t, out_t, output_size=out_w, scale=scale_w 

232 ) 

233 

234 

235def _upsample_nearest_exact1d_vec(*args, **kwargs): 

236 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D_VEC") 

237 # Treat vec the same as base variant, allowing list-like output_size/scales 

238 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False) 

239 out_t = _prepare_out_tensor(in_t, out_w, scale_w) 

240 if out_t.numel() == 0: 

241 return out_t 

242 return _launch_upsample_nearest_exact1d_kernel( 

243 in_t, out_t, output_size=out_w, scale=scale_w 

244 )