Coverage for src/flag_gems/runtime/backend/_mthreads/ops/unique.py: 0%

284 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5from flag_gems.runtime import torch_device_fn 

6from flag_gems.utils import triton_lang_extension as tle 

7from flag_gems.utils.libentry import libentry 

8 

9 

10@libentry() 

11@triton.jit 

12def simple_unique_flat_kernel( 

13 sorted_data_ptr: tl.tensor, 

14 sorted_indices_ptr: tl.tensor, # in 

15 data_out_ptr: tl.tensor, 

16 inverse_indices_ptr: tl.tensor, 

17 idx_ptr: tl.tensor, 

18 unique_size_ptr: tl.tensor, # out 

19 return_inverse: tl.constexpr, 

20 return_counts: tl.constexpr, 

21 num_tasks: int, 

22 tile_size: tl.constexpr, 

23): 

24 i0 = tl.arange(0, tile_size) 

25 mask = i0 < num_tasks 

26 

27 # load 

28 a = tl.load(sorted_data_ptr + i0, mask=mask) 

29 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

30 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

31 

32 # ne & cumsum 

33 ne_result = tl.where(i0 > 0, a != b, 0) 

34 cumsum = tl.cumsum(ne_result) 

35 

36 # unique_size 

37 unique_size_mask = i0 == tile_size - 1 

38 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask) 

39 

40 # data_out: scatter_(to=cumsum, sorted_data) 

41 tl.store(data_out_ptr + cumsum, a, mask=mask) 

42 

43 # inverse_indices: scatter_(to=sorted_indices, cumsum) 

44 if return_inverse: 

45 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

46 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) 

47 

48 # idx 

49 if return_counts: 

50 idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask 

51 tl.store(idx_ptr + cumsum, i0, mask=idx_mask) 

52 

53 

54@triton.jit 

55def output_counts_flat_impl( 

56 global_pid, 

57 idx_ptr: tl.tensor, 

58 origin_num_tasks: int, # in 

59 counts_ptr: tl.tensor, # out 

60 num_tasks: int, 

61 tile_size: tl.constexpr, 

62): 

63 r = tl.arange(0, tile_size) 

64 

65 # load idx 

66 i0 = global_pid * tile_size + r 

67 mask = i0 < num_tasks 

68 idx = tl.load(idx_ptr + i0, mask=mask) 

69 

70 # load idx_next 

71 i0_next = i0 + 1 

72 next_mask = i0_next < num_tasks 

73 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

74 

75 # diff 

76 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

77 

78 # store counts 

79 tl.store(counts_ptr + i0, counts, mask=mask) 

80 

81 

82@libentry() 

83@triton.jit 

84def output_counts_flat_kernel( 

85 idx_ptr: tl.tensor, 

86 origin_num_tasks: int, # in 

87 counts_ptr: tl.tensor, # out 

88 num_tasks: int, 

89 tiles_per_cta: int, 

90 tile_size: tl.constexpr, 

91): 

92 pid = tle.program_id(0) 

93 ctas_num = tle.num_programs(0) 

94 # grid-stride-loop style kernel 

95 for j in range(0, tiles_per_cta): 

96 global_pid = pid + j * ctas_num 

97 output_counts_flat_impl( 

98 global_pid, 

99 idx_ptr, 

100 origin_num_tasks, # in 

101 counts_ptr, # out 

102 num_tasks, 

103 tile_size, 

104 ) 

105 

106 

107@triton.jit 

108def quick_output_flat_impl( 

109 global_pid, 

110 sorted_data_ptr: tl.tensor, 

111 idx_ptr: tl.tensor, 

112 origin_num_tasks: int, # in 

113 data_out_ptr: tl.tensor, 

114 counts_ptr: tl.tensor, # out 

115 num_tasks: int, 

116 tile_size: tl.constexpr, 

117): 

118 r = tl.arange(0, tile_size) 

119 

120 # load idx 

