Coverage for src/flag_gems/ops/upsample_bicubic2d.py: 28%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2import math 

3from typing import Sequence 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def cubic_weight(d, a: tl.constexpr): 

14 ad = tl.abs(d) 

15 ad2 = ad * ad 

16 ad3 = ad2 * ad 

17 w1 = (a + 2.0) * ad3 - (a + 3.0) * ad2 + 1.0 

18 w2 = a * ad3 - 5.0 * a * ad2 + 8.0 * a * ad - 4.0 * a 

19 return tl.where(ad <= 1.0, w1, tl.where(ad < 2.0, w2, 0.0)) 

20 

21 

22@triton.autotune( 

23 configs=[ 

24 triton.Config({"BLOCK_W": 128}, num_warps=4), 

25 triton.Config({"BLOCK_W": 256}, num_warps=4), 

26 triton.Config({"BLOCK_W": 256}, num_warps=8), 

27 triton.Config({"BLOCK_W": 512}, num_warps=8), 

28 triton.Config({"BLOCK_W": 1024}, num_warps=8), 

29 ], 

30 key=["W_out"], 

31) 

32@triton.jit 

33def _upsample_bicubic2d_row_kernel( 

34 in_ptr, 

35 out_ptr, 

36 N, 

37 C, 

38 H_in, 

39 W_in, 

40 H_out, 

41 W_out, 

42 strideN, 

43 strideC, 

44 strideH, 

45 strideW, 

46 out_strideN, 

47 out_strideC, 

48 out_strideH, 

49 out_strideW, 

50 scale_h, 

51 scale_w, 

52 align_corners: tl.constexpr, 

53 BLOCK_W: tl.constexpr, 

54): 

55 pid = tl.program_id(0) 

56 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

57 

58 pid_w = pid % num_w_blocks 

59 row_id = pid // num_w_blocks 

60 

61 y_out = row_id % H_out 

62 nc = row_id // H_out 

63 c = nc % C 

64 n = nc // C 

65 

66 fy = y_out * 1.0 

67 if align_corners: 

68 in_y = fy * scale_h 

69 else: 

70 in_y = (fy + 0.5) * scale_h - 0.5 

71 

72 y0f = tl.floor(in_y) 

73 y0 = y0f.to(tl.int32) 

74 ty = in_y - y0f 

75 

76 y_m1 = tl.maximum(0, tl.minimum(H_in - 1, y0 - 1)) 

77 y_0 = tl.maximum(0, tl.minimum(H_in - 1, y0 + 0)) 

78 y_p1 = tl.maximum(0, tl.minimum(H_in - 1, y0 + 1)) 

79 y_p2 = tl.maximum(0, tl.minimum(H_in - 1, y0 + 2)) 

80 

81 a = -0.75 

82 wy0 = cubic_weight(1.0 + ty, a) 

83 wy1 = cubic_weight(ty, a) 

84 wy2 = cubic_weight(1.0 - ty, a) 

85 wy3 = cubic_weight(2.0 - ty, a) 

86 

87 n_64 = n.to(tl.int64) 

88 c_64 = c.to(tl.int64) 

89 base_ptr = in_ptr + n_64 * strideN + c_64 * strideC 

90 

91 row_m1_ptr = base_ptr + y_m1.to(tl.int64) * strideH 

92 row_0_ptr = base_ptr + y_0.to(tl.int64) * strideH 

93 row_p1_ptr = base_ptr + y_p1.to(tl.int64) * strideH 

94 row_p2_ptr = base_ptr + y_p2.to(tl.int64) * strideH 

95 

96 x_out = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

97 mask = x_out < W_out 

98 

99 fx = x_out.to(tl.float32) 

100 if align_corners: 

101 in_x = fx * scale_w 

102 else: 

103 in_x = (fx + 0.5) * scale_w - 0.5 

104 

105 x0f = tl.floor(in_x) 

106 x0 = x0f.to(tl.int32) 

107 tx = in_x - x0f 

108 

109 x_m1 = tl.maximum(0, tl.minimum(W_in - 1, x0 - 1)) 

110 x_0 = tl.maximum(0, tl.minimum(W_in - 1, x0 + 0)) 

111 x_p1 = tl.maximum(0, tl.minimum(W_in - 1, x0 + 1)) 

112 x_p2 = tl.maximum(0, tl.minimum(W_in - 1, x0 + 2)) 

113 

114 wx0 = cubic_weight(1.0 + tx, a) 

115 wx1 = cubic_weight(tx, a) 

116 wx2 = cubic_weight(1.0 - tx, a) 

117 wx3 = cubic_weight(2.0 - tx, a) 

118 

119 off_x_m1 = x_m1 * strideW 

120 off_x_0 = x_0 * strideW 

121 off_x_p1 = x_p1 * strideW 

122 off_x_p2 = x_p2 * strideW 

123 

124 v0 = tl.load(row_m1_ptr + off_x_m1).to(tl.float32) 

125 v1 = tl.load(row_m1_ptr + off_x_0).to(tl.float32) 

