Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/unique.py: 0%

524 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import os 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as tle 

9from flag_gems.utils.libentry import libentry 

10 

11 

12@libentry() 

13@triton.jit 

14def simple_unique_flat_kernel( 

15 sorted_data_ptr: tl.tensor, 

16 sorted_indices_ptr: tl.tensor, # in 

17 data_out_ptr: tl.tensor, 

18 inverse_indices_ptr: tl.tensor, 

19 idx_ptr: tl.tensor, 

20 unique_size_ptr: tl.tensor, # out 

21 return_inverse: tl.constexpr, 

22 return_counts: tl.constexpr, 

23 num_tasks: int, 

24 tile_size: tl.constexpr, 

25): 

26 i0 = tl.arange(0, tile_size) 

27 mask = i0 < num_tasks 

28 

29 # load 

30 a = tl.load(sorted_data_ptr + i0, mask=mask) 

31 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

32 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

33 

34 # ne & cumsum 

35 ne_result = tl.where(i0 > 0, a != b, 0) 

36 cumsum = tl.cumsum(ne_result) 

37 

38 # unique_size 

39 unique_size_mask = i0 == tile_size - 1 

40 unique_off = tl.where(unique_size_mask, tl.zeros_like(i0), -1) 

41 tl.store(unique_size_ptr + unique_off, cumsum, mask=unique_size_mask) 

42 

43 # data_out: scatter_(to=cumsum, sorted_data) 

44 data_out_off = tl.where(mask, cumsum, -1) 

45 tl.store(data_out_ptr + data_out_off, a, mask=mask) 

46 

47 # inverse_indices: scatter_(to=sorted_indices, cumsum) 

48 if return_inverse: 

49 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

50 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) 

51 

52 # idx 

53 if return_counts: 

54 idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask 

55 tl.store(idx_ptr + cumsum, i0, mask=idx_mask) 

56 

57 

58@triton.jit 

59def output_counts_flat_impl( 

60 global_pid, 

61 idx_ptr: tl.tensor, 

62 origin_num_tasks: int, # in 

63 counts_ptr: tl.tensor, # out 

64 num_tasks: int, 

65 tile_size: tl.constexpr, 

66): 

67 r = tl.arange(0, tile_size) 

68 

69 # load idx 

70 i0 = global_pid * tile_size + r 

71 mask = i0 < num_tasks 

72 idx = tl.load(idx_ptr + i0, mask=mask) 

73 

74 # load idx_next 

75 i0_next = i0 + 1 

76 next_mask = i0_next < num_tasks 

77 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

78 

79 # diff 

80 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

81 

82 # store counts 

83 tl.store(counts_ptr + i0, counts, mask=mask) 

84 

85 

86@libentry() 

87@triton.jit 

88def output_counts_flat_kernel( 

89 idx_ptr: tl.tensor, 

90 origin_num_tasks: int, # in 

91 counts_ptr: tl.tensor, # out 

92 num_tasks: int, 

93 tiles_per_cta: int, 

94 tile_size: tl.constexpr, 

95): 

96 pid = tle.program_id(0) 

97 ctas_num = tle.num_programs(0) 

98 # grid-stride-loop style kernel 

99 for j in range(0, tiles_per_cta): 

100 global_pid = pid + j * ctas_num 

101 output_counts_flat_impl( 

102 global_pid, 

103 idx_ptr, 

104 origin_num_tasks, # in 

105 counts_ptr, # out 

106 num_tasks, 

107 tile_size, 

108 ) 

109 

110 

111@triton.jit 

112def quick_output_flat_impl( 

113 global_pid, 

114 sorted_data_ptr: tl.tensor, 

115 idx_ptr: tl.tensor, 

116 origin_num_tasks: int, # in 

117 data_out_ptr: tl.tensor, 

118 counts_ptr: tl.tensor, # out 

119 num_tasks: int, 

120 tile_size: tl.constexpr, 

121): 

122 r = tl.arange(0, tile_size) 

123 

124 # load idx 

125 i0 = global_pid * tile_size + r 

126 mask = i0 < num_tasks 

127 idx = tl.load(idx_ptr + i0, mask=mask) 

128 

129 # load idx_next 

130 i0_next = i0 + 1 

131 next_mask = i0_next < num_tasks 

132 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

133 

134 # diff 

135 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

136 

137 # store counts 

138 tl.store(counts_ptr + i0, counts, mask=mask) 

139 

140 # data_out: gather(sorted_data, from=idx) 

141 sorted_data = tl.load(sorted_data_ptr + idx, mask=mask) 

142 tl.store(data_out_ptr + i0, sorted_data, mask=mask) 

143 

144 

145@libentry() 

146@triton.jit 

147def quick_output_flat_kernel( 

148 sorted_data_ptr: tl.tensor, 

149 idx_ptr: tl.tensor, 

150 origin_num_tasks: int, # in 

151 data_out_ptr: tl.tensor, 

152 counts_ptr: tl.tensor, # out 

153 num_tasks: int, 

154 tiles_per_cta: int, 

155 tile_size: tl.constexpr, 

156): 

157 pid = tle.program_id(0) 

158 ctas_num = tle.num_programs(0) 

159 # grid-stride-loop style kernel 

160 for j in range(0, tiles_per_cta): 

161 global_pid = pid + j * ctas_num 

162 quick_output_flat_impl( 

163 global_pid, 

164 sorted_data_ptr, 

165 idx_ptr, 

166 origin_num_tasks, # in 

167 data_out_ptr, 

168 counts_ptr, # out 

169 num_tasks, 

170 tile_size, 

171 ) 

172 

173 

174@triton.jit 

175def local_quick_unique_flat_impl( 

176 global_pid, 

177 sorted_data_ptr: tl.tensor, # in 

178 local_unique_ptr: tl.tensor, 

179 origin_idx_ptr: tl.tensor, 

180 tile_sum_ptr: tl.tensor, # out 

181 global_ctas_num: int, 

182 num_tasks: int, 

183 tile_size: tl.constexpr, 

184 return_counts: tl.constexpr, 

185): 

186 offset = global_pid * tile_size 

187 r = tl.arange(0, tile_size) 

188 i0 = offset + r 

189 mask = i0 < num_tasks 

190 

191 # load 

192 a = tl.load(sorted_data_ptr + i0, mask=mask) 

193 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

194 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

195 

196 # ne & cumsum 

197 ne_result = tl.where(i0 > 0, a != b, 0) 

198 cumsum = tl.cumsum(ne_result) 

199 

200 # local_id or local_unique 

201 local_unique_offset = cumsum - tl.where(global_pid > 0, 1, 0) 

202 local_unique_mask = (local_unique_offset >= 0) & mask 

203 if return_counts: 

204 # origin_idx: scatter_(to=cumsum, i0) 

205 origin_idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & local_unique_mask 

206 lu_store_offset = offset + local_unique_offset 

207 lu_store_offset = tl.where(origin_idx_mask, lu_store_offset, -1) 

208 tl.store( 

209 origin_idx_ptr + lu_store_offset, 

210 i0, 

211 mask=origin_idx_mask, 

212 ) 

213 else: 

214 # local_unique: scatter_(to=cumsum, sorted_data) 

215 lu_store_offset = offset + local_unique_offset 