121 i0 = global_pid * tile_size + r 

122 mask = i0 < num_tasks 

123 idx = tl.load(idx_ptr + i0, mask=mask) 

124 

125 # load idx_next 

126 i0_next = i0 + 1 

127 next_mask = i0_next < num_tasks 

128 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

129 

130 # diff 

131 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

132 

133 # store counts 

134 tl.store(counts_ptr + i0, counts, mask=mask) 

135 

136 # data_out: gather(sorted_data, from=idx) 

137 sorted_data = tl.load(sorted_data_ptr + idx, mask=mask) 

138 tl.store(data_out_ptr + i0, sorted_data, mask=mask) 

139 

140 

141@libentry() 

142@triton.jit 

143def quick_output_flat_kernel( 

144 sorted_data_ptr: tl.tensor, 

145 idx_ptr: tl.tensor, 

146 origin_num_tasks: int, # in 

147 data_out_ptr: tl.tensor, 

148 counts_ptr: tl.tensor, # out 

149 num_tasks: int, 

150 tiles_per_cta: int, 

151 tile_size: tl.constexpr, 

152): 

153 pid = tle.program_id(0) 

154 ctas_num = tle.num_programs(0) 

155 # grid-stride-loop style kernel 

156 for j in range(0, tiles_per_cta): 

157 global_pid = pid + j * ctas_num 

158 quick_output_flat_impl( 

159 global_pid, 

160 sorted_data_ptr, 

161 idx_ptr, 

162 origin_num_tasks, # in 

163 data_out_ptr, 

164 counts_ptr, # out 

165 num_tasks, 

166 tile_size, 

167 ) 

168 

169 

170@triton.jit 

171def local_quick_unique_flat_impl( 

172 global_pid, 

173 sorted_data_ptr: tl.tensor, # in 

174 local_unique_ptr: tl.tensor, 

175 origin_idx_ptr: tl.tensor, 

176 tile_sum_ptr: tl.tensor, # out 

177 global_ctas_num: int, 

178 num_tasks: int, 

179 tile_size: tl.constexpr, 

180 return_counts: tl.constexpr, 

181): 

182 offset = global_pid * tile_size 

183 r = tl.arange(0, tile_size) 

184 i0 = offset + r 

185 mask = i0 < num_tasks 

186 

187 # load 

188 a = tl.load(sorted_data_ptr + i0, mask=mask) 

189 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

190 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

191 

192 # ne & cumsum 

193 ne_result = tl.where(i0 > 0, a != b, 0) 

194 cumsum = tl.cumsum(ne_result) 

195 

196 # local_id or local_unique 

197 local_unique_offset = cumsum - tl.where(global_pid > 0, 1, 0) 

198 local_unique_mask = (local_unique_offset >= 0) & mask 

199 if return_counts: 

200 # origin_idx: scatter_(to=cumsum, i0) 

201 origin_idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & local_unique_mask 

202 tl.store( 

203 origin_idx_ptr + (offset + local_unique_offset), 

204 i0, 

205 mask=origin_idx_mask, 

206 ) 

207 else: 

208 # local_unique: scatter_(to=cumsum, sorted_data) 

209 tl.store( 

210 local_unique_ptr + (offset + local_unique_offset), a, mask=local_unique_mask 

211 ) 

212 

213 # tile_sum 

214 tile_sum_mask = (r == tile_size - 1) & (global_pid < global_ctas_num) 

215 tile_sum = tl.where(tile_sum_mask & (global_pid == 0), cumsum + 1, cumsum) 

216 tl.store(tile_sum_ptr + global_pid + tl.zeros_like(r), tile_sum, mask=tile_sum_mask) 

217 

218 

219@libentry() 

220@triton.jit 

