Coverage for src/flag_gems/ops/upsample_bicubic2d_aa.py: 10%

178 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12device = device.name 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.autotune( 

18 configs=runtime.get_tuned_config("upsample_bicubic2d_aa"), 

19 key=["N", "C", "OH", "OW"], 

20) 

21@triton.jit 

22def upsample_bicubic2d_aa_kernel( 

23 ptr_o, 

24 ptr_i, 

25 N, 

26 C, 

27 OH, 

28 OW, 

29 IH, 

30 IW, 

31 reciprocal_scale_h, 

32 reciprocal_scale_w, 

33 BLOCK_X: tl.constexpr, 

34 BLOCK_Y: tl.constexpr, 

35): 

36 pid_x = tle.program_id(axis=0) 

37 pid_y = tle.program_id(axis=1) 

38 ow = (pid_x * BLOCK_X + tl.arange(0, BLOCK_X)) % OW 

39 oh = (pid_y * BLOCK_Y + tl.arange(0, BLOCK_Y)) % OH 

40 

41 support_w = 2.0 

42 support_h = 2.0 

43 

44 # _compute_weights_span 

45 center_w = (ow + 0.5) * reciprocal_scale_w 

46 center_h = (oh + 0.5) * reciprocal_scale_h 

47 span_start_w = tl.maximum(center_w - support_w + 0.5, 0).to(tl.int32) 

48 span_start_h = tl.maximum(center_h - support_h + 0.5, 0).to(tl.int32) 

49 span_size_w = (tl.minimum(center_w + support_w + 0.5, IW) - span_start_w).to( 

50 tl.int32 

51 ) 

52 span_size_h = (tl.minimum(center_h + support_h + 0.5, IH) - span_start_h).to( 

53 tl.int32 

54 ) 

55 start_minus_center_w = span_start_w - center_w 

56 start_minus_center_h = span_start_h - center_h 

57 invscale_w = 1.0 

58 invscale_h = 1.0 

59 a = -0.5 

60 wy0 = tl.abs((0 + start_minus_center_h + 0.5) * invscale_h) 

61 weight_y0 = tl.where( 

62 0 < span_size_h, 

63 tl.where( 

64 wy0 < 1.0, 

65 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1, 

66 tl.where(wy0 < 2.0, (((wy0 - 5) * wy0 + 8) * wy0 - 4) * a, 0), 

67 ), 

68 0, 

69 ) 

70 wy1 = tl.abs((1 + start_minus_center_h + 0.5) * invscale_h) 

71 weight_y1 = tl.where( 

72 1 < span_size_h, 

73 tl.where( 

74 wy1 < 1.0, 

75 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1, 

76 tl.where(wy1 < 2.0, (((wy1 - 5) * wy1 + 8) * wy1 - 4) * a, 0), 

77 ), 

78 0, 

79 ) 

80 wy2 = tl.abs((2 + start_minus_center_h + 0.5) * invscale_h) 

81 weight_y2 = tl.where( 

82 2 < span_size_h, 

83 tl.where( 

84 wy2 < 1.0, 

85 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1, 

86 tl.where(wy2 < 2.0, (((wy2 - 5) * wy2 + 8) * wy2 - 4) * a, 0), 

87 ), 

88 0, 

89 ) 

90 wy3 = tl.abs((3 + start_minus_center_h + 0.5) * invscale_h) 

91 weight_y3 = tl.where( 

92 3 < span_size_h, 

93 tl.where( 

94 wy3 < 1.0, 

95 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1, 

96 tl.where(wy3 < 2.0, (((wy3 - 5) * wy3 + 8) * wy3 - 4) * a, 0), 

97 ), 

98 0, 

99 ) 

100 wy4 = tl.abs((4 + start_minus_center_h + 0.5) * invscale_h) 

101 weight_y4 = tl.where( 

102 4 < span_size_h, 

103 tl.where( 

104 wy4 < 1.0, 

105 ((a + 2) * wy4 - (a + 3)) * wy4 * wy4 + 1, 

106 tl.where(wy4 < 2.0, (((wy4 - 5) * wy4 + 8) * wy4 - 4) * a, 0), 

107 ), 

108 0, 

109 ) 

110 weight_y_total = weight_y0 + weight_y1 + weight_y2 + weight_y3 + weight_y4 