216 lu_store_offset = tl.where(local_unique_mask, lu_store_offset, -1) 

217 tl.store(local_unique_ptr + lu_store_offset, a, mask=local_unique_mask) 

218 

219 # tile_sum 

220 tile_sum_mask = (r == tile_size - 1) & (global_pid < global_ctas_num) 

221 tile_sum = tl.where(tile_sum_mask & (global_pid == 0), cumsum + 1, cumsum) 

222 tile_sum_store_offset = global_pid + tl.zeros_like(r) 

223 tile_sum_store_offset = tl.where(tile_sum_mask, tile_sum_store_offset, -1) 

224 tl.store(tile_sum_ptr + tile_sum_store_offset, tile_sum, mask=tile_sum_mask) 

225 

226 

227@libentry() 

228@triton.jit 

229def local_quick_unique_flat_kernel( 

230 sorted_data_ptr: tl.tensor, # in 

231 local_unique_ptr: tl.tensor, 

232 origin_idx_ptr: tl.tensor, 

233 tile_sum_ptr: tl.tensor, # out 

234 global_ctas_num: int, 

235 num_tasks: int, 

236 tiles_per_cta: int, 

237 tile_size: tl.constexpr, 

238 return_counts: tl.constexpr, 

239): 

240 pid = tle.program_id(0) 

241 ctas_num = tle.num_programs(0) 

242 # grid-stride-loop style kernel 

243 for j in range(0, tiles_per_cta): 

244 global_pid = pid + j * ctas_num 

245 local_quick_unique_flat_impl( 

246 global_pid, 

247 sorted_data_ptr, # in 

248 local_unique_ptr, 

249 origin_idx_ptr, 

250 tile_sum_ptr, # out 

251 global_ctas_num, 

252 num_tasks, 

253 tile_size, 

254 return_counts, 

255 ) 

256 

257 

258@triton.jit 

259def global_quick_unique_flat_impl( 

260 global_pid, 

261 total, 

262 local_unique_ptr: tl.tensor, 

263 origin_idx_ptr: tl.tensor, 

264 tile_sum_ptr: tl.tensor, # in 

265 data_out_ptr: tl.tensor, 

266 idx_ptr: tl.tensor, # out 

267 ctas_num: tl.constexpr, 

268 global_ctas_num: tl.constexpr, 

269 next_power_global_ctas_num: tl.constexpr, 

270 num_tasks: tl.constexpr, 

271 tile_size: tl.constexpr, 

272 return_counts: tl.constexpr, 

273): 

274 r = tl.arange(0, tile_size) 

275 i0 = global_pid * tile_size + r 

276 mask = i0 < num_tasks 

277 

278 # load tile_sum 

279 p = tl.arange(0, next_power_global_ctas_num) 

280 pre_tile_sum_mask = ( 

281 (p >= global_pid - ctas_num) 

282 & (p < global_pid) 

283 & (p >= 0) 

284 & (p < global_ctas_num) 

285 ) 

286 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

287 cur_tile_sum_mask = global_pid < global_ctas_num 

288 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask) 

289 

290 # total 

291 total += tl.sum(pre_tile_sum) 

292 if global_pid == global_ctas_num - 1: 

293 last_tile_sum_mask = p == global_pid 

294 tile_offset = tl.where(last_tile_sum_mask, p, -1) 

295 tl.store( 

296 tile_sum_ptr + tile_offset, total + cur_tile_sum, mask=last_tile_sum_mask 

297 ) 

298 

299 # idx or data_out 

300 tile_mask = r < cur_tile_sum 

301 out_offset = total + r 

302 if return_counts: 

303 # move origin_idx to idx_ptr 

304 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask) 

305 idx_offset = tl.where(tile_mask, out_offset, -1) 

306 tl.store(idx_ptr + idx_offset, origin_idx, mask=tile_mask) 

307 else: 

308 # move local_unique to data_out_ptr 

309 local_unique = tl.load(local_unique_ptr + i0, mask=mask) 

310 data_out_offset = tl.where(tile_mask, out_offset, -1) 

311 tl.store(data_out_ptr + data_out_offset, local_unique, mask=tile_mask) 

312 

313 return total 

314 

315 

316@libentry() 

317@triton.jit 

318def global_quick_unique_flat_kernel( 

319 local_unique_ptr: tl.tensor, 

320 origin_idx_ptr: tl.tensor, 

321 tile_sum_ptr: tl.tensor, # in 

322 data_out_ptr: tl.tensor, 

323 idx_ptr: tl.tensor, # out 

324 ctas_num: tl.constexpr, 

325 global_ctas_num: tl.constexpr, 

326 next_power_global_ctas_num: tl.constexpr, 

327 num_tasks: tl.constexpr, 

328 tiles_per_cta: tl.constexpr, 

329 tile_size: tl.constexpr, 

330 one_tile_per_cta: tl.constexpr, 

331 return_counts: tl.constexpr, 

332): 

333 pid = tle.program_id(0) 

334 ctas_num = tle.num_programs(0) 

335 if one_tile_per_cta: # monolitic kernel style 

336 global_quick_unique_flat_impl( 

337 pid, 

338 0, 

339 local_unique_ptr, 

340 origin_idx_ptr, 

341 tile_sum_ptr, # in 

342 data_out_ptr, 

343 idx_ptr, # out 

344 ctas_num, 

345 global_ctas_num, 

346 next_power_global_ctas_num, 

347 num_tasks, 

348 tile_size, 

349 return_counts, 

350 ) 

351 else: # grid-stride-loop style kernel 

352 total = tl.zeros([1], dtype=tl.int64) 

353 for j in range(0, tiles_per_cta): 

354 global_pid = pid + j * ctas_num 

355 total = global_quick_unique_flat_impl( 

356 global_pid, 

357 total, 

358 local_unique_ptr, 

359 origin_idx_ptr, 

360 tile_sum_ptr, # in 

361 data_out_ptr, 

362 idx_ptr, # out 

363 ctas_num, 

364 global_ctas_num, 

365 next_power_global_ctas_num, 

366 num_tasks, 

367 tile_size, 

368 return_counts, 

369 ) 

370 

371 

372@triton.jit 

373def global_quick_unique_flat_impl_stage_1( 

374 global_pid, 

375 total, 

376 local_unique_ptr: tl.tensor, 

377 origin_idx_ptr: tl.tensor, 

378 tile_sum_ptr: tl.tensor, # in 

379 data_out_ptr: tl.tensor, 

380 idx_ptr: tl.tensor, # out 

381 ctas_num: tl.constexpr, 

382 global_ctas_num: tl.constexpr, 

383 next_power_global_ctas_num: tl.constexpr, 

384 num_tasks: tl.constexpr, 

385 tile_size: tl.constexpr, 

386 return_counts: tl.constexpr, 

387): 

388 # r = tl.arange(0, tile_size) 

389 # i0 = global_pid * tile_size + r 

390 # mask = i0 < num_tasks 

391 

392 # load tile_sum 

393 p = tl.arange(0, next_power_global_ctas_num) 

394 pre_tile_sum_mask = ( 

395 (p >= global_pid - ctas_num) 

396 & (p < global_pid) 

397 & (p >= 0) 

398 & (p < global_ctas_num) 

399 ) 