126 v2 = tl.load(row_m1_ptr + off_x_p1).to(tl.float32) 

127 v3 = tl.load(row_m1_ptr + off_x_p2).to(tl.float32) 

128 acc = (v0 * wx0 + v1 * wx1 + v2 * wx2 + v3 * wx3) * wy0 

129 

130 v0 = tl.load(row_0_ptr + off_x_m1).to(tl.float32) 

131 v1 = tl.load(row_0_ptr + off_x_0).to(tl.float32) 

132 v2 = tl.load(row_0_ptr + off_x_p1).to(tl.float32) 

133 v3 = tl.load(row_0_ptr + off_x_p2).to(tl.float32) 

134 acc += (v0 * wx0 + v1 * wx1 + v2 * wx2 + v3 * wx3) * wy1 

135 

136 v0 = tl.load(row_p1_ptr + off_x_m1).to(tl.float32) 

137 v1 = tl.load(row_p1_ptr + off_x_0).to(tl.float32) 

138 v2 = tl.load(row_p1_ptr + off_x_p1).to(tl.float32) 

139 v3 = tl.load(row_p1_ptr + off_x_p2).to(tl.float32) 

140 acc += (v0 * wx0 + v1 * wx1 + v2 * wx2 + v3 * wx3) * wy2 

141 

142 v0 = tl.load(row_p2_ptr + off_x_m1).to(tl.float32) 

143 v1 = tl.load(row_p2_ptr + off_x_0).to(tl.float32) 

144 v2 = tl.load(row_p2_ptr + off_x_p1).to(tl.float32) 

145 v3 = tl.load(row_p2_ptr + off_x_p2).to(tl.float32) 

146 acc += (v0 * wx0 + v1 * wx1 + v2 * wx2 + v3 * wx3) * wy3 

147 

148 out_offset = ( 

149 n_64 * out_strideN 

150 + c_64 * out_strideC 

151 + y_out.to(tl.int64) * out_strideH 

152 + x_out.to(tl.int64) * out_strideW 

153 ) 

154 tl.store(out_ptr + out_offset, acc.to(out_ptr.dtype.element_ty), mask=mask) 

155 

156 

157def upsample_bicubic2d( 

158 input: torch.Tensor, 

159 output_size: Sequence[int] | None = None, 

160 align_corners: bool = False, 

161 scales_h: float | None = None, 

162 scales_w: float | None = None, 

163) -> torch.Tensor: 

164 logger.debug("GEMS UPSAMPLE BICUBIC2D") 

165 scale_factors = (scales_h, scales_w) 

166 

167 if input.dim() != 4: 

168 raise ValueError("input must be a 4D tensor (N, C, H, W)") 

169 if output_size is None and scale_factors is None: 

170 raise ValueError("Either output_size or scale_factors must be provided") 

171 

172 N, C, H_in, W_in = input.shape 

173 

174 if output_size is not None: 

175 if len(output_size) != 2: 

176 raise ValueError( 

177 "output_size must be a sequence of two ints (H_out, W_out)" 

178 ) 

179 H_out, W_out = int(output_size[0]), int(output_size[1]) 

180 else: 

181 if len(scale_factors) == 2: 

182 sh, sw = float(scale_factors[0]), float(scale_factors[1]) 

183 elif len(scale_factors) == 1: 

184 sh = sw = float(scale_factors[0]) 

185 else: 

186 raise ValueError("scale_factors must have length 1 or 2 for 2D upsampling") 

187 H_out = max(int(math.floor(H_in * sh)), 1) 

188 W_out = max(int(math.floor(W_in * sw)), 1) 

189 

190 if H_out <= 0 or W_out <= 0: 

191 raise ValueError("Output size must be positive") 

192 

193 device = input.device 

194 if not input.is_cuda: 

195 raise ValueError("This Triton kernel requires CUDA tensors") 

196 

197 if align_corners: 

198 scale_h = 0.0 if H_out <= 1 else (H_in - 1.0) / (H_out - 1.0) 

199 scale_w = 0.0 if W_out <= 1 else (W_in - 1.0) / (W_out - 1.0) 

200 else: 

201 scale_h = float(H_in) / float(H_out) 

202 scale_w = float(W_in) / float(W_out) 

203 

204 out = torch.empty((N, C, H_out, W_out), dtype=input.dtype, device=device) 

205 

206 sN, sC, sH, sW = input.stride() 

207 oN, oC, oH, oW = out.stride() 

208 

209 grid = lambda meta: (triton.cdiv(W_out, meta["BLOCK_W"]) * N * C * H_out,) 

210 

211 _upsample_bicubic2d_row_kernel[grid]( 

212 input, 

213 out, 

214 N, 

215 C, 

216 H_in, 

217 W_in, 

218 H_out, 

219 W_out, 

220 sN, 

221 sC, 

222 sH, 

223 sW, 

224 oN, 

225 oC, 

226 oH, 

227 oW, 

228 float(scale_h), 

229 float(scale_w), 

230 align_corners, 

231 ) 

232 

233 return out