111 weight_y_total = tl.where(weight_y_total != 0, weight_y_total, 1) 

112 weight_y0 /= weight_y_total 

113 weight_y1 /= weight_y_total 

114 weight_y2 /= weight_y_total 

115 weight_y3 /= weight_y_total 

116 weight_y4 /= weight_y_total 

117 

118 wx0 = tl.abs((0 + start_minus_center_w + 0.5) * invscale_w) 

119 weight_x0 = tl.where( 

120 0 < span_size_w, 

121 tl.where( 

122 wx0 < 1.0, 

123 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1, 

124 tl.where(wx0 < 2.0, (((wx0 - 5) * wx0 + 8) * wx0 - 4) * a, 0), 

125 ), 

126 0, 

127 ) 

128 wx1 = tl.abs((1 + start_minus_center_w + 0.5) * invscale_w) 

129 weight_x1 = tl.where( 

130 1 < span_size_w, 

131 tl.where( 

132 wx1 < 1.0, 

133 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1, 

134 tl.where(wx1 < 2.0, (((wx1 - 5) * wx1 + 8) * wx1 - 4) * a, 0), 

135 ), 

136 0, 

137 ) 

138 wx2 = tl.abs((2 + start_minus_center_w + 0.5) * invscale_w) 

139 weight_x2 = tl.where( 

140 2 < span_size_w, 

141 tl.where( 

142 wx2 < 1.0, 

143 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1, 

144 tl.where(wx2 < 2.0, (((wx2 - 5) * wx2 + 8) * wx2 - 4) * a, 0), 

145 ), 

146 0, 

147 ) 

148 wx3 = tl.abs((3 + start_minus_center_w + 0.5) * invscale_w) 

149 weight_x3 = tl.where( 

150 3 < span_size_w, 

151 tl.where( 

152 wx3 < 1.0, 

153 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1, 

154 tl.where(wx3 < 2.0, (((wx3 - 5) * wx3 + 8) * wx3 - 4) * a, 0), 

155 ), 

156 0, 

157 ) 

158 wx4 = tl.abs((4 + start_minus_center_w + 0.5) * invscale_w) 

159 weight_x4 = tl.where( 

160 4 < span_size_w, 

161 tl.where( 

162 wx4 < 1.0, 

163 ((a + 2) * wx4 - (a + 3)) * wx4 * wx4 + 1, 

164 tl.where(wx4 < 2.0, (((wx4 - 5) * wx4 + 8) * wx4 - 4) * a, 0), 

165 ), 

166 0, 

167 ) 

168 weight_x_total = weight_x0 + weight_x1 + weight_x2 + weight_x3 + weight_x4 

169 weight_x_total = tl.where(weight_x_total != 0, weight_x_total, 1) 

170 weight_x0 /= weight_x_total 

171 weight_x1 /= weight_x_total 

172 weight_x2 /= weight_x_total 

173 weight_x3 /= weight_x_total 

174 weight_x4 /= weight_x_total 

175 

176 mask_y0 = span_start_h[:, None] + 0 < IH 

177 mask_y1 = span_start_h[:, None] + 1 < IH 

178 mask_y2 = span_start_h[:, None] + 2 < IH 

179 mask_y3 = span_start_h[:, None] + 3 < IH 

180 mask_y4 = span_start_h[:, None] + 4 < IH 

181 mask_x0 = span_start_w[None, :] + 0 < IW 

182 mask_x1 = span_start_w[None, :] + 1 < IW 

183 mask_x2 = span_start_w[None, :] + 2 < IW 

184 mask_x3 = span_start_w[None, :] + 3 < IW 

185 mask_x4 = span_start_w[None, :] + 4 < IW 

186 

187 for n in range(0, N, 1): 

188 for c in range(0, C, 1): 

189 offset_base = ( 

190 (n * C + c) * IH + span_start_h[:, None] 

191 ) * IW + span_start_w[None, :] 

192 

193 data00 = tl.load( 

194 ptr_i + (offset_base + 0 * IW + 0), 

195 mask=(mask_y0 & mask_x0), 

196 other=0, 

197 ) 

198 data01 = tl.load( 

199 ptr_i + (offset_base + 0 * IW + 1), 

200 mask=(mask_y0 & mask_x1), 

201 other=0, 

202 ) 