400 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

401 cur_tile_sum_mask = global_pid < global_ctas_num 

402 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask) 

403 

404 # total 

405 total += tl.sum(pre_tile_sum) 

406 if global_pid == global_ctas_num - 1: 

407 last_tile_sum_mask = p == global_pid 

408 tile_offset = tl.where(last_tile_sum_mask, p, -1) 

409 tl.store( 

410 tile_sum_ptr + tile_offset, total + cur_tile_sum, mask=last_tile_sum_mask 

411 ) 

412 

413 return total 

414 

415 

416@libentry() 

417@triton.jit 

418def global_quick_unique_flat_kernel_stage_1( 

419 local_unique_ptr: tl.tensor, 

420 origin_idx_ptr: tl.tensor, 

421 tile_sum_ptr: tl.tensor, # in 

422 data_out_ptr: tl.tensor, 

423 idx_ptr: tl.tensor, # out 

424 ctas_num: tl.constexpr, 

425 global_ctas_num: tl.constexpr, 

426 next_power_global_ctas_num: tl.constexpr, 

427 num_tasks: tl.constexpr, 

428 tiles_per_cta: tl.constexpr, 

429 tile_size: tl.constexpr, 

430 one_tile_per_cta: tl.constexpr, 

431 return_counts: tl.constexpr, 

432): 

433 pid = tle.program_id(0) 

434 ctas_num = tle.num_programs(0) 

435 if one_tile_per_cta: # monolitic kernel style 

436 global_quick_unique_flat_impl_stage_1( 

437 pid, 

438 0, 

439 local_unique_ptr, 

440 origin_idx_ptr, 

441 tile_sum_ptr, # in 

442 data_out_ptr, 

443 idx_ptr, # out 

444 ctas_num, 

445 global_ctas_num, 

446 next_power_global_ctas_num, 

447 num_tasks, 

448 tile_size, 

449 return_counts, 

450 ) 

451 else: # grid-stride-loop style kernel 

452 total = tl.zeros([1], dtype=tl.int64) 

453 for j in range(0, tiles_per_cta): 

454 global_pid = pid + j * ctas_num 

455 total = global_quick_unique_flat_impl_stage_1( 

456 global_pid, 

457 total, 

458 local_unique_ptr, 

459 origin_idx_ptr, 

460 tile_sum_ptr, # in 

461 data_out_ptr, 

462 idx_ptr, # out 

463 ctas_num, 

464 global_ctas_num, 

465 next_power_global_ctas_num, 

466 num_tasks, 

467 tile_size, 

468 return_counts, 

469 ) 

470 

471 

472@triton.jit 

473def global_quick_unique_flat_impl_stage_2( 

474 global_pid, 

475 total, 

476 local_unique_ptr: tl.tensor, 

477 origin_idx_ptr: tl.tensor, 

478 tile_sum_ptr: tl.tensor, # in 

479 data_out_ptr: tl.tensor, 

480 idx_ptr: tl.tensor, # out 

481 total_in_ptr, 

482 ctas_num: tl.constexpr, 

483 global_ctas_num: tl.constexpr, 

484 next_power_global_ctas_num: tl.constexpr, 

485 num_tasks: tl.constexpr, 

486 tile_size: tl.constexpr, 

487 return_counts: tl.constexpr, 

488): 

489 r = tl.arange(0, tile_size) 

490 i0 = global_pid * tile_size + r 

491 mask = i0 < num_tasks 

492 

493 # load tile_sum 

494 # p = tl.arange(0, next_power_global_ctas_num) 

495 # pre_tile_sum_mask = ( 

496 # (p >= global_pid - ctas_num) 

497 # & (p < global_pid) 

498 # & (p >= 0) 

499 # & (p < global_ctas_num) 

500 # ) 

501 # pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

502 cur_tile_sum_mask = global_pid < global_ctas_num 

503 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask) 

504 

505 # total 

506 total_in_mask = global_pid < global_ctas_num 

507 total = tl.load(total_in_ptr + global_pid, mask=total_in_mask) 

508 # tl.device_print("total", total) 

509 

510 # idx or data_out 

511 tile_mask = r < cur_tile_sum 

512 out_offset = total + r 

513 if return_counts: 

514 # move origin_idx to idx_ptr 

515 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask) 

516 idx_offset = tl.where(tile_mask, out_offset, -1) 

517 tl.store(idx_ptr + idx_offset, origin_idx, mask=tile_mask) 

518 else: 

519 # move local_unique to data_out_ptr 

520 local_unique = tl.load(local_unique_ptr + i0, mask=mask) 

521 data_out_offset = tl.where(tile_mask, out_offset, -1) 

522 tl.store(data_out_ptr + data_out_offset, local_unique, mask=tile_mask) 

523 

524 return total 

525 

526 

527@libentry() 

528@triton.jit 

529def global_quick_unique_flat_kernel_stage_2( 

530 local_unique_ptr: tl.tensor, 

531 origin_idx_ptr: tl.tensor, 

532 tile_sum_ptr: tl.tensor, # in 

533 data_out_ptr: tl.tensor, 

534 idx_ptr: tl.tensor, # out 

535 total_in_ptr, 

536 ctas_num: tl.constexpr, 

537 global_ctas_num: tl.constexpr, 

538 next_power_global_ctas_num: tl.constexpr, 

539 num_tasks: tl.constexpr, 

540 tiles_per_cta: tl.constexpr, 

541 tile_size: tl.constexpr, 

542 one_tile_per_cta: tl.constexpr, 

543 return_counts: tl.constexpr, 

544): 

545 pid = tle.program_id(0) 

546 ctas_num = tle.num_programs(0) 

547 if one_tile_per_cta: # monolitic kernel style 

548 global_quick_unique_flat_impl_stage_2( 

549 pid, 

550 0, 

551 local_unique_ptr, 

552 origin_idx_ptr, 

553 tile_sum_ptr, # in 

554 data_out_ptr, 

555 idx_ptr, # out 

556 total_in_ptr, 

557 ctas_num, 

558 global_ctas_num, 

559 next_power_global_ctas_num, 

560 num_tasks, 

561 tile_size, 

562 return_counts, 

563 ) 

564 else: # grid-stride-loop style kernel 

565 total = tl.zeros([1], dtype=tl.int64) 

566 for j in range(0, tiles_per_cta): 

567 global_pid = pid + j * ctas_num 

568 total = global_quick_unique_flat_impl_stage_2( 

569 global_pid, 

570 total, 

571 local_unique_ptr, 

572 origin_idx_ptr, 

573 tile_sum_ptr, # in 

574 data_out_ptr, 

575 idx_ptr, # out 

576 total_in_ptr, 

577 ctas_num, 

578 global_ctas_num, 

579 next_power_global_ctas_num, 

580 num_tasks, 

581 tile_size, 

582 return_counts, 

583 ) 

584 

585 

586def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool): 

587 num_tasks = sorted_data.numel() 

588 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

589 tile_size = min(8192, next_power_num_tasks) 

590 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

591 # if global_ctas_num <= 8192: 

592 # tile_size = max( 

593 # 32, min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks) 

594 # ) 

595 # global_ctas_num = triton.cdiv(num_tasks, tile_size) 

