Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/upsample_bicubic2d_aa.py: 0%

200 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13device = device.name 

14 

15 

16def configs(): 

17 block = [(bx, by) for bx in (512, 256, 128, 64) for by in (2, 1)] 

18 warps = [4, 8] 

19 return [ 

20 triton.Config( 

21 { 

22 "BLOCK_X": bs[0], 

23 "BLOCK_Y": bs[1], 

24 }, 

25 num_warps=wp, 

26 ) 

27 for bs in block 

28 for wp in warps 

29 ] 

30 

31 

32def heur_m_block_size(args): 

33 return triton.next_power_of_2(triton.cdiv(args["OW"], 12)) # cluster_num 

34 

35 

36def heur_n_block_size(args): 

37 return 1 

38 import builtins 

39 

40 return builtins.min(triton.next_power_of_2(args["OH"]), 8192) 

41 

42 

43# @triton.autotune( 

44# configs=runtime.get_tuned_config("upsample_bicubic2d_aa"), 

45# key=["N", "C", "OH", "OW"], 

46# ) 

47@triton.heuristics( 

48 values={ 

49 "BLOCK_X": heur_m_block_size, 

50 "BLOCK_Y": heur_n_block_size, 

51 }, 

52) 

53@triton.jit 

54def upsample_bicubic2d_aa_kernel( 

55 ptr_o, 

56 ptr_i, 

57 N: tl.constexpr, 

58 C: tl.constexpr, 

59 OH, 

60 OW, 

61 IH, 

62 IW, 

63 reciprocal_scale_h, 

64 reciprocal_scale_w, 

65 BLOCK_X: tl.constexpr, 

66 BLOCK_Y: tl.constexpr, 

67): 

68 pid_x = tle.program_id(axis=0) 

69 pid_y = tle.program_id(axis=1) 

70 ow = (pid_x * BLOCK_X + tl.arange(0, BLOCK_X)) % OW 

71 oh = (pid_y * BLOCK_Y + tl.arange(0, BLOCK_Y)) % OH 

72 

73 support_w = 2.0 

74 support_h = 2.0 

75 

76 # _compute_weights_span 

77 center_w = (ow + 0.5) * reciprocal_scale_w 

78 center_h = (oh + 0.5) * reciprocal_scale_h 

79 span_start_w = tl.maximum(center_w - support_w + 0.5, 0).to(tl.int32) 

80 span_start_h = tl.maximum(center_h - support_h + 0.5, 0).to(tl.int32) 

81 span_size_w = (tl.minimum(center_w + support_w + 0.5, IW) - span_start_w).to( 

82 tl.int32 

83 ) 

84 span_size_h = (tl.minimum(center_h + support_h + 0.5, IH) - span_start_h).to( 

85 tl.int32 

86 ) 

87 start_minus_center_w = span_start_w - center_w 

88 start_minus_center_h = span_start_h - center_h 

89 invscale_w = 1.0 

90 invscale_h = 1.0 

91 a = -0.5 

92 wy0 = tl.abs((0 + start_minus_center_h + 0.5) * invscale_h) 

93 weight_y0 = tl.where( 

94 0 < span_size_h, 

95 tl.where( 

96 wy0 < 1.0, 

97 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1, 

98 tl.where(wy0 < 2.0, (((wy0 - 5) * wy0 + 8) * wy0 - 4) * a, 0), 

99 ), 

100 0, 

101 ) 

102 wy1 = tl.abs((1 + start_minus_center_h + 0.5) * invscale_h) 

103 weight_y1 = tl.where( 

104 1 < span_size_h, 

105 tl.where( 

106 wy1 < 1.0, 

107 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1, 

108 tl.where(wy1 < 2.0, (((wy1 - 5) * wy1 + 8) * wy1 - 4) * a, 0), 

109 ), 

110 0, 

111 ) 

112 wy2 = tl.abs((2 + start_minus_center_h + 0.5) * invscale_h) 