203 data02 = tl.load( 

204 ptr_i + (offset_base + 0 * IW + 2), 

205 mask=(mask_y0 & mask_x2), 

206 other=0, 

207 ) 

208 data03 = tl.load( 

209 ptr_i + (offset_base + 0 * IW + 3), 

210 mask=(mask_y0 & mask_x3), 

211 other=0, 

212 ) 

213 data04 = tl.load( 

214 ptr_i + (offset_base + 0 * IW + 4), 

215 mask=(mask_y0 & mask_x4), 

216 other=0, 

217 ) 

218 

219 data10 = tl.load( 

220 ptr_i + (offset_base + 1 * IW + 0), 

221 mask=(mask_y1 & mask_x0), 

222 other=0, 

223 ) 

224 data11 = tl.load( 

225 ptr_i + (offset_base + 1 * IW + 1), 

226 mask=(mask_y1 & mask_x1), 

227 other=0, 

228 ) 

229 data12 = tl.load( 

230 ptr_i + (offset_base + 1 * IW + 2), 

231 mask=(mask_y1 & mask_x2), 

232 other=0, 

233 ) 

234 data13 = tl.load( 

235 ptr_i + (offset_base + 1 * IW + 3), 

236 mask=(mask_y1 & mask_x3), 

237 other=0, 

238 ) 

239 data14 = tl.load( 

240 ptr_i + (offset_base + 1 * IW + 4), 

241 mask=(mask_y1 & mask_x4), 

242 other=0, 

243 ) 

244 

245 data20 = tl.load( 

246 ptr_i + (offset_base + 2 * IW + 0), 

247 mask=(mask_y2 & mask_x0), 

248 other=0, 

249 ) 

250 data21 = tl.load( 

251 ptr_i + (offset_base + 2 * IW + 1), 

252 mask=(mask_y2 & mask_x1), 

253 other=0, 

254 ) 

255 data22 = tl.load( 

256 ptr_i + (offset_base + 2 * IW + 2), 

257 mask=(mask_y2 & mask_x2), 

258 other=0, 

259 ) 

260 data23 = tl.load( 

261 ptr_i + (offset_base + 2 * IW + 3), 

262 mask=(mask_y2 & mask_x3), 

263 other=0, 

264 ) 

265 data24 = tl.load( 

266 ptr_i + (offset_base + 2 * IW + 4), 

267 mask=(mask_y2 & mask_x4), 

268 other=0, 

269 ) 

270 

271 data30 = tl.load( 

272 ptr_i + (offset_base + 3 * IW + 0), 

273 mask=(mask_y3 & mask_x0), 

274 other=0, 

275 ) 

276 data31 = tl.load( 

277 ptr_i + (offset_base + 3 * IW + 1), 

278 mask=(mask_y3 & mask_x1), 

279 other=0, 

280 ) 

281 data32 = tl.load( 

282 ptr_i + (offset_base + 3 * IW + 2), 

283 mask=(mask_y3 & mask_x2), 

284 other=0, 

285 ) 

286 data33 = tl.load( 

287 ptr_i + (offset_base + 3 * IW + 3), 

288 mask=(mask_y3 & mask_x3), 

289 other=0, 

290 ) 

291 data34 = tl.load( 

292 ptr_i + (offset_base + 3 * IW + 4), 

293 mask=(mask_y3 & mask_x4), 

294 other=0, 

295 ) 

296 

297 data40 = tl.load( 

298 ptr_i + (offset_base + 4 * IW + 0), 

299 mask=(mask_y4 & mask_x0), 

300 other=0, 

301 ) 

302 data41 = tl.load( 

303 ptr_i + (offset_base + 4 * IW + 1), 

304 mask=(mask_y4 & mask_x1), 

305 other=0, 

306 ) 

307 data42 = tl.load( 

308 ptr_i + (offset_base + 4 * IW + 2), 

309 mask=(mask_y4 & mask_x2), 

310 other=0, 

311 ) 

312 data43 = tl.load( 

313 ptr_i + (offset_base + 4 * IW + 3), 

314 mask=(mask_y4 & mask_x3), 

315 other=0, 

316 ) 

317 data44 = tl.load( 

318 ptr_i + (offset_base + 4 * IW + 4), 

319 mask=(mask_y4 & mask_x4), 

320 other=0, 

321 ) 

322 