596 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

597 ctas_num = global_ctas_num # if global_ctas_num < 65536 else 2048 

598 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

599 num_warps = 8 if tiles_per_cta == 1 else 32 

600 grid = (ctas_num, 1, 1) 

601 # print(f"ctas_num = {ctas_num}") 

602 # print(f"tile_size = {tile_size}") 

603 # print(f"global_ctas_num = {global_ctas_num}") 

604 # print(f"tiles_per_cta = {tiles_per_cta}") 

605 

606 # allocate tensor 

607 if return_counts: 

608 local_unique = None 

609 origin_idx = torch.empty_like(sorted_data, dtype=torch.int64) 

610 idx = torch.empty_like(origin_idx) 

611 else: 

612 local_unique = torch.empty_like(sorted_data) 

613 origin_idx = None 

614 idx = None 

615 counts = None 

616 tile_sum = torch.empty( 

617 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device 

618 ) 

619 data_out = None 

620 if not return_counts: 

621 data_out = torch.empty_like(sorted_data) 

622 assert tiles_per_cta == 1 

623 # launch kernel 

624 with torch_device_fn.device(sorted_data.device.index): 

625 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

626 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

627 local_quick_unique_flat_kernel[grid]( 

628 sorted_data, # in 

629 local_unique, 

630 origin_idx, 

631 tile_sum, # out 

632 global_ctas_num, 

633 num_tasks, 

634 tiles_per_cta=tiles_per_cta, 

635 tile_size=tile_size, 

636 return_counts=return_counts, 

637 num_warps=num_warps, 

638 ) 

639 if "TRITONXPU_OTHER_SIM" in os.environ: 

640 del os.environ["TRITONXPU_OTHER_SIM"] 

641 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

642 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

643 

644 if num_tasks < 2**26: 

645 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

646 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

647 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

648 global_quick_unique_flat_kernel[grid]( 

649 local_unique, 

650 origin_idx, 

651 tile_sum, # in 

652 data_out, 

653 idx, # out 

654 ctas_num, 

655 global_ctas_num, 

656 next_power_global_ctas_num, 

657 num_tasks, 

658 tiles_per_cta=tiles_per_cta, 

659 tile_size=tile_size, 

660 one_tile_per_cta=tiles_per_cta == 1, 

661 return_counts=return_counts, 

662 num_warps=num_warps, 

663 isCloseVectorization=True, 

664 # buffer_size_limit=128, 

665 ) 

666 if "TRITONXPU_OTHER_SIM" in os.environ: 

667 del os.environ["TRITONXPU_OTHER_SIM"] 

668 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

669 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

670 if "TRITONXPU_INTERLEAVE" in os.environ: 

671 del os.environ["TRITONXPU_INTERLEAVE"] 

672 else: 

673 # print(f'tile_sum.shape = {tile_sum.shape}') 

674 # print(f'tile_sum.cpu() = {tile_sum.cpu()}') 

675 total_in = torch.cumsum(tile_sum, dim=0) 

676 total_in = torch.roll(total_in, shifts=1) 

677 total_in[0] = 0 

678 # print(f'in total_in.cpu() = {total_in.cpu()}') 

679 

680 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

681 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

682 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

683 global_quick_unique_flat_kernel_stage_1[grid]( 

684 local_unique, 

685 origin_idx, 

686 tile_sum, # in 

687 data_out, 

688 idx, # out 

689 ctas_num, 

690 global_ctas_num, 

691 next_power_global_ctas_num, 

692 num_tasks, 

693 tiles_per_cta=tiles_per_cta, 

694 tile_size=tile_size, 

695 one_tile_per_cta=tiles_per_cta == 1, 

696 return_counts=return_counts, 

697 num_warps=num_warps, 

698 isCloseVectorization=True, 

699 buffer_size_limit=128, 

700 ) 

701 if "TRITONXPU_OTHER_SIM" in os.environ: 

702 del os.environ["TRITONXPU_OTHER_SIM"] 

703 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

704 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

705 if "TRITONXPU_INTERLEAVE" in os.environ: 

706 del os.environ["TRITONXPU_INTERLEAVE"] 

707 

708 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

709 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

710 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

711 global_quick_unique_flat_kernel_stage_2[grid]( 

712 local_unique, 

713 origin_idx, 

714 tile_sum, # in 

715 data_out, 

716 idx, # out 

717 total_in, 

718 ctas_num, 

719 global_ctas_num, 

720 next_power_global_ctas_num, 

721 num_tasks, 

722 tiles_per_cta=tiles_per_cta, 

723 tile_size=tile_size, 

724 one_tile_per_cta=tiles_per_cta == 1, 

725 return_counts=return_counts, 

726 num_warps=num_warps, 

727 isCloseVectorization=True, 

728 ) 

729 if "TRITONXPU_OTHER_SIM" in os.environ: 

730 del os.environ["TRITONXPU_OTHER_SIM"] 

731 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

732 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

733 if "TRITONXPU_INTERLEAVE" in os.environ: 

734 del os.environ["TRITONXPU_INTERLEAVE"] 

735 

736 out_size = tile_sum[-1].item() 

737 if return_counts: 

738 data_out = torch.empty( 

739 (out_size,), dtype=sorted_data.dtype, device=sorted_data.device 

740 ) 

741 idx = idx[:out_size] 

742 counts = origin_idx[:out_size] 

743 quick_output_flat_kernel[grid]( 

744 sorted_data, 

745 idx, 

746 num_tasks, # in 

747 data_out, 

748 counts, # out 

749 out_size, 

750 tiles_per_cta, 

751 tile_size, 

752 num_warps=num_warps, 

753 isCloseUnrollControl=True 

754 if sorted_data.dtype == torch.int16 

755 else False, 

756 ) 

757 

758 if return_counts: 

759 return data_out, None, counts 

760 else: 

761 return data_out[:out_size], None, None 

762 

763 

764@triton.jit 

765def local_ne_flat_impl( 

766 global_pid, 

767 sorted_data_ptr: tl.tensor, # in 

768 ne_result_ptr: tl.tensor, 

769 tile_sum_ptr: tl.tensor, # out 

770 global_ctas_num: int, 

771 num_tasks: int, 

772 tile_size: tl.constexpr, 

773): 

774 r = tl.arange(0, tile_size) 

775 i0 = global_pid * tile_size + r 

776 mask = i0 < num_tasks 

777 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

778 

779 # load 

780 a = tl.load(sorted_data_ptr + i0, mask=mask) 

781 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

782 

783 # compute 

784 ne_result = tl.where(i0 > 0, a != b, 0) 

785 

786 # store ne_result 

787 tl.store(ne_result_ptr + i0, ne_result, mask=mask) 

788 

789 # store tile_sum 

790 tile_sum = tl.sum(ne_result) 

791 tile_sum_mask = global_pid < global_ctas_num 

792 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask) 

793 

794 

795@libentry() 

796@triton.jit 

797def local_ne_flat_kernel( 

798 sorted_data_ptr: tl.tensor, # in 

799 ne_result_ptr: tl.tensor, 

800 tile_sum_ptr: tl.tensor, # out 

801 global_ctas_num: int, 

802 num_tasks: int, 

803 tiles_per_cta: int, 

804 tile_size: tl.constexpr, 

805): 