113 weight_y2 = tl.where( 

114 2 < span_size_h, 

115 tl.where( 

116 wy2 < 1.0, 

117 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1, 

118 tl.where(wy2 < 2.0, (((wy2 - 5) * wy2 + 8) * wy2 - 4) * a, 0), 

119 ), 

120 0, 

121 ) 

122 wy3 = tl.abs((3 + start_minus_center_h + 0.5) * invscale_h) 

123 weight_y3 = tl.where( 

124 3 < span_size_h, 

125 tl.where( 

126 wy3 < 1.0, 

127 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1, 

128 tl.where(wy3 < 2.0, (((wy3 - 5) * wy3 + 8) * wy3 - 4) * a, 0), 

129 ), 

130 0, 

131 ) 

132 wy4 = tl.abs((4 + start_minus_center_h + 0.5) * invscale_h) 

133 weight_y4 = tl.where( 

134 4 < span_size_h, 

135 tl.where( 

136 wy4 < 1.0, 

137 ((a + 2) * wy4 - (a + 3)) * wy4 * wy4 + 1, 

138 tl.where(wy4 < 2.0, (((wy4 - 5) * wy4 + 8) * wy4 - 4) * a, 0), 

139 ), 

140 0, 

141 ) 

142 weight_y_total = weight_y0 + weight_y1 + weight_y2 + weight_y3 + weight_y4 

143 weight_y_total = tl.where(weight_y_total != 0, weight_y_total, 1) 

144 weight_y0 /= weight_y_total 

145 weight_y1 /= weight_y_total 

146 weight_y2 /= weight_y_total 

147 weight_y3 /= weight_y_total 

148 weight_y4 /= weight_y_total 

149 

150 wx0 = tl.abs((0 + start_minus_center_w + 0.5) * invscale_w) 

151 weight_x0 = tl.where( 

152 0 < span_size_w, 

153 tl.where( 

154 wx0 < 1.0, 

155 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1, 

156 tl.where(wx0 < 2.0, (((wx0 - 5) * wx0 + 8) * wx0 - 4) * a, 0), 

157 ), 

158 0, 

159 ) 

160 wx1 = tl.abs((1 + start_minus_center_w + 0.5) * invscale_w) 

161 weight_x1 = tl.where( 

162 1 < span_size_w, 

163 tl.where( 

164 wx1 < 1.0, 

165 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1, 

166 tl.where(wx1 < 2.0, (((wx1 - 5) * wx1 + 8) * wx1 - 4) * a, 0), 

167 ), 

168 0, 

169 ) 

170 wx2 = tl.abs((2 + start_minus_center_w + 0.5) * invscale_w) 

171 weight_x2 = tl.where( 

172 2 < span_size_w, 

173 tl.where( 

174 wx2 < 1.0, 

175 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1, 

176 tl.where(wx2 < 2.0, (((wx2 - 5) * wx2 + 8) * wx2 - 4) * a, 0), 

177 ), 

178 0, 

179 ) 

180 wx3 = tl.abs((3 + start_minus_center_w + 0.5) * invscale_w) 

181 weight_x3 = tl.where( 

182 3 < span_size_w, 

183 tl.where( 

184 wx3 < 1.0, 

185 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1, 

186 tl.where(wx3 < 2.0, (((wx3 - 5) * wx3 + 8) * wx3 - 4) * a, 0), 

187 ), 

188 0, 

189 ) 

190 wx4 = tl.abs((4 + start_minus_center_w + 0.5) * invscale_w) 

191 weight_x4 = tl.where( 

192 4 < span_size_w, 

193 tl.where( 

194 wx4 < 1.0, 

195 ((a + 2) * wx4 - (a + 3)) * wx4 * wx4 + 1, 

196 tl.where(wx4 < 2.0, (((wx4 - 5) * wx4 + 8) * wx4 - 4) * a, 0), 

197 ), 

198 0, 

199 ) 