323 data0 = ( 

324 data00 * weight_x0[None, :] 

325 + data01 * weight_x1[None, :] 

326 + data02 * weight_x2[None, :] 

327 + data03 * weight_x3[None, :] 

328 + data04 * weight_x4[None, :] 

329 ) 

330 data1 = ( 

331 data10 * weight_x0[None, :] 

332 + data11 * weight_x1[None, :] 

333 + data12 * weight_x2[None, :] 

334 + data13 * weight_x3[None, :] 

335 + data14 * weight_x4[None, :] 

336 ) 

337 data2 = ( 

338 data20 * weight_x0[None, :] 

339 + data21 * weight_x1[None, :] 

340 + data22 * weight_x2[None, :] 

341 + data23 * weight_x3[None, :] 

342 + data24 * weight_x4[None, :] 

343 ) 

344 data3 = ( 

345 data30 * weight_x0[None, :] 

346 + data31 * weight_x1[None, :] 

347 + data32 * weight_x2[None, :] 

348 + data33 * weight_x3[None, :] 

349 + data34 * weight_x4[None, :] 

350 ) 

351 data4 = ( 

352 data40 * weight_x0[None, :] 

353 + data41 * weight_x1[None, :] 

354 + data42 * weight_x2[None, :] 

355 + data43 * weight_x3[None, :] 

356 + data44 * weight_x4[None, :] 

357 ) 

358 result = ( 

359 data0 * weight_y0[:, None] 

360 + data1 * weight_y1[:, None] 

361 + data2 * weight_y2[:, None] 

362 + data3 * weight_y3[:, None] 

363 + data4 * weight_y4[:, None] 

364 ) 

365 

366 offset_o = ((n * C + c) * OH + oh[:, None]) * OW + ow[None, :] 

367 tl.store(ptr_o + offset_o, result) 

368 

369 

370# upsample and downsample 

371@triton.autotune( 

372 configs=runtime.get_tuned_config("upsample_bicubic2d_aa"), 

373 key=["N", "C", "OH", "OW"], 

374) 

375@triton.jit 

376def general_interpolate_bicubic2d_aa_kernel( 

377 ptr_o, 

378 ptr_i, 

379 N, 

380 C, 

381 OH, 

382 OW, 

383 IH, 

384 IW, 

385 reciprocal_scale_h, 

386 reciprocal_scale_w, 

387 BLOCK_X: tl.constexpr, 

388 BLOCK_Y: tl.constexpr, 

389): 

390 pid_x = tle.program_id(axis=0) 

391 pid_y = tle.program_id(axis=1) 

392 ow = (pid_x * BLOCK_X + tl.arange(0, BLOCK_X)) % OW 

393 oh = (pid_y * BLOCK_Y + tl.arange(0, BLOCK_Y)) % OH 

394 

395 support_w = 2 * reciprocal_scale_w if (reciprocal_scale_w >= 1.0) else 2.0 

396 support_h = 2 * reciprocal_scale_h if (reciprocal_scale_h >= 1.0) else 2.0 

397 

398 interpolate_w = (support_w + 0.5).to(tl.int32) * 2 + 1 

399 interpolate_h = (support_h + 0.5).to(tl.int32) * 2 + 1 

400 

401 # _compute_weights_span 

402 center_w = (ow + 0.5) * reciprocal_scale_w 

403 center_h = (oh + 0.5) * reciprocal_scale_h 

404 span_start_w = tl.maximum(center_w - support_w + 0.5, 0).to(tl.int32) 

405 span_start_h = tl.maximum(center_h - support_h + 0.5, 0).to(tl.int32) 

406 span_size_w = (tl.minimum(center_w + support_w + 0.5, IW) - span_start_w).to( 

407 tl.int32 

408 ) 

409 span_size_h = (tl.minimum(center_h + support_h + 0.5, IH) - span_start_h).to( 

410 tl.int32 

411 ) 

412 

413 invscale_w = 1.0 / reciprocal_scale_w if (reciprocal_scale_w >= 1.0) else 1.0 

414 invscale_h = 1.0 / reciprocal_scale_h if (reciprocal_scale_h >= 1.0) else 1.0 

415 start_minus_center_w = span_start_w - center_w 

416 start_minus_center_h = span_start_h - center_h 

417 

418 a = -0.5 

419 for n in range(0, N, 1): 

420 for c in range(0, C, 1): 