806 pid = tle.program_id(0) 

807 ctas_num = tle.num_programs(0) 

808 # grid-stride-loop style kernel 

809 for j in range(0, tiles_per_cta): 

810 global_pid = pid + j * ctas_num 

811 local_ne_flat_impl( 

812 global_pid, 

813 sorted_data_ptr, # in 

814 ne_result_ptr, 

815 tile_sum_ptr, # out 

816 global_ctas_num, 

817 num_tasks, 

818 tile_size, 

819 ) 

820 

821 

822@triton.jit 

823def global_cumsum_flat_impl( 

824 global_pid, 

825 total, 

826 ne_result_ptr: tl.tensor, 

827 tile_sum_ptr: tl.tensor, # in 

828 sorted_data_ptr: tl.tensor, 

829 sorted_indices_ptr: tl.tensor, # in 

830 data_out_ptr: tl.tensor, 

831 inverse_indices_ptr: tl.tensor, 

832 idx_ptr: tl.tensor, # out 

833 cumsum_out, 

834 ctas_num: tl.constexpr, 

835 global_ctas_num: int, 

836 next_power_global_ctas_num: tl.constexpr, 

837 num_tasks: int, 

838 tile_size: tl.constexpr, 

839 return_counts: tl.constexpr, 

840): 

841 offset = global_pid * tile_size 

842 r = tl.arange(0, tile_size) 

843 i0 = offset + r 

844 mask = i0 < num_tasks 

845 

846 # load sorted_data, sorted_indices 

847 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask) 

848 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

849 

850 # load tile_sum 

851 p = tl.arange(0, next_power_global_ctas_num) 

852 pre_tile_sum_mask = ( 

853 (p >= global_pid - ctas_num) 

854 & (p < global_pid) 

855 & (p >= 0) 

856 & (p < global_ctas_num) 

857 ) 

858 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

859 

860 # cumsum 

861 total += tl.sum(pre_tile_sum) 

862 # tl.device_print("total", total) 

863 ne_result = tl.load(ne_result_ptr + i0, mask=mask) 

864 ne_result_i1 = ne_result.to(tl.int1) 

865 ne_result = ne_result.to(tl.int32) 

866 # tl.device_print("ne_result", ne_result) 

867 cumsum = tl.cumsum(ne_result) 

868 # tl.store(cumsum_out + i0, cumsum) 

869 # tl.device_print("cumsum", cumsum) 

870 

871 # tile_sum 

872 if global_pid == global_ctas_num - 1: 

873 last_tile_sum_mask = i0 == num_tasks - 1 

874 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum) 

875 tile_offset = tl.where(last_tile_sum_mask, global_pid + tl.zeros_like(r), -1) 

876 tl.store( 

877 tile_sum_ptr + tile_offset, 

878 tile_sum, 

879 mask=last_tile_sum_mask, 

880 ) 

881 cumsum += total 

882 

883 # data_out: scatter_(to=cumsum, sorted_data) 

884 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask) 

885 

886 # inverse_indices: scatter_(to=sorted_indices, cumsum) 

887 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) 

888 

889 # idx 

890 if return_counts: 

891 idx_mask = ((i0 == 0) | ne_result_i1) & mask 

892 idx_offset = tl.where(idx_mask, cumsum, num_tasks + 1) 

893 tl.store(idx_ptr + idx_offset, i0, mask=idx_mask) 

894 

895 return total 

896 

897 

898@libentry() 

899@triton.jit 

900def global_cumsum_flat_kernel( 

901 ne_result_ptr: tl.tensor, 

902 tile_sum_ptr: tl.tensor, # in 

903 sorted_data_ptr: tl.tensor, 

904 sorted_indices_ptr: tl.tensor, # in 

905 data_out_ptr: tl.tensor, 

906 inverse_indices_ptr: tl.tensor, 

907 idx_ptr: tl.tensor, # out 

908 cumsum_out, 

909 ctas_num: int, 

910 global_ctas_num: int, 

911 next_power_global_ctas_num: tl.constexpr, 

912 num_tasks: int, 

913 tiles_per_cta: int, 

914 tile_size: tl.constexpr, 

915 one_tile_per_cta: tl.constexpr, 

916 return_counts: tl.constexpr, 

917): 

918 pid = tle.program_id(0) 

919 ctas_num = tle.num_programs(0) 

920 if one_tile_per_cta: # monolitic kernel style 

921 global_cumsum_flat_impl( 

922 pid, 

923 0, 

924 ne_result_ptr, 

925 tile_sum_ptr, # in 

926 sorted_data_ptr, 

927 sorted_indices_ptr, # in 

928 data_out_ptr, 

929 inverse_indices_ptr, 

930 idx_ptr, # out 

931 cumsum_out, 

932 ctas_num, 

933 global_ctas_num, 

934 next_power_global_ctas_num, 

935 num_tasks, 

936 tile_size, 

937 return_counts, 

938 ) 

939 else: # grid-stride-loop style kernel 

940 total = tl.zeros([1], dtype=tl.int64) 

941 for j in range(0, tiles_per_cta): 

942 global_pid = pid + j * ctas_num 

943 total = global_cumsum_flat_impl( 

944 global_pid, 

945 total, 

946 ne_result_ptr, 

947 tile_sum_ptr, # in 

948 sorted_data_ptr, 

949 sorted_indices_ptr, # in 

950 data_out_ptr, 

951 inverse_indices_ptr, 

952 idx_ptr, # out 

953 cumsum_out, 

954 ctas_num, 

955 global_ctas_num, 

956 next_power_global_ctas_num, 

957 num_tasks, 

958 tile_size, 

959 return_counts, 

960 ) 

961 

962 

963@triton.jit 

964def global_cumsum_flat_impl_stage_1( 

965 global_pid, 

966 total, 

967 ne_result_ptr: tl.tensor, 

968 tile_sum_ptr: tl.tensor, # in 

969 sorted_data_ptr: tl.tensor, 

970 sorted_indices_ptr: tl.tensor, # in 

971 data_out_ptr: tl.tensor, 

972 inverse_indices_ptr: tl.tensor, 

973 idx_ptr: tl.tensor, # out 

974 total_in_ptr, 

975 cumsum_in_ptr, 

976 ctas_num: tl.constexpr, 

977 global_ctas_num: int, 

978 next_power_global_ctas_num: tl.constexpr, 

979 num_tasks: int, 

980 tile_size: tl.constexpr, 

981 return_counts: tl.constexpr, 

982): 

983 offset = global_pid * tile_size 

984 r = tl.arange(0, tile_size) 

985 i0 = offset + r 

986 mask = i0 < num_tasks 

987 

988 # load sorted_data, sorted_indices 

989 # sorted_data = tl.load(sorted_data_ptr + i0, mask=mask) 

990 # sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

991 

992 # load tile_sum 

993 # p = tl.arange(0, next_power_global_ctas_num) 

994 # pre_tile_sum_mask = ( 

995 # (p >= global_pid - ctas_num) 

996 # & (p < global_pid) 

997 # & (p >= 0) 

998 # & (p < global_ctas_num) 

999 # ) 

1000 # pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

1001 

1002 # cumsum 