221def local_quick_unique_flat_kernel( 

222 sorted_data_ptr: tl.tensor, # in 

223 local_unique_ptr: tl.tensor, 

224 origin_idx_ptr: tl.tensor, 

225 tile_sum_ptr: tl.tensor, # out 

226 global_ctas_num: int, 

227 num_tasks: int, 

228 tiles_per_cta: int, 

229 tile_size: tl.constexpr, 

230 return_counts: tl.constexpr, 

231): 

232 pid = tle.program_id(0) 

233 ctas_num = tle.num_programs(0) 

234 # grid-stride-loop style kernel 

235 for j in range(0, tiles_per_cta): 

236 global_pid = pid + j * ctas_num 

237 local_quick_unique_flat_impl( 

238 global_pid, 

239 sorted_data_ptr, # in 

240 local_unique_ptr, 

241 origin_idx_ptr, 

242 tile_sum_ptr, # out 

243 global_ctas_num, 

244 num_tasks, 

245 tile_size, 

246 return_counts, 

247 ) 

248 

249 

250@triton.jit 

251def global_quick_unique_flat_impl( 

252 global_pid, 

253 total, 

254 local_unique_ptr: tl.tensor, 

255 origin_idx_ptr: tl.tensor, 

256 tile_sum_ptr: tl.tensor, # in 

257 data_out_ptr: tl.tensor, 

258 idx_ptr: tl.tensor, # out 

259 ctas_num: int, 

260 global_ctas_num: int, 

261 next_power_global_ctas_num: tl.constexpr, 

262 num_tasks: int, 

263 tile_size: tl.constexpr, 

264 return_counts: tl.constexpr, 

265): 

266 r = tl.arange(0, tile_size) 

267 i0 = global_pid * tile_size + r 

268 mask = i0 < num_tasks 

269 

270 # load tile_sum 

271 p = tl.arange(0, next_power_global_ctas_num) 

272 pre_tile_sum_mask = ( 

273 (p >= global_pid - ctas_num) 

274 & (p < global_pid) 

275 & (p >= 0) 

276 & (p < global_ctas_num) 

277 ) 

278 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

279 cur_tile_sum_mask = global_pid < global_ctas_num 

280 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask) 

281 

282 # total 

283 total += tl.sum(pre_tile_sum) 

284 if global_pid == global_ctas_num - 1: 

285 last_tile_sum_mask = p == global_pid 

286 tl.store(tile_sum_ptr + p, total + cur_tile_sum, mask=last_tile_sum_mask) 

287 

288 # idx or data_out 

289 tile_mask = r < cur_tile_sum 

290 out_offset = total + r 

291 if return_counts: 

292 # move origin_idx to idx_ptr 

293 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask) 

294 tl.store(idx_ptr + out_offset, origin_idx, mask=tile_mask) 

295 else: 

296 # move local_unique to data_out_ptr 

297 local_unique = tl.load(local_unique_ptr + i0, mask=mask) 

298 tl.store(data_out_ptr + out_offset, local_unique, mask=tile_mask) 

299 

300 return total 

301 

302 

303@libentry() 

304@triton.jit 

305def global_quick_unique_flat_kernel( 

306 local_unique_ptr: tl.tensor, 

307 origin_idx_ptr: tl.tensor, 

308 tile_sum_ptr: tl.tensor, # in 

309 data_out_ptr: tl.tensor, 

310 idx_ptr: tl.tensor, # out 

311 ctas_num: int, 

312 global_ctas_num: int, 

313 next_power_global_ctas_num: tl.constexpr, 

314 num_tasks: int, 

315 tiles_per_cta: int, 

316 tile_size: tl.constexpr, 

317 one_tile_per_cta: tl.constexpr, 

318 return_counts: tl.constexpr, 

319): 

320 pid = tle.program_id(0) 

321 ctas_num = tle.num_programs(0) 

322 if one_tile_per_cta: # monolitic kernel style 