421 offset_base = ((n * C + c) * IH + span_start_h[:, None]) * IW + span_start_w 

422 weight_y_total = tl.zeros((BLOCK_Y,), dtype=tl.float32) 

423 result = tl.zeros((BLOCK_Y, BLOCK_X), dtype=tl.float32) 

424 for y in range(0, interpolate_h, 1): 

425 wy = tl.abs((y + start_minus_center_h + 0.5) * invscale_h) 

426 weight_y = tl.where( 

427 y < span_size_h, 

428 tl.where( 

429 wy < 1.0, 

430 ((a + 2) * wy - (a + 3)) * wy * wy + 1, 

431 tl.where(wy < 2.0, (((wy - 5) * wy + 8) * wy - 4) * a, 0), 

432 ), 

433 0, 

434 ) 

435 weight_y_total += weight_y 

436 weight_x_total = tl.zeros((BLOCK_X,), dtype=tl.float32) 

437 buffer = tl.zeros((BLOCK_Y, BLOCK_X), dtype=tl.float32) 

438 for x in range(0, interpolate_w, 1): 

439 wx = tl.abs((x + start_minus_center_w + 0.5) * invscale_w) 

440 weight_x = tl.where( 

441 x < span_size_w, 

442 tl.where( 

443 wx < 1.0, 

444 ((a + 2) * wx - (a + 3)) * wx * wx + 1, 

445 tl.where(wx < 2.0, (((wx - 5) * wx + 8) * wx - 4) * a, 0), 

446 ), 

447 0, 

448 ) 

449 weight_x_total += weight_x 

450 data = tl.load( 

451 ptr_i + (offset_base + y * IW + x), 

452 mask=(span_start_h[:, None] + y < IH) 

453 & (span_start_w[None, :] + x < IW), 

454 other=0, 

455 ) 

456 buffer += data * weight_x[None, :] 

457 weight_x_total = tl.where(weight_x_total != 0, weight_x_total, 1) 

458 result += buffer / weight_x_total[None, :] * weight_y[:, None] 

459 weight_y_total = tl.where(weight_y_total != 0, weight_y_total, 1) 

460 result /= weight_y_total[:, None] 

461 offset_o = ((n * C + c) * OH + oh[:, None]) * OW + ow[None, :] 

462 tl.store(ptr_o + offset_o, result) 

463 

464 

465def bicubic_reciprocal_scale(src_size, dst_size, align_corners, scale): 

466 if align_corners: 

467 if dst_size > 1: 

468 return (src_size - 1) / (dst_size - 1) 

469 else: 

470 return 0 

471 else: 

472 if scale is not None and scale > 0: 

473 return 1.0 / scale 

474 else: 

475 return src_size / dst_size 

476 

477 

478# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L12547 

479def _upsample_bicubic2d_aa( 

480 input: torch.Tensor, 

481 output_size: Tuple[int], 

482 align_corners: bool = False, 

483 scales_h: Optional[float] = None, 

484 scales_w: Optional[float] = None, 

485): 

486 logger.debug("GEMS UPSAMPLE BICUBIC2D AA") 

487 assert input.device.type == device 

488 assert input.ndim == 4, "The ndim of input must be 4" 

489 assert len(output_size) == 2, "The len of output_size must be 2" 

490 

491 OH, OW = output_size 

492 N, C, IH, IW = input.shape 

493 

494 reciprocal_scale_h = bicubic_reciprocal_scale(IH, OH, align_corners, scales_h) 

495 reciprocal_scale_w = bicubic_reciprocal_scale(IW, OW, align_corners, scales_w) 

496 

497 # allocate output 

498 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

499 grid = lambda META: ( 

500 triton.cdiv(OW, META["BLOCK_X"]), 

501 triton.cdiv(OH, META["BLOCK_Y"]), 

502 ) 

503 kernel = ( 

504 general_interpolate_bicubic2d_aa_kernel 

505 if (reciprocal_scale_w >= 1.0) or (reciprocal_scale_h >= 1.0) 

506 else upsample_bicubic2d_aa_kernel 

507 ) 

508 with torch_device_fn.device(input.device): 

509 kernel[grid]( 

510 output, 

511 input, 

512 N, 

513 C, 

514 OH, 

515 OW, 

516 IH, 

517 IW, 

518 reciprocal_scale_h, 

519 reciprocal_scale_w, 

520 ) 

521 return output