200 weight_x_total = weight_x0 + weight_x1 + weight_x2 + weight_x3 + weight_x4 

201 weight_x_total = tl.where(weight_x_total != 0, weight_x_total, 1) 

202 weight_x0 /= weight_x_total 

203 weight_x1 /= weight_x_total 

204 weight_x2 /= weight_x_total 

205 weight_x3 /= weight_x_total 

206 weight_x4 /= weight_x_total 

207 

208 mask_y0 = span_start_h[:, None] + 0 < IH 

209 mask_y1 = span_start_h[:, None] + 1 < IH 

210 mask_y2 = span_start_h[:, None] + 2 < IH 

211 mask_y3 = span_start_h[:, None] + 3 < IH 

212 mask_y4 = span_start_h[:, None] + 4 < IH 

213 mask_x0 = span_start_w[None, :] + 0 < IW 

214 mask_x1 = span_start_w[None, :] + 1 < IW 

215 mask_x2 = span_start_w[None, :] + 2 < IW 

216 mask_x3 = span_start_w[None, :] + 3 < IW 

217 mask_x4 = span_start_w[None, :] + 4 < IW 

218 

219 for n in range(0, N, 1): 

220 for c in range(0, C, 1): 

221 offset_base = ( 

222 (n * C + c) * IH + span_start_h[:, None] 

223 ) * IW + span_start_w[None, :] 

224 

225 data00 = tl.load( 

226 ptr_i + (offset_base + 0 * IW + 0), 

227 mask=(mask_y0 & mask_x0), 

228 other=0, 

229 ) 

230 data01 = tl.load( 

231 ptr_i + (offset_base + 0 * IW + 1), 

232 mask=(mask_y0 & mask_x1), 

233 other=0, 

234 ) 

235 data02 = tl.load( 

236 ptr_i + (offset_base + 0 * IW + 2), 

237 mask=(mask_y0 & mask_x2), 

238 other=0, 

239 ) 

240 data03 = tl.load( 

241 ptr_i + (offset_base + 0 * IW + 3), 

242 mask=(mask_y0 & mask_x3), 

243 other=0, 

244 ) 

245 data04 = tl.load( 

246 ptr_i + (offset_base + 0 * IW + 4), 

247 mask=(mask_y0 & mask_x4), 

248 other=0, 

249 ) 

250 

251 data10 = tl.load( 

252 ptr_i + (offset_base + 1 * IW + 0), 

253 mask=(mask_y1 & mask_x0), 

254 other=0, 

255 ) 

256 data11 = tl.load( 

257 ptr_i + (offset_base + 1 * IW + 1), 

258 mask=(mask_y1 & mask_x1), 

259 other=0, 

260 ) 

261 data12 = tl.load( 

262 ptr_i + (offset_base + 1 * IW + 2), 

263 mask=(mask_y1 & mask_x2), 

264 other=0, 

265 ) 

266 data13 = tl.load( 

267 ptr_i + (offset_base + 1 * IW + 3), 

268 mask=(mask_y1 & mask_x3), 

269 other=0, 

270 ) 

271 data14 = tl.load( 

272 ptr_i + (offset_base + 1 * IW + 4), 

273 mask=(mask_y1 & mask_x4), 

274 other=0, 

275 ) 

276 

277 data20 = tl.load( 

278 ptr_i + (offset_base + 2 * IW + 0), 

279 mask=(mask_y2 & mask_x0), 

280 other=0, 

281 ) 

282 data21 = tl.load( 

283 ptr_i + (offset_base + 2 * IW + 1), 

284 mask=(mask_y2 & mask_x1), 

285 other=0, 

286 ) 

287 data22 = tl.load( 

288 ptr_i + (offset_base + 2 * IW + 2), 

289 mask=(mask_y2 & mask_x2), 

290 other=0, 

291 ) 