323 global_quick_unique_flat_impl( 

324 pid, 

325 0, 

326 local_unique_ptr, 

327 origin_idx_ptr, 

328 tile_sum_ptr, # in 

329 data_out_ptr, 

330 idx_ptr, # out 

331 ctas_num, 

332 global_ctas_num, 

333 next_power_global_ctas_num, 

334 num_tasks, 

335 tile_size, 

336 return_counts, 

337 ) 

338 else: # grid-stride-loop style kernel 

339 total = tl.zeros([1], dtype=tl.int64) 

340 for j in range(0, tiles_per_cta): 

341 global_pid = pid + j * ctas_num 

342 total = global_quick_unique_flat_impl( 

343 global_pid, 

344 total, 

345 local_unique_ptr, 

346 origin_idx_ptr, 

347 tile_sum_ptr, # in 

348 data_out_ptr, 

349 idx_ptr, # out 

350 ctas_num, 

351 global_ctas_num, 

352 next_power_global_ctas_num, 

353 num_tasks, 

354 tile_size, 

355 return_counts, 

356 ) 

357 

358 

359def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool): 

360 num_tasks = sorted_data.numel() 

361 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

362 tile_size = min(8192, next_power_num_tasks) 

363 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

364 if global_ctas_num <= 8192: 

365 tile_size = max( 

366 32, min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks) 

367 ) 

368 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

369 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

370 ctas_num = global_ctas_num if global_ctas_num < 65536 else 2048 

371 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

372 num_warps = 8 if tiles_per_cta == 1 else 8 

373 grid = (ctas_num, 1, 1) 

374 

375 # allocate tensor 

376 if return_counts: 

377 local_unique = None 

378 origin_idx = torch.empty_like(sorted_data, dtype=torch.int64) 

379 idx = torch.empty_like(origin_idx) 

380 else: 

381 local_unique = torch.empty_like(sorted_data) 

382 origin_idx = None 

383 idx = None 

384 counts = None 

385 tile_sum = torch.empty( 

386 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device 

387 ) 

388 data_out = None 

389 if not return_counts: 

390 data_out = torch.empty_like(sorted_data) 

391 

392 # launch kernel 

393 with torch_device_fn.device(sorted_data.device.index): 

394 local_quick_unique_flat_kernel[grid]( 

395 sorted_data, # in 

396 local_unique, 

397 origin_idx, 

398 tile_sum, # out 

399 global_ctas_num, 

400 num_tasks, 

401 tiles_per_cta=tiles_per_cta, 

402 tile_size=tile_size, 

403 return_counts=return_counts, 

404 num_warps=num_warps, 

405 ) 

406 global_quick_unique_flat_kernel[grid]( 

407 local_unique, 

408 origin_idx, 

409 tile_sum, # in 

410 data_out, 

411 idx, # out 

412 ctas_num, 

413 global_ctas_num, 

414 next_power_global_ctas_num, 

415 num_tasks, 

416 tiles_per_cta=tiles_per_cta, 

417 tile_size=tile_size, 

418 one_tile_per_cta=tiles_per_cta == 1, 

419 return_counts=return_counts, 

420 num_warps=num_warps, 

421 ) 

422 out_size = tile_sum[-1].item() 

423 if return_counts: 

424 data_out = torch.empty( 

425 (out_size,), dtype=sorted_data.dtype, device=sorted_data.device 

426 ) 

427 idx = idx[:out_size] 

428 counts = origin_idx[:out_size] 

429 quick_output_flat_kernel[grid]( 

430 sorted_data, 

431 idx, 

432 num_tasks, # in 

433 data_out, 

434 counts, # out 

435 out_size, 

436 tiles_per_cta, 

437 tile_size, 

438 num_warps=num_warps, 

439 ) 

440 

441 if return_counts: 

442 return data_out, None, counts 

443 else: 

444 return data_out[:out_size], None, None 

445 

446 

447@triton.jit 