1003 # total += tl.sum(pre_tile_sum) 

1004 # ne_result = tl.load(ne_result_ptr + i0, mask=mask) 

1005 # ne_result_i1 = ne_result.to(tl.int1) 

1006 # ne_result = ne_result.to(tl.int32) 

1007 # cumsum = tl.cumsum(ne_result) 

1008 total_in_mask = global_pid < global_ctas_num 

1009 total = tl.load(total_in_ptr + global_pid, mask=total_in_mask) 

1010 

1011 ne_result = tl.load(ne_result_ptr + i0, mask=mask) 

1012 # ne_result_i1 = ne_result.to(tl.int1) 

1013 ne_result = ne_result.to(tl.int32) 

1014 # tl.device_print("ne_result", ne_result) 

1015 # cumsum = tl.cumsum(ne_result) 

1016 cumsum = tl.load(cumsum_in_ptr + i0) 

1017 

1018 # tile_sum 

1019 if global_pid == global_ctas_num - 1: 

1020 last_tile_sum_mask = i0 == num_tasks - 1 

1021 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum) 

1022 tile_offset = tl.where(last_tile_sum_mask, global_pid + tl.zeros_like(r), -1) 

1023 tl.store( 

1024 tile_sum_ptr + tile_offset, 

1025 tile_sum, 

1026 mask=last_tile_sum_mask, 

1027 ) 

1028 

1029 return total 

1030 

1031 

1032@libentry() 

1033@triton.jit 

1034def global_cumsum_flat_kernel_stage_1( 

1035 ne_result_ptr: tl.tensor, 

1036 tile_sum_ptr: tl.tensor, # in 

1037 sorted_data_ptr: tl.tensor, 

1038 sorted_indices_ptr: tl.tensor, # in 

1039 data_out_ptr: tl.tensor, 

1040 inverse_indices_ptr: tl.tensor, 

1041 idx_ptr: tl.tensor, # out 

1042 total_in_ptr, 

1043 cumsum_in_ptr, 

1044 ctas_num: int, 

1045 global_ctas_num: int, 

1046 next_power_global_ctas_num: tl.constexpr, 

1047 num_tasks: int, 

1048 tiles_per_cta: int, 

1049 tile_size: tl.constexpr, 

1050 one_tile_per_cta: tl.constexpr, 

1051 return_counts: tl.constexpr, 

1052): 

1053 pid = tle.program_id(0) 

1054 ctas_num = tle.num_programs(0) 

1055 if one_tile_per_cta: # monolitic kernel style 

1056 global_cumsum_flat_impl_stage_1( 

1057 pid, 

1058 0, 

1059 ne_result_ptr, 

1060 tile_sum_ptr, # in 

1061 sorted_data_ptr, 

1062 sorted_indices_ptr, # in 

1063 data_out_ptr, 

1064 inverse_indices_ptr, 

1065 idx_ptr, # out 

1066 total_in_ptr, 

1067 cumsum_in_ptr, 

1068 ctas_num, 

1069 global_ctas_num, 

1070 next_power_global_ctas_num, 

1071 num_tasks, 

1072 tile_size, 

1073 return_counts, 

1074 ) 

1075 else: # grid-stride-loop style kernel 

1076 total = tl.zeros([1], dtype=tl.int64) 

1077 for j in range(0, tiles_per_cta): 

1078 global_pid = pid + j * ctas_num 

1079 total = global_cumsum_flat_impl_stage_1( 

1080 global_pid, 

1081 total, 

1082 ne_result_ptr, 

1083 tile_sum_ptr, # in 

1084 sorted_data_ptr, 

1085 sorted_indices_ptr, # in 

1086 data_out_ptr, 

1087 inverse_indices_ptr, 

1088 idx_ptr, # out 

1089 total_in_ptr, 

1090 cumsum_in_ptr, 

1091 ctas_num, 

1092 global_ctas_num, 

1093 next_power_global_ctas_num, 

1094 num_tasks, 

1095 tile_size, 

1096 return_counts, 

1097 ) 

1098 

1099 

1100@triton.jit 

1101def global_cumsum_flat_impl_stage_2( 

1102 global_pid, 

1103 total, 

1104 ne_result_ptr: tl.tensor, 

1105 tile_sum_ptr: tl.tensor, # in 

1106 sorted_data_ptr: tl.tensor, 

1107 sorted_indices_ptr: tl.tensor, # in 

1108 data_out_ptr: tl.tensor, 

1109 inverse_indices_ptr: tl.tensor, 

1110 idx_ptr: tl.tensor, # out 

1111 total_in_ptr, 

1112 cumsum_in_ptr, 

1113 ctas_num: tl.constexpr, 

1114 global_ctas_num: int, 

1115 next_power_global_ctas_num: tl.constexpr, 

1116 num_tasks: int, 

1117 tile_size: tl.constexpr, 

1118 return_counts: tl.constexpr, 

1119): 

1120 offset = global_pid * tile_size 

1121 r = tl.arange(0, tile_size) 

1122 i0 = offset + r 

1123 mask = i0 < num_tasks 

1124 

1125 # load sorted_data, sorted_indices 

1126 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask) 

1127 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

1128 

1129 # load tile_sum 

1130 # p = tl.arange(0, next_power_global_ctas_num) 

1131 # pre_tile_sum_mask = ( 

1132 # (p >= global_pid - ctas_num) 

1133 # & (p < global_pid) 

1134 # & (p >= 0) 

1135 # & (p < global_ctas_num) 

1136 # ) 

1137 # pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

1138 

1139 # cumsum 

1140 total_in_mask = global_pid < global_ctas_num 

1141 total = tl.load(total_in_ptr + global_pid, mask=total_in_mask) 

1142 

1143 ne_result = tl.load(ne_result_ptr + i0, mask=mask) 

1144 ne_result_i1 = ne_result.to(tl.int1) 

1145 ne_result = ne_result.to(tl.int32) 

1146 # tl.device_print("ne_result", ne_result) 

1147 # cumsum = tl.cumsum(ne_result) 

1148 cumsum = tl.load(cumsum_in_ptr + i0) 

1149 # tl.device_print("cumsum", cumsum) 

1150 cumsum += total 

1151 

1152 # data_out: scatter_(to=cumsum, sorted_data) 

1153 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask) 

1154 

1155 # inverse_indices: scatter_(to=sorted_indices, cumsum) 

1156 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) 

1157 

1158 # idx 

1159 if return_counts: 

1160 idx_mask = ((i0 == 0) | ne_result_i1) & mask 

1161 idx_offset = tl.where(idx_mask, cumsum, num_tasks + 1) 

1162 tl.store(idx_ptr + idx_offset, i0, mask=idx_mask) 

1163 

1164 return total 

1165 

1166 

1167@libentry() 

1168@triton.jit 

1169def global_cumsum_flat_kernel_stage_2( 

1170 ne_result_ptr: tl.tensor, 

1171 tile_sum_ptr: tl.tensor, # in 

1172 sorted_data_ptr: tl.tensor, 

1173 sorted_indices_ptr: tl.tensor, # in 

1174 data_out_ptr: tl.tensor, 

1175 inverse_indices_ptr: tl.tensor, 

1176 idx_ptr: tl.tensor, # out 

1177 total_in_ptr, 

1178 cumsum_in_ptr, 

1179 ctas_num: int, 

1180 global_ctas_num: int, 

1181 next_power_global_ctas_num: tl.constexpr, 

1182 num_tasks: int, 

1183 tiles_per_cta: int, 

1184 tile_size: tl.constexpr, 

1185 one_tile_per_cta: tl.constexpr, 

1186 return_counts: tl.constexpr, 

1187): 