292 data23 = tl.load( 

293 ptr_i + (offset_base + 2 * IW + 3), 

294 mask=(mask_y2 & mask_x3), 

295 other=0, 

296 ) 

297 data24 = tl.load( 

298 ptr_i + (offset_base + 2 * IW + 4), 

299 mask=(mask_y2 & mask_x4), 

300 other=0, 

301 ) 

302 

303 data30 = tl.load( 

304 ptr_i + (offset_base + 3 * IW + 0), 

305 mask=(mask_y3 & mask_x0), 

306 other=0, 

307 ) 

308 data31 = tl.load( 

309 ptr_i + (offset_base + 3 * IW + 1), 

310 mask=(mask_y3 & mask_x1), 

311 other=0, 

312 ) 

313 data32 = tl.load( 

314 ptr_i + (offset_base + 3 * IW + 2), 

315 mask=(mask_y3 & mask_x2), 

316 other=0, 

317 ) 

318 data33 = tl.load( 

319 ptr_i + (offset_base + 3 * IW + 3), 

320 mask=(mask_y3 & mask_x3), 

321 other=0, 

322 ) 

323 data34 = tl.load( 

324 ptr_i + (offset_base + 3 * IW + 4), 

325 mask=(mask_y3 & mask_x4), 

326 other=0, 

327 ) 

328 

329 data40 = tl.load( 

330 ptr_i + (offset_base + 4 * IW + 0), 

331 mask=(mask_y4 & mask_x0), 

332 other=0, 

333 ) 

334 data41 = tl.load( 

335 ptr_i + (offset_base + 4 * IW + 1), 

336 mask=(mask_y4 & mask_x1), 

337 other=0, 

338 ) 

339 data42 = tl.load( 

340 ptr_i + (offset_base + 4 * IW + 2), 

341 mask=(mask_y4 & mask_x2), 

342 other=0, 

343 ) 

344 data43 = tl.load( 

345 ptr_i + (offset_base + 4 * IW + 3), 

346 mask=(mask_y4 & mask_x3), 

347 other=0, 

348 ) 

349 data44 = tl.load( 

350 ptr_i + (offset_base + 4 * IW + 4), 

351 mask=(mask_y4 & mask_x4), 

352 other=0, 

353 ) 

354 

355 data0 = ( 

356 data00 * weight_x0[None, :] 

357 + data01 * weight_x1[None, :] 

358 + data02 * weight_x2[None, :] 

359 + data03 * weight_x3[None, :] 

360 + data04 * weight_x4[None, :] 

361 ) 

362 data1 = ( 

363 data10 * weight_x0[None, :] 

364 + data11 * weight_x1[None, :] 

365 + data12 * weight_x2[None, :] 

366 + data13 * weight_x3[None, :] 

367 + data14 * weight_x4[None, :] 

368 ) 

369 data2 = ( 

370 data20 * weight_x0[None, :] 

371 + data21 * weight_x1[None, :] 

372 + data22 * weight_x2[None, :] 

373 + data23 * weight_x3[None, :] 

374 + data24 * weight_x4[None, :] 

375 ) 

376 data3 = ( 

377 data30 * weight_x0[None, :] 

378 + data31 * weight_x1[None, :] 

379 + data32 * weight_x2[None, :] 

380 + data33 * weight_x3[None, :] 

381 + data34 * weight_x4[None, :] 

382 ) 

383 data4 = ( 

384 data40 * weight_x0[None, :] 

385 + data41 * weight_x1[None, :] 

386 + data42 * weight_x2[None, :] 

387 + data43 * weight_x3[None, :] 

388 + data44 * weight_x4[None, :] 

389 ) 

390 result = ( 

391 data0 * weight_y0[:, None] 

392 + data1 * weight_y1[:, None] 

393 + data2 * weight_y2[:, None] 

394 + data3 * weight_y3[:, None] 

395 + data4 * weight_y4[:, None] 

396 ) 