448def local_ne_flat_impl( 

449 global_pid, 

450 sorted_data_ptr: tl.tensor, # in 

451 ne_result_ptr: tl.tensor, 

452 tile_sum_ptr: tl.tensor, # out 

453 global_ctas_num: int, 

454 num_tasks: int, 

455 tile_size: tl.constexpr, 

456): 

457 r = tl.arange(0, tile_size) 

458 i0 = global_pid * tile_size + r 

459 mask = i0 < num_tasks 

460 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

461 

462 # load 

463 a = tl.load(sorted_data_ptr + i0, mask=mask) 

464 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

465 

466 # compute 

467 ne_result = tl.where(i0 > 0, a != b, 0) 

468 

469 # store ne_result 

470 tl.store(ne_result_ptr + i0, ne_result, mask=mask) 

471 

472 # store tile_sum 

473 tile_sum = tl.sum(ne_result) 

474 tile_sum_mask = global_pid < global_ctas_num 

475 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask) 

476 

477 

478@libentry() 

479@triton.jit 

480def local_ne_flat_kernel( 

481 sorted_data_ptr: tl.tensor, # in 

482 ne_result_ptr: tl.tensor, 

483 tile_sum_ptr: tl.tensor, # out 

484 global_ctas_num: int, 

485 num_tasks: int, 

486 tiles_per_cta: int, 

487 tile_size: tl.constexpr, 

488): 

489 pid = tle.program_id(0) 

490 ctas_num = tle.num_programs(0) 

491 # grid-stride-loop style kernel 

492 for j in range(0, tiles_per_cta): 

493 global_pid = pid + j * ctas_num 

494 local_ne_flat_impl( 

495 global_pid, 

496 sorted_data_ptr, # in 

497 ne_result_ptr, 

498 tile_sum_ptr, # out 

499 global_ctas_num, 

500 num_tasks, 

501 tile_size, 

502 ) 

503 

504 

505@triton.jit 

506def global_cumsum_flat_impl( 

507 global_pid, 

508 total, 

509 ne_result_ptr: tl.tensor, 

510 tile_sum_ptr: tl.tensor, # in 

511 sorted_data_ptr: tl.tensor, 

512 sorted_indices_ptr: tl.tensor, # in 

513 data_out_ptr: tl.tensor, 

514 inverse_indices_ptr: tl.tensor, 

515 idx_ptr: tl.tensor, # out 

516 ctas_num: tl.constexpr, 

517 global_ctas_num: int, 

518 next_power_global_ctas_num: tl.constexpr, 

519 num_tasks: int, 

520 tile_size: tl.constexpr, 

521 return_counts: tl.constexpr, 

522): 

523 offset = global_pid * tile_size 

524 r = tl.arange(0, tile_size) 

525 i0 = offset + r 

526 mask = i0 < num_tasks 

527 

528 # load sorted_data, sorted_indices 

529 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask) 

530 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

531 

532 # load tile_sum 

533 p = tl.arange(0, next_power_global_ctas_num) 

534 pre_tile_sum_mask = ( 

535 (p >= global_pid - ctas_num) 

536 & (p < global_pid) 

537 & (p >= 0) 

538 & (p < global_ctas_num) 

539 ) 

540 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

541 

542 # cumsum 

543 total += tl.sum(pre_tile_sum) 

544 ne_result = tl.load(ne_result_ptr + i0, mask=mask) 

545 ne_result_i1 = ne_result.to(tl.int1) 

546 ne_result = ne_result.to(tl.int32) 

547 cumsum = tl.cumsum(ne_result) 

548 

549 # tile_sum 

550 if global_pid == global_ctas_num - 1: 

551 last_tile_sum_mask = i0 == num_tasks - 1 

552 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum) 

553 tl.store( 

554 tile_sum_ptr + global_pid + tl.zeros_like(r), 

555 tile_sum, 

556 mask=last_tile_sum_mask, 

557 ) 

558 cumsum += total 

559 

560 # data_out: scatter_(to=cumsum, sorted_data) 

561 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask) 

562 