1188 pid = tle.program_id(0) 

1189 ctas_num = tle.num_programs(0) 

1190 if one_tile_per_cta: # monolitic kernel style 

1191 global_cumsum_flat_impl_stage_2( 

1192 pid, 

1193 0, 

1194 ne_result_ptr, 

1195 tile_sum_ptr, # in 

1196 sorted_data_ptr, 

1197 sorted_indices_ptr, # in 

1198 data_out_ptr, 

1199 inverse_indices_ptr, 

1200 idx_ptr, # out 

1201 total_in_ptr, 

1202 cumsum_in_ptr, 

1203 ctas_num, 

1204 global_ctas_num, 

1205 next_power_global_ctas_num, 

1206 num_tasks, 

1207 tile_size, 

1208 return_counts, 

1209 ) 

1210 else: # grid-stride-loop style kernel 

1211 total = tl.zeros([1], dtype=tl.int64) 

1212 for j in range(0, tiles_per_cta): 

1213 global_pid = pid + j * ctas_num 

1214 total = global_cumsum_flat_impl_stage_2( 

1215 global_pid, 

1216 total, 

1217 ne_result_ptr, 

1218 tile_sum_ptr, # in 

1219 sorted_data_ptr, 

1220 sorted_indices_ptr, # in 

1221 data_out_ptr, 

1222 inverse_indices_ptr, 

1223 idx_ptr, # out 

1224 total_in_ptr, 

1225 cumsum_in_ptr, 

1226 ctas_num, 

1227 global_ctas_num, 

1228 next_power_global_ctas_num, 

1229 num_tasks, 

1230 tile_size, 

1231 return_counts, 

1232 ) 

1233 

1234 

1235def sorted_indices_unique_flat( 

1236 sorted_data: torch.Tensor, sorted_indices: torch.Tensor, return_counts: bool 

1237): 

1238 num_tasks = sorted_data.numel() 

1239 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

1240 tile_size = min(2048, next_power_num_tasks) 

1241 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

1242 # if global_ctas_num <= 8192: 

1243 # min_tile_size = 512 if global_ctas_num > 32 else 256 

1244 # tile_size = max( 

1245 # min_tile_size, 

1246 # min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks), 

1247 # ) 

1248 # global_ctas_num = triton.cdiv(num_tasks, tile_size) 

1249 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

1250 ctas_num = global_ctas_num # if global_ctas_num < 32768 else 8192 

1251 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

1252 num_warps = 8 if tiles_per_cta == 1 else 32 

1253 grid = (ctas_num, 1, 1) 

1254 # print(f"ctas_num = {ctas_num}") 

1255 # print(f"tile_size = {tile_size}") 

1256 # print(f"tiles_per_cta = {tiles_per_cta}") 

1257 # print(f"global_ctas_num = {global_ctas_num}") 

1258 

1259 # allocate tensor 

1260 ne_result = torch.empty_like(sorted_data, dtype=torch.bool) 

1261 tile_sum = torch.empty( 

1262 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device 

1263 ) 

1264 data_out = torch.empty_like(sorted_data) 

1265 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) 

1266 idx = None 

1267 if return_counts: 

1268 idx = torch.empty_like(inverse_indices) 

1269 

1270 # assert tiles_per_cta == 1 

1271 

1272 # launch kernel 

1273 with torch_device_fn.device(sorted_data.device.index): 

1274 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

1275 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

1276 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

1277 

1278 local_ne_flat_kernel[grid]( 

1279 sorted_data, # in 

1280 ne_result, 

1281 tile_sum, # out 

1282 global_ctas_num, 

1283 num_tasks, 

1284 tiles_per_cta=tiles_per_cta, 

1285 tile_size=tile_size, 

1286 num_warps=num_warps, 

1287 ) 

1288 if "TRITONXPU_OTHER_SIM" in os.environ: 

1289 del os.environ["TRITONXPU_OTHER_SIM"] 

1290 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

1291 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

1292 if "TRITONXPU_INTERLEAVE" in os.environ: 

1293 del os.environ["TRITONXPU_INTERLEAVE"] 

1294 

1295 if num_tasks < 2**26: 

1296 # print(f"ne_result.shape = {ne_result.shape}") 

1297 # print(f"tile_sum.shape = {tile_sum.shape}") 

1298 # print(f'tile_sum.cpu() = {tile_sum.cpu()}') 