397 

398 offset_o = ((n * C + c) * OH + oh[:, None]) * OW + ow[None, :] 

399 tl.store(ptr_o + offset_o, result) 

400 

401 

402# upsample and downsample 

403# @triton.autotune( 

404# configs=runtime.get_tuned_config("upsample_bicubic2d_aa"), 

405# key=["N", "C", "OH", "OW"], 

406# ) 

407@triton.heuristics( 

408 values={ 

409 "BLOCK_X": heur_m_block_size, 

410 "BLOCK_Y": heur_n_block_size, 

411 }, 

412) 

413@triton.jit 

414def general_interpolate_bicubic2d_aa_kernel( 

415 ptr_o, 

416 ptr_i, 

417 N, 

418 C, 

419 OH, 

420 OW, 

421 IH, 

422 IW, 

423 reciprocal_scale_h, 

424 reciprocal_scale_w, 

425 BLOCK_X: tl.constexpr, 

426 BLOCK_Y: tl.constexpr, 

427): 

428 pid_x = tle.program_id(axis=0) 

429 pid_y = tle.program_id(axis=1) 

430 ow = (pid_x * BLOCK_X + tl.arange(0, BLOCK_X)) % OW 

431 oh = (pid_y * BLOCK_Y + tl.arange(0, BLOCK_Y)) % OH 

432 

433 if reciprocal_scale_w >= 1.0: 

434 support_w = 2 * reciprocal_scale_w 

435 else: 

436 support_w = 2.0 

437 if reciprocal_scale_h >= 1.0: 

438 support_h = 2 * reciprocal_scale_h 

439 else: 

440 support_h = 2.0 

441 

442 interpolate_w = (support_w + 0.5).to(tl.int32) * 2 + 1 

443 interpolate_h = (support_h + 0.5).to(tl.int32) * 2 + 1 

444 

445 # _compute_weights_span 

446 center_w = (ow + 0.5) * reciprocal_scale_w 

447 center_h = (oh + 0.5) * reciprocal_scale_h 

448 span_start_w = tl.maximum(center_w - support_w + 0.5, 0).to(tl.int32) 

449 span_start_h = tl.maximum(center_h - support_h + 0.5, 0).to(tl.int32) 

450 span_size_w = (tl.minimum(center_w + support_w + 0.5, IW) - span_start_w).to( 

451 tl.int32 

452 ) 

453 span_size_h = (tl.minimum(center_h + support_h + 0.5, IH) - span_start_h).to( 

454 tl.int32 

455 ) 

456 

457 if reciprocal_scale_w >= 1.0: 

458 invscale_w = 1.0 / reciprocal_scale_w 

459 else: 

460 invscale_w = 1.0 

461 if reciprocal_scale_h >= 1.0: 

462 invscale_h = 1.0 / reciprocal_scale_h 

463 else: 

464 invscale_h = 1.0 

465 start_minus_center_w = span_start_w - center_w 

466 start_minus_center_h = span_start_h - center_h 

467 

468 a = -0.5 

469 for n in range(0, N, 1): 

470 for c in range(0, C, 1): 

471 offset_base = ((n * C + c) * IH + span_start_h[:, None]) * IW + span_start_w 

472 weight_y_total = tl.zeros((BLOCK_Y,), dtype=tl.float32) 

473 result = tl.zeros((BLOCK_Y, BLOCK_X), dtype=tl.float32) 

474 for y in range(0, interpolate_h, 1): 

475 wy = tl.abs((y + start_minus_center_h + 0.5) * invscale_h) 

476 weight_y = tl.where( 

477 y < span_size_h, 

478 tl.where( 

479 wy < 1.0, 

480 ((a + 2) * wy - (a + 3)) * wy * wy + 1, 

481 tl.where(wy < 2.0, (((wy - 5) * wy + 8) * wy - 4) * a, 0), 

482 ), 

483 0, 

484 ) 