563 # inverse_indices: scatter_(to=sorted_indices, cumsum) 

564 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) 

565 

566 # idx 

567 if return_counts: 

568 idx_mask = ((i0 == 0) | ne_result_i1) & mask 

569 tl.store(idx_ptr + cumsum, i0, mask=idx_mask) 

570 

571 return total 

572 

573 

574@libentry() 

575@triton.jit 

576def global_cumsum_flat_kernel( 

577 ne_result_ptr: tl.tensor, 

578 tile_sum_ptr: tl.tensor, # in 

579 sorted_data_ptr: tl.tensor, 

580 sorted_indices_ptr: tl.tensor, # in 

581 data_out_ptr: tl.tensor, 

582 inverse_indices_ptr: tl.tensor, 

583 idx_ptr: tl.tensor, # out 

584 ctas_num: int, 

585 global_ctas_num: int, 

586 next_power_global_ctas_num: tl.constexpr, 

587 num_tasks: int, 

588 tiles_per_cta: int, 

589 tile_size: tl.constexpr, 

590 one_tile_per_cta: tl.constexpr, 

591 return_counts: tl.constexpr, 

592): 

593 pid = tle.program_id(0) 

594 ctas_num = tle.num_programs(0) 

595 if one_tile_per_cta: # monolitic kernel style 

596 global_cumsum_flat_impl( 

597 pid, 

598 0, 

599 ne_result_ptr, 

600 tile_sum_ptr, # in 

601 sorted_data_ptr, 

602 sorted_indices_ptr, # in 

603 data_out_ptr, 

604 inverse_indices_ptr, 

605 idx_ptr, # out 

606 ctas_num, 

607 global_ctas_num, 

608 next_power_global_ctas_num, 

609 num_tasks, 

610 tile_size, 

611 return_counts, 

612 ) 

613 else: # grid-stride-loop style kernel 

614 total = tl.zeros([1], dtype=tl.int64) 

615 for j in range(0, tiles_per_cta): 

616 global_pid = pid + j * ctas_num 

617 total = global_cumsum_flat_impl( 

618 global_pid, 

619 total, 

620 ne_result_ptr, 

621 tile_sum_ptr, # in 

622 sorted_data_ptr, 

623 sorted_indices_ptr, # in 

624 data_out_ptr, 

625 inverse_indices_ptr, 

626 idx_ptr, # out 

627 ctas_num, 

628 global_ctas_num, 

629 next_power_global_ctas_num, 

630 num_tasks, 

631 tile_size, 

632 return_counts, 

633 ) 

634 

635 

636def sorted_indices_unique_flat( 

637 sorted_data: torch.Tensor, sorted_indices: torch.Tensor, return_counts: bool 

638): 

639 num_tasks = sorted_data.numel() 

640 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

641 tile_size = min(8192, next_power_num_tasks) 

642 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

643 if global_ctas_num <= 8192: 

644 min_tile_size = 512 if global_ctas_num > 32 else 256 

645 tile_size = max( 

646 min_tile_size, 

647 min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks), 

648 ) 

649 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

650 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

651 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192 

652 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

653 num_warps = 8 if tiles_per_cta == 1 else 8 

654 grid = (ctas_num, 1, 1) 

655 

656 # allocate tensor 

657 ne_result = torch.empty_like(sorted_data, dtype=torch.bool) 

658 tile_sum = torch.empty( 

659 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device 

660 ) 

661 data_out = torch.empty_like(sorted_data) 

662 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) 

663 idx = None 

664 if return_counts: 

665 idx = torch.empty_like(inverse_indices) 

666 

667 # launch kernel 

668 with torch_device_fn.device(sorted_data.device.index): 

669 local_ne_flat_kernel[grid]( 

670 sorted_data, # in 

671 ne_result, 

672 tile_sum, # out 

673 global_ctas_num, 

674 num_tasks, 

675 tiles_per_cta=tiles_per_cta, 

676 tile_size=tile_size, 

677 num_warps=num_warps, 

678 ) 