1299 next_multiple = ((num_tasks // 2048) + 1) * 2048 

1300 cumsum_out = torch.zeros(next_multiple) 

1301 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

1302 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

1303 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

1304 global_cumsum_flat_kernel[grid]( 

1305 ne_result, 

1306 tile_sum, # in 

1307 sorted_data, 

1308 sorted_indices, # in 

1309 data_out, 

1310 inverse_indices, 

1311 idx, # out 

1312 cumsum_out, 

1313 ctas_num, 

1314 global_ctas_num, 

1315 next_power_global_ctas_num, 

1316 num_tasks, 

1317 tiles_per_cta=tiles_per_cta, 

1318 tile_size=tile_size, 

1319 one_tile_per_cta=tiles_per_cta == 1, 

1320 return_counts=return_counts, 

1321 num_warps=num_warps, 

1322 ) 

1323 if "TRITONXPU_OTHER_SIM" in os.environ: 

1324 del os.environ["TRITONXPU_OTHER_SIM"] 

1325 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

1326 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

1327 if "TRITONXPU_INTERLEAVE" in os.environ: 

1328 del os.environ["TRITONXPU_INTERLEAVE"] 

1329 # print(f'cumsum_out = {cumsum_out.cpu()}') 

1330 # print(f'out tile_sum.cpu() = {tile_sum.cpu()}') 

1331 

1332 else: 

1333 total_in = torch.cumsum(tile_sum, dim=0) 

1334 total_in = torch.roll(total_in, shifts=1) 

1335 total_in[0] = 0 

1336 # print(f"total_in.shape = {total_in.shape}") 

1337 # print(f"total_in.cpu() = {total_in.cpu()}") 

1338 

1339 # ne_result = torch.cumsum(ne_result, dim=0) 

1340 # print(f"ne_result.shape = {ne_result.shape}") 

1341 next_multiple = ((num_tasks // 2048) + 1) * 2048 

1342 padding_size = next_multiple - num_tasks # 96256 - 96000 = 256 

1343 padded_ne_result = torch.nn.functional.pad( 

1344 ne_result, (0, padding_size), "constant", 0 

1345 ) 

1346 num_blocks = next_multiple // 2048 # 96256 / 2048 = 47 

1347 reshaped = padded_ne_result.view(num_blocks, 2048) 

1348 cumsum_blocks = torch.cumsum(reshaped, dim=1) 

1349 cumsum_result = cumsum_blocks.view(-1) 

1350 

1351 # print(f'ne_result.cpu() = {ne_result.cpu()}') 

1352 

1353 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

1354 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

1355 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

1356 global_cumsum_flat_kernel_stage_1[grid]( 

1357 ne_result, 

1358 tile_sum, # in 

1359 sorted_data, 

1360 sorted_indices, # in 

1361 data_out, 

1362 inverse_indices, 

1363 idx, # out 

1364 total_in, 

1365 cumsum_result, 

1366 ctas_num, 

1367 global_ctas_num, 

1368 next_power_global_ctas_num, 

1369 num_tasks, 

1370 tiles_per_cta=tiles_per_cta, 

1371 tile_size=tile_size, 

1372 one_tile_per_cta=tiles_per_cta == 1, 

1373 return_counts=return_counts, 

1374 num_warps=num_warps, 

1375 ) 

1376 if "TRITONXPU_OTHER_SIM" in os.environ: 

1377 del os.environ["TRITONXPU_OTHER_SIM"] 

1378 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

1379 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

1380 if "TRITONXPU_INTERLEAVE" in os.environ: 

1381 del os.environ["TRITONXPU_INTERLEAVE"] 

1382 

1383 # print(f'out tile_sum.cpu() = {tile_sum.cpu()}') 

1384 

1385 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

1386 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

1387 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

1388 global_cumsum_flat_kernel_stage_2[grid]( 

1389 ne_result, 

1390 tile_sum, # in 

1391 sorted_data, 

1392 sorted_indices, # in 

1393 data_out, 

1394 inverse_indices, 

1395 idx, # out 

1396 total_in, 

1397 cumsum_result, 

1398 ctas_num, 

1399 global_ctas_num, 

1400 next_power_global_ctas_num, 

1401 num_tasks, 

1402 tiles_per_cta=tiles_per_cta, 

1403 tile_size=tile_size, 

1404 one_tile_per_cta=tiles_per_cta == 1, 

1405 return_counts=return_counts, 

1406 num_warps=num_warps, 

1407 isCloseUnrollControl=True, 

1408 ) 

1409 if "TRITONXPU_OTHER_SIM" in os.environ: 

1410 del os.environ["TRITONXPU_OTHER_SIM"] 

1411 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

1412 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

1413 if "TRITONXPU_INTERLEAVE" in os.environ: 

1414 del os.environ["TRITONXPU_INTERLEAVE"] 

1415 

1416 out_size = tile_sum[-1].item() + 1 

1417 counts = None 

1418 if return_counts: 

1419 idx = idx[:out_size] 

1420 counts = torch.empty_like(idx) 

1421 # print("i am here!!!!") 

1422 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

1423 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

1424 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

1425 output_counts_flat_kernel[grid]( 

1426 idx, 

1427 num_tasks, # in 

1428 counts, # out 

1429 out_size, 

1430 tiles_per_cta, 

1431 tile_size, 

1432 num_warps=num_warps, 

1433 ) 

1434 if "TRITONXPU_OTHER_SIM" in os.environ: 

1435 del os.environ["TRITONXPU_OTHER_SIM"] 

1436 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

1437 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

1438 if "TRITONXPU_INTERLEAVE" in os.environ: 

1439 del os.environ["TRITONXPU_INTERLEAVE"] 

1440 

1441 return data_out[:out_size], inverse_indices, counts 

1442 

1443 

1444def simple_unique_flat( 

1445 sorted_data: torch.Tensor, 

1446 sorted_indices: torch.Tensor, 

1447 return_inverse: bool, 

1448 return_counts: bool, 

1449): 

1450 num_tasks = sorted_data.numel() 

1451 grid = (1, 1, 1) 

1452 

1453 # allocate tensor 

1454 data_out = torch.zeros_like(sorted_data) 

1455 if return_inverse: 

1456 inverse_indices = torch.zeros_like(sorted_data, dtype=torch.int64) 

1457 else: 

1458 inverse_indices = None 

1459 if return_counts: 

1460 idx = torch.zeros_like(sorted_data, dtype=torch.int64) 

1461 else: 

1462 idx = None 

1463 unique_size = torch.zeros([1], dtype=torch.int64, device=sorted_data.device) 

1464 

1465 # launch kernel 

1466 with torch_device_fn.device(sorted_data.device.index): 

1467 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

1468 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

1469 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

1470 simple_unique_flat_kernel[grid]( 

1471 sorted_data, 

1472 sorted_indices, # in 

1473 data_out, 

1474 inverse_indices, 

1475 idx, 

1476 unique_size, # out 

1477 return_inverse, 

1478 return_counts, 

1479 num_tasks, 

1480 tile_size=triton.next_power_of_2(num_tasks), 

1481 num_warps=8, 

1482 ) 

1483 if "TRITONXPU_OTHER_SIM" in os.environ: 

1484 del os.environ["TRITONXPU_OTHER_SIM"] 

1485 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

1486 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

1487 if "TRITONXPU_INTERLEAVE" in os.environ: 

1488 del os.environ["TRITONXPU_INTERLEAVE"] 

1489 out_size = unique_size.item() + 1 

1490 # print(f"unique_size.item() = {unique_size.item()}") 

1491 counts = None 

1492 if return_counts: 

1493 idx = idx[:out_size] 

1494 counts = torch.empty_like(idx) 

1495 with torch_device_fn.device(sorted_data.device.index): 

1496 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

1497 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

1498 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

1499 output_counts_flat_kernel[grid]( 

1500 idx, 

1501 num_tasks, # in 

1502 counts, # out 

1503 num_tasks=out_size, 

1504 tiles_per_cta=1, 

1505 tile_size=triton.next_power_of_2(out_size), 

1506 num_warps=8, 

1507 ) 

1508 if "TRITONXPU_OTHER_SIM" in os.environ: 

1509 del os.environ["TRITONXPU_OTHER_SIM"] 

1510 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

1511 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

1512 if "TRITONXPU_INTERLEAVE" in os.environ: 

1513 del os.environ["TRITONXPU_INTERLEAVE"] 

1514 return data_out[:out_size], inverse_indices, counts 

1515 

1516 

1517def _unique2( 

1518 in0: torch.Tensor, 

1519 sorted: bool = True, 

1520 return_inverse: bool = False, 

1521 return_counts: bool = False, 

1522): 

1523 if in0.numel() <= 8192: 

1524 # print("simple_unique_flat") 

1525 sorted_data, sorted_indices = torch.sort(in0.ravel()) 

1526 data_out, inverse_indices, counts = simple_unique_flat( 

1527 sorted_data, sorted_indices, return_inverse, return_counts 

1528 ) 

1529 elif return_inverse: 

1530 # print("sorted_indices_unique_flat") 

1531 sorted_data, sorted_indices = torch.sort(in0.ravel()) 

1532 data_out, inverse_indices, counts = sorted_indices_unique_flat( 

1533 sorted_data, sorted_indices, return_counts 

1534 ) 

1535 else: 

1536 # print("sorted_quick_unique_flat") 

1537 sorted_data, _ = torch.sort(in0.ravel()) 

1538 data_out, inverse_indices, counts = sorted_quick_unique_flat( 

1539 sorted_data, return_counts 

1540 ) 

1541 return ( 

1542 data_out, 

1543 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0), 

1544 counts, 

1545 )