485 weight_y_total += weight_y 

486 weight_x_total = tl.zeros((BLOCK_X,), dtype=tl.float32) 

487 buffer = tl.zeros((BLOCK_Y, BLOCK_X), dtype=tl.float32) 

488 for x in range(0, interpolate_w, 1): 

489 wx = tl.abs((x + start_minus_center_w + 0.5) * invscale_w) 

490 weight_x = tl.where( 

491 x < span_size_w, 

492 tl.where( 

493 wx < 1.0, 

494 ((a + 2) * wx - (a + 3)) * wx * wx + 1, 

495 tl.where(wx < 2.0, (((wx - 5) * wx + 8) * wx - 4) * a, 0), 

496 ), 

497 0, 

498 ) 

499 weight_x_total += weight_x 

500 data = tl.load( 

501 ptr_i + (offset_base + y * IW + x), 

502 mask=(span_start_h[:, None] + y < IH) 

503 & (span_start_w[None, :] + x < IW), 

504 other=0, 

505 ) 

506 buffer += data * weight_x[None, :] 

507 weight_x_total = tl.where(weight_x_total != 0, weight_x_total, 1) 

508 result += buffer / weight_x_total[None, :] * weight_y[:, None] 

509 weight_y_total = tl.where(weight_y_total != 0, weight_y_total, 1) 

510 result /= weight_y_total[:, None] 

511 offset_o = ((n * C + c) * OH + oh[:, None]) * OW + ow[None, :] 

512 tl.store(ptr_o + offset_o, result) 

513 

514 

515def bicubic_reciprocal_scale(src_size, dst_size, align_corners, scale): 

516 if align_corners: 

517 if dst_size > 1: 

518 return (src_size - 1) / (dst_size - 1) 

519 else: 

520 return 0 

521 else: 

522 if scale is not None and scale > 0: 

523 return 1.0 / scale 

524 else: 

525 return src_size / dst_size 

526 

527 

528# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L12547 

529def _upsample_bicubic2d_aa( 

530 input: torch.Tensor, 

531 output_size: Tuple[int], 

532 align_corners: bool = False, 

533 scales_h: Optional[float] = None, 

534 scales_w: Optional[float] = None, 

535): 

536 logger.debug("GEMS UPSAMPLE BICUBIC2D AA") 

537 assert input.device.type == device 

538 assert input.ndim == 4, "The ndim of input must be 4" 

539 assert len(output_size) == 2, "The len of output_size must be 2" 

540 

541 OH, OW = output_size 

542 N, C, IH, IW = input.shape 

543 

544 reciprocal_scale_h = bicubic_reciprocal_scale(IH, OH, align_corners, scales_h) 

545 reciprocal_scale_w = bicubic_reciprocal_scale(IW, OW, align_corners, scales_w) 

546 

547 # allocate output 

548 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

549 grid = lambda META: ( 

550 triton.cdiv(OW, META["BLOCK_X"]), 

551 triton.cdiv(OH, META["BLOCK_Y"]), 

552 ) 

553 kernel = ( 

554 general_interpolate_bicubic2d_aa_kernel 

555 if (reciprocal_scale_w >= 1.0) or (reciprocal_scale_h >= 1.0) 

556 else upsample_bicubic2d_aa_kernel 

557 ) 

558 

559 import os 

560 

561 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

562 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

563 with torch_device_fn.device(input.device): 

564 kernel[grid]( 

565 output, 

566 input, 

567 N, 

568 C, 

569 OH, 

570 OW, 

571 IH, 

572 IW, 

573 reciprocal_scale_h, 

574 reciprocal_scale_w, 

575 ) 

576 

577 if "TRITONXPU_OTHER_SIM" in os.environ: 

578 del os.environ["TRITONXPU_OTHER_SIM"] 

579 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

580 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

581 

582 return output