679 global_cumsum_flat_kernel[grid]( 

680 ne_result, 

681 tile_sum, # in 

682 sorted_data, 

683 sorted_indices, # in 

684 data_out, 

685 inverse_indices, 

686 idx, # out 

687 ctas_num, 

688 global_ctas_num, 

689 next_power_global_ctas_num, 

690 num_tasks, 

691 tiles_per_cta=tiles_per_cta, 

692 tile_size=tile_size, 

693 one_tile_per_cta=tiles_per_cta == 1, 

694 return_counts=return_counts, 

695 num_warps=num_warps, 

696 ) 

697 out_size = tile_sum[-1].item() + 1 

698 counts = None 

699 if return_counts: 

700 idx = idx[:out_size] 

701 counts = torch.empty_like(idx) 

702 output_counts_flat_kernel[grid]( 

703 idx, 

704 num_tasks, # in 

705 counts, # out 

706 out_size, 

707 tiles_per_cta, 

708 tile_size, 

709 num_warps=num_warps, 

710 ) 

711 

712 return data_out[:out_size], inverse_indices, counts 

713 

714 

715def simple_unique_flat( 

716 sorted_data: torch.Tensor, 

717 sorted_indices: torch.Tensor, 

718 return_inverse: bool, 

719 return_counts: bool, 

720): 

721 num_tasks = sorted_data.numel() 

722 grid = (1, 1, 1) 

723 

724 # allocate tensor 

725 data_out = torch.empty_like(sorted_data) 

726 if return_inverse: 

727 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) 

728 else: 

729 inverse_indices = None 

730 if return_counts: 

731 idx = torch.empty_like(sorted_data, dtype=torch.int64) 

732 else: 

733 idx = None 

734 unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device) 

735 

736 # launch kernel 

737 with torch_device_fn.device(sorted_data.device.index): 

738 simple_unique_flat_kernel[grid]( 

739 sorted_data, 

740 sorted_indices, # in 

741 data_out, 

742 inverse_indices, 

743 idx, 

744 unique_size, # out 

745 return_inverse, 

746 return_counts, 

747 num_tasks, 

748 tile_size=triton.next_power_of_2(num_tasks), 

749 num_warps=8, 

750 ) 

751 out_size = unique_size.item() + 1 

752 counts = None 

753 if return_counts: 

754 idx = idx[:out_size] 

755 counts = torch.empty_like(idx) 

756 with torch_device_fn.device(sorted_data.device.index): 

757 output_counts_flat_kernel[grid]( 

758 idx, 

759 num_tasks, # in 

760 counts, # out 

761 num_tasks=out_size, 

762 tiles_per_cta=1, 

763 tile_size=triton.next_power_of_2(out_size), 

764 num_warps=8, 

765 ) 

766 return data_out[:out_size], inverse_indices, counts 

767 

768 

769def _unique2( 

770 in0: torch.Tensor, 

771 sorted: bool = True, 

772 return_inverse: bool = False, 

773 return_counts: bool = False, 

774): 

775 if in0.numel() <= 8192: 

776 sorted_data, sorted_indices = torch.sort(in0.ravel()) 

777 data_out, inverse_indices, counts = simple_unique_flat( 

778 sorted_data, sorted_indices, return_inverse, return_counts 

779 ) 

780 elif return_inverse: 

781 sorted_data, sorted_indices = torch.sort(in0.ravel()) 

782 data_out, inverse_indices, counts = sorted_indices_unique_flat( 

783 sorted_data, sorted_indices, return_counts 

784 ) 

785 else: 

786 sorted_data, _ = torch.sort(in0.ravel()) 

787 data_out, inverse_indices, counts = sorted_quick_unique_flat( 

788 sorted_data, return_counts 

789 ) 

790 return ( 

791 data_out, 

792 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0), 

793 counts, 

794 )