Coverage for src/flag_gems/runtime/backend/_hygon/ops/unique.py: 0%

287 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as tle 

9from flag_gems.utils.libentry import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def simple_unique_flat_kernel( 

17 sorted_data_ptr: tl.tensor, 

18 sorted_indices_ptr: tl.tensor, # in 

19 data_out_ptr: tl.tensor, 

20 inverse_indices_ptr: tl.tensor, 

21 idx_ptr: tl.tensor, 

22 unique_size_ptr: tl.tensor, # out 

23 return_inverse: tl.constexpr, 

24 return_counts: tl.constexpr, 

25 num_tasks: int, 

26 tile_size: tl.constexpr, 

27): 

28 i0 = tl.arange(0, tile_size) 

29 mask = i0 < num_tasks 

30 

31 # load 

32 a = tl.load(sorted_data_ptr + i0, mask=mask) 

33 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

34 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

35 

36 # ne & cumsum 

37 ne_result = tl.where(i0 > 0, a != b, 0) 

38 cumsum = tl.cumsum(ne_result) 

39 

40 # unique_size 

41 unique_size_mask = i0 == tile_size - 1 

42 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask) 

43 

44 # data_out: scatter_(to=cumsum, sorted_data) 

45 tl.store(data_out_ptr + cumsum, a, mask=mask) 

46 

47 # inverse_indices: scatter_(to=sorted_indices, cumsum) 

48 if return_inverse: 

49 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

50 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) 

51 

52 # idx 

53 if return_counts: 

54 idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask 

55 tl.store(idx_ptr + cumsum, i0, mask=idx_mask) 

56 

57 

58@triton.jit 

59def output_counts_flat_impl( 

60 global_pid, 

61 idx_ptr: tl.tensor, 

62 origin_num_tasks: int, # in 

63 counts_ptr: tl.tensor, # out 

64 num_tasks: int, 

65 tile_size: tl.constexpr, 

66): 

67 r = tl.arange(0, tile_size) 

68 

69 # load idx 

70 i0 = global_pid * tile_size + r 

71 mask = i0 < num_tasks 

72 idx = tl.load(idx_ptr + i0, mask=mask) 

73 

74 # load idx_next 

75 i0_next = i0 + 1 

76 next_mask = i0_next < num_tasks 

77 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

78 

79 # diff 

80 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

81 

82 # store counts 

83 tl.store(counts_ptr + i0, counts, mask=mask) 

84 

85 

86@libentry() 

87@triton.jit 

88def output_counts_flat_kernel( 

89 idx_ptr: tl.tensor, 

90 origin_num_tasks: int, # in 

91 counts_ptr: tl.tensor, # out 

92 num_tasks: int, 

93 tiles_per_cta: int, 

94 tile_size: tl.constexpr, 

95): 

96 pid = tle.program_id(0) 

97 ctas_num = tle.num_programs(0) 

98 # grid-stride-loop style kernel 

99 for j in range(0, tiles_per_cta): 

100 global_pid = pid + j * ctas_num 

101 output_counts_flat_impl( 

102 global_pid, 

103 idx_ptr, 

104 origin_num_tasks, # in 

105 counts_ptr, # out 

106 num_tasks, 

107 tile_size, 

108 ) 

109 

110 

111@triton.jit 

112def quick_output_flat_impl( 

113 global_pid, 

114 sorted_data_ptr: tl.tensor, 

115 idx_ptr: tl.tensor, 

116 origin_num_tasks: int, # in 

117 data_out_ptr: tl.tensor, 

118 counts_ptr: tl.tensor, # out 

119 num_tasks: int, 

120 tile_size: tl.constexpr, 

121): 

122 r = tl.arange(0, tile_size) 

123 

124 # load idx 

125 i0 = global_pid * tile_size + r 

126 mask = i0 < num_tasks 

127 idx = tl.load(idx_ptr + i0, mask=mask) 

128 

129 # load idx_next 

130 i0_next = i0 + 1 

131 next_mask = i0_next < num_tasks 

132 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

133 

134 # diff 

135 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

136 

137 # store counts 

138 tl.store(counts_ptr + i0, counts, mask=mask) 

139 

140 # data_out: gather(sorted_data, from=idx) 

141 sorted_data = tl.load(sorted_data_ptr + idx, mask=mask) 

142 tl.store(data_out_ptr + i0, sorted_data, mask=mask) 

143 

144 

145@libentry() 

146@triton.jit 

147def quick_output_flat_kernel( 

148 sorted_data_ptr: tl.tensor, 

149 idx_ptr: tl.tensor, 

150 origin_num_tasks: int, # in 

151 data_out_ptr: tl.tensor, 

152 counts_ptr: tl.tensor, # out 

153 num_tasks: int, 

154 tiles_per_cta: int, 

155 tile_size: tl.constexpr, 

156): 

157 pid = tle.program_id(0) 

158 ctas_num = tle.num_programs(0) 

159 # grid-stride-loop style kernel 

160 for j in range(0, tiles_per_cta): 

161 global_pid = pid + j * ctas_num 

162 quick_output_flat_impl( 

163 global_pid, 

164 sorted_data_ptr, 

165 idx_ptr, 

166 origin_num_tasks, # in 

167 data_out_ptr, 

168 counts_ptr, # out 

169 num_tasks, 

170 tile_size, 

171 ) 

172 

173 

174@triton.jit 

175def local_quick_unique_flat_impl( 

176 global_pid, 

177 sorted_data_ptr: tl.tensor, # in 

178 local_unique_ptr: tl.tensor, 

179 origin_idx_ptr: tl.tensor, 

180 tile_sum_ptr: tl.tensor, # out 

181 global_ctas_num: int, 

182 num_tasks: int, 

183 tile_size: tl.constexpr, 

184 return_counts: tl.constexpr, 

185): 

186 offset = global_pid * tile_size 

187 r = tl.arange(0, tile_size) 

188 i0 = offset + r 

189 mask = i0 < num_tasks 

190 

191 # load 

192 a = tl.load(sorted_data_ptr + i0, mask=mask) 

193 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

194 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

195 

196 # ne & cumsum 

197 ne_result = tl.where(i0 > 0, a != b, 0) 

198 cumsum = tl.cumsum(ne_result) 

199 

200 # local_id or local_unique 

201 local_unique_offset = cumsum - tl.where(global_pid > 0, 1, 0) 

202 local_unique_mask = (local_unique_offset >= 0) & mask 

203 if return_counts: 

204 # origin_idx: scatter_(to=cumsum, i0) 

205 origin_idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & local_unique_mask 

206 tl.store( 

207 origin_idx_ptr + (offset + local_unique_offset), 

208 i0, 

209 mask=origin_idx_mask, 

210 ) 

211 else: 

212 # local_unique: scatter_(to=cumsum, sorted_data) 

213 tl.store( 

214 local_unique_ptr + (offset + local_unique_offset), a, mask=local_unique_mask 

215 ) 

216 

217 # tile_sum 

218 tile_sum_mask = (r == tile_size - 1) & (global_pid < global_ctas_num) 

219 tile_sum = tl.where(tile_sum_mask & (global_pid == 0), cumsum + 1, cumsum) 

220 tl.store(tile_sum_ptr + global_pid + tl.zeros_like(r), tile_sum, mask=tile_sum_mask) 

221 

222 

223@libentry() 

224@triton.jit 

225def local_quick_unique_flat_kernel( 

226 sorted_data_ptr: tl.tensor, # in 

227 local_unique_ptr: tl.tensor, 

228 origin_idx_ptr: tl.tensor, 

229 tile_sum_ptr: tl.tensor, # out 

230 global_ctas_num: int, 

231 num_tasks: int, 

232 tiles_per_cta: int, 

233 tile_size: tl.constexpr, 

234 return_counts: tl.constexpr, 

235): 

236 pid = tle.program_id(0) 

237 ctas_num = tle.num_programs(0) 

238 # grid-stride-loop style kernel 

239 for j in range(0, tiles_per_cta): 

240 global_pid = pid + j * ctas_num 

241 local_quick_unique_flat_impl( 

242 global_pid, 

243 sorted_data_ptr, # in 

244 local_unique_ptr, 

245 origin_idx_ptr, 

246 tile_sum_ptr, # out 

247 global_ctas_num, 

248 num_tasks, 

249 tile_size, 

250 return_counts, 

251 ) 

252 

253 

254@triton.jit 

255def global_quick_unique_flat_impl( 

256 global_pid, 

257 total, 

258 local_unique_ptr: tl.tensor, 

259 origin_idx_ptr: tl.tensor, 

260 tile_sum_ptr: tl.tensor, # in 

261 data_out_ptr: tl.tensor, 

262 idx_ptr: tl.tensor, # out 

263 ctas_num: int, 

264 global_ctas_num: int, 

265 next_power_global_ctas_num: tl.constexpr, 

266 num_tasks: int, 

267 tile_size: tl.constexpr, 

268 return_counts: tl.constexpr, 

269): 

270 r = tl.arange(0, tile_size) 

271 i0 = global_pid * tile_size + r 

272 mask = i0 < num_tasks 

273 

274 # load tile_sum 

275 p = tl.arange(0, next_power_global_ctas_num) 

276 pre_tile_sum_mask = ( 

277 (p >= global_pid - ctas_num) 

278 & (p < global_pid) 

279 & (p >= 0) 

280 & (p < global_ctas_num) 

281 ) 

282 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

283 cur_tile_sum_mask = global_pid < global_ctas_num 

284 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask) 

285 

286 # total 

287 total += tl.sum(pre_tile_sum) 

288 if global_pid == global_ctas_num - 1: 

289 last_tile_sum_mask = p == global_pid 

290 tl.store(tile_sum_ptr + p, total + cur_tile_sum, mask=last_tile_sum_mask) 

291 

292 # idx or data_out 

293 tile_mask = r < cur_tile_sum 

294 out_offset = total + r 

295 if return_counts: 

296 # move origin_idx to idx_ptr 

297 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask) 

298 tl.store(idx_ptr + out_offset, origin_idx, mask=tile_mask) 

299 else: 

300 # move local_unique to data_out_ptr 

301 local_unique = tl.load(local_unique_ptr + i0, mask=mask) 

302 tl.store(data_out_ptr + out_offset, local_unique, mask=tile_mask) 

303 

304 return total 

305 

306 

307@libentry() 

308@triton.jit 

309def global_quick_unique_flat_kernel( 

310 local_unique_ptr: tl.tensor, 

311 origin_idx_ptr: tl.tensor, 

312 tile_sum_ptr: tl.tensor, # in 

313 data_out_ptr: tl.tensor, 

314 idx_ptr: tl.tensor, # out 

315 ctas_num: int, 

316 global_ctas_num: int, 

317 next_power_global_ctas_num: tl.constexpr, 

318 num_tasks: int, 

319 tiles_per_cta: int, 

320 tile_size: tl.constexpr, 

321 one_tile_per_cta: tl.constexpr, 

322 return_counts: tl.constexpr, 

323): 

324 pid = tle.program_id(0) 

325 ctas_num = tle.num_programs(0) 

326 if one_tile_per_cta: # monolitic kernel style 

327 global_quick_unique_flat_impl( 

328 pid, 

329 0, 

330 local_unique_ptr, 

331 origin_idx_ptr, 

332 tile_sum_ptr, # in 

333 data_out_ptr, 

334 idx_ptr, # out 

335 ctas_num, 

336 global_ctas_num, 

337 next_power_global_ctas_num, 

338 num_tasks, 

339 tile_size, 

340 return_counts, 

341 ) 

342 else: # grid-stride-loop style kernel 

343 total = tl.zeros([1], dtype=tl.int64) 

344 for j in range(0, tiles_per_cta): 

345 global_pid = pid + j * ctas_num 

346 total = global_quick_unique_flat_impl( 

347 global_pid, 

348 total, 

349 local_unique_ptr, 

350 origin_idx_ptr, 

351 tile_sum_ptr, # in 

352 data_out_ptr, 

353 idx_ptr, # out 

354 ctas_num, 

355 global_ctas_num, 

356 next_power_global_ctas_num, 

357 num_tasks, 

358 tile_size, 

359 return_counts, 

360 ) 

361 

362 

363def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool): 

364 num_tasks = sorted_data.numel() 

365 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

366 tile_size = min(8192, next_power_num_tasks) 

367 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

368 if global_ctas_num <= 8192: 

369 tile_size = max( 

370 32, min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks) 

371 ) 

372 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

373 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

374 ctas_num = global_ctas_num if global_ctas_num < 65536 else 2048 

375 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

376 num_warps = 8 if tiles_per_cta == 1 else 32 

377 grid = (ctas_num, 1, 1) 

378 

379 # allocate tensor 

380 if return_counts: 

381 local_unique = None 

382 origin_idx = torch.empty_like(sorted_data, dtype=torch.int64) 

383 idx = torch.empty_like(origin_idx) 

384 else: 

385 local_unique = torch.empty_like(sorted_data) 

386 origin_idx = None 

387 idx = None 

388 counts = None 

389 tile_sum = torch.empty( 

390 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device 

391 ) 

392 data_out = None 

393 if not return_counts: 

394 data_out = torch.empty_like(sorted_data) 

395 

396 # launch kernel 

397 with torch_device_fn.device(sorted_data.device.index): 

398 local_quick_unique_flat_kernel[grid]( 

399 sorted_data, # in 

400 local_unique, 

401 origin_idx, 

402 tile_sum, # out 

403 global_ctas_num, 

404 num_tasks, 

405 tiles_per_cta=tiles_per_cta, 

406 tile_size=tile_size, 

407 return_counts=return_counts, 

408 num_warps=num_warps, 

409 ) 

410 global_quick_unique_flat_kernel[grid]( 

411 local_unique, 

412 origin_idx, 

413 tile_sum, # in 

414 data_out, 

415 idx, # out 

416 ctas_num, 

417 global_ctas_num, 

418 next_power_global_ctas_num, 

419 num_tasks, 

420 tiles_per_cta=tiles_per_cta, 

421 tile_size=tile_size, 

422 one_tile_per_cta=tiles_per_cta == 1, 

423 return_counts=return_counts, 

424 num_warps=num_warps, 

425 ) 

426 out_size = tile_sum[-1].item() 

427 if return_counts: 

428 data_out = torch.empty( 

429 (out_size,), dtype=sorted_data.dtype, device=sorted_data.device 

430 ) 

431 idx = idx[:out_size] 

432 counts = origin_idx[:out_size] 

433 quick_output_flat_kernel[grid]( 

434 sorted_data, 

435 idx, 

436 num_tasks, # in 

437 data_out, 

438 counts, # out 

439 out_size, 

440 tiles_per_cta, 

441 tile_size, 

442 num_warps=num_warps, 

443 ) 

444 

445 if return_counts: 

446 return data_out, None, counts 

447 else: 

448 return data_out[:out_size], None, None 

449 

450 

451@triton.jit 

452def local_ne_flat_impl( 

453 global_pid, 

454 sorted_data_ptr: tl.tensor, # in 

455 ne_result_ptr: tl.tensor, 

456 tile_sum_ptr: tl.tensor, # out 

457 global_ctas_num: int, 

458 num_tasks: int, 

459 tile_size: tl.constexpr, 

460): 

461 r = tl.arange(0, tile_size) 

462 i0 = global_pid * tile_size + r 

463 mask = i0 < num_tasks 

464 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

465 

466 # load 

467 a = tl.load(sorted_data_ptr + i0, mask=mask) 

468 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

469 

470 # compute 

471 ne_result = tl.where(i0 > 0, a != b, 0) 

472 

473 # store ne_result 

474 tl.store(ne_result_ptr + i0, ne_result, mask=mask) 

475 

476 # store tile_sum 

477 tile_sum = tl.sum(ne_result) 

478 tile_sum_mask = global_pid < global_ctas_num 

479 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask) 

480 

481 

482@libentry() 

483@triton.jit 

484def local_ne_flat_kernel( 

485 sorted_data_ptr: tl.tensor, # in 

486 ne_result_ptr: tl.tensor, 

487 tile_sum_ptr: tl.tensor, # out 

488 global_ctas_num: int, 

489 num_tasks: int, 

490 tiles_per_cta: int, 

491 tile_size: tl.constexpr, 

492): 

493 pid = tle.program_id(0) 

494 ctas_num = tle.num_programs(0) 

495 # grid-stride-loop style kernel 

496 for j in range(0, tiles_per_cta): 

497 global_pid = pid + j * ctas_num 

498 local_ne_flat_impl( 

499 global_pid, 

500 sorted_data_ptr, # in 

501 ne_result_ptr, 

502 tile_sum_ptr, # out 

503 global_ctas_num, 

504 num_tasks, 

505 tile_size, 

506 ) 

507 

508 

509@triton.jit 

510def global_cumsum_flat_impl( 

511 global_pid, 

512 total, 

513 ne_result_ptr: tl.tensor, 

514 tile_sum_ptr: tl.tensor, # in 

515 sorted_data_ptr: tl.tensor, 

516 sorted_indices_ptr: tl.tensor, # in 

517 data_out_ptr: tl.tensor, 

518 inverse_indices_ptr: tl.tensor, 

519 idx_ptr: tl.tensor, # out 

520 ctas_num: tl.constexpr, 

521 global_ctas_num: int, 

522 next_power_global_ctas_num: tl.constexpr, 

523 num_tasks: int, 

524 tile_size: tl.constexpr, 

525 return_counts: tl.constexpr, 

526): 

527 offset = global_pid * tile_size 

528 r = tl.arange(0, tile_size) 

529 i0 = offset + r 

530 mask = i0 < num_tasks 

531 

532 # load sorted_data, sorted_indices 

533 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask) 

534 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

535 

536 # load tile_sum 

537 p = tl.arange(0, next_power_global_ctas_num) 

538 pre_tile_sum_mask = ( 

539 (p >= global_pid - ctas_num) 

540 & (p < global_pid) 

541 & (p >= 0) 

542 & (p < global_ctas_num) 

543 ) 

544 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

545 

546 # cumsum 

547 total += tl.sum(pre_tile_sum) 

548 ne_result = tl.load(ne_result_ptr + i0, mask=mask) 

549 ne_result_i1 = ne_result.to(tl.int1) 

550 ne_result = ne_result.to(tl.int32) 

551 cumsum = tl.cumsum(ne_result) 

552 

553 # tile_sum 

554 if global_pid == global_ctas_num - 1: 

555 last_tile_sum_mask = i0 == num_tasks - 1 

556 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum) 

557 tl.store( 

558 tile_sum_ptr + global_pid + tl.zeros_like(r), 

559 tile_sum, 

560 mask=last_tile_sum_mask, 

561 ) 

562 cumsum += total 

563 

564 # data_out: scatter_(to=cumsum, sorted_data) 

565 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask) 

566 

567 # inverse_indices: scatter_(to=sorted_indices, cumsum) 

568 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) 

569 

570 # idx 

571 if return_counts: 

572 idx_mask = ((i0 == 0) | ne_result_i1) & mask 

573 tl.store(idx_ptr + cumsum, i0, mask=idx_mask) 

574 

575 return total 

576 

577 

578@libentry() 

579@triton.jit 

580def global_cumsum_flat_kernel( 

581 ne_result_ptr: tl.tensor, 

582 tile_sum_ptr: tl.tensor, # in 

583 sorted_data_ptr: tl.tensor, 

584 sorted_indices_ptr: tl.tensor, # in 

585 data_out_ptr: tl.tensor, 

586 inverse_indices_ptr: tl.tensor, 

587 idx_ptr: tl.tensor, # out 

588 ctas_num: int, 

589 global_ctas_num: int, 

590 next_power_global_ctas_num: tl.constexpr, 

591 num_tasks: int, 

592 tiles_per_cta: int, 

593 tile_size: tl.constexpr, 

594 one_tile_per_cta: tl.constexpr, 

595 return_counts: tl.constexpr, 

596): 

597 pid = tle.program_id(0) 

598 ctas_num = tle.num_programs(0) 

599 if one_tile_per_cta: # monolitic kernel style 

600 global_cumsum_flat_impl( 

601 pid, 

602 0, 

603 ne_result_ptr, 

604 tile_sum_ptr, # in 

605 sorted_data_ptr, 

606 sorted_indices_ptr, # in 

607 data_out_ptr, 

608 inverse_indices_ptr, 

609 idx_ptr, # out 

610 ctas_num, 

611 global_ctas_num, 

612 next_power_global_ctas_num, 

613 num_tasks, 

614 tile_size, 

615 return_counts, 

616 ) 

617 else: # grid-stride-loop style kernel 

618 total = tl.zeros([1], dtype=tl.int64) 

619 for j in range(0, tiles_per_cta): 

620 global_pid = pid + j * ctas_num 

621 total = global_cumsum_flat_impl( 

622 global_pid, 

623 total, 

624 ne_result_ptr, 

625 tile_sum_ptr, # in 

626 sorted_data_ptr, 

627 sorted_indices_ptr, # in 

628 data_out_ptr, 

629 inverse_indices_ptr, 

630 idx_ptr, # out 

631 ctas_num, 

632 global_ctas_num, 

633 next_power_global_ctas_num, 

634 num_tasks, 

635 tile_size, 

636 return_counts, 

637 ) 

638 

639 

640def sorted_indices_unique_flat( 

641 sorted_data: torch.Tensor, sorted_indices: torch.Tensor, return_counts: bool 

642): 

643 num_tasks = sorted_data.numel() 

644 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

645 tile_size = min(8192, next_power_num_tasks) 

646 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

647 if global_ctas_num <= 8192: 

648 min_tile_size = 512 if global_ctas_num > 32 else 256 

649 tile_size = max( 

650 min_tile_size, 

651 min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks), 

652 ) 

653 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

654 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

655 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192 

656 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

657 num_warps = 8 if tiles_per_cta == 1 else 32 

658 grid = (ctas_num, 1, 1) 

659 

660 # allocate tensor 

661 ne_result = torch.empty_like(sorted_data, dtype=torch.bool) 

662 tile_sum = torch.empty( 

663 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device 

664 ) 

665 data_out = torch.empty_like(sorted_data) 

666 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) 

667 idx = None 

668 if return_counts: 

669 idx = torch.empty_like(inverse_indices) 

670 

671 # launch kernel 

672 with torch_device_fn.device(sorted_data.device.index): 

673 local_ne_flat_kernel[grid]( 

674 sorted_data, # in 

675 ne_result, 

676 tile_sum, # out 

677 global_ctas_num, 

678 num_tasks, 

679 tiles_per_cta=tiles_per_cta, 

680 tile_size=tile_size, 

681 num_warps=num_warps, 

682 ) 

683 global_cumsum_flat_kernel[grid]( 

684 ne_result, 

685 tile_sum, # in 

686 sorted_data, 

687 sorted_indices, # in 

688 data_out, 

689 inverse_indices, 

690 idx, # out 

691 ctas_num, 

692 global_ctas_num, 

693 next_power_global_ctas_num, 

694 num_tasks, 

695 tiles_per_cta=tiles_per_cta, 

696 tile_size=tile_size, 

697 one_tile_per_cta=tiles_per_cta == 1, 

698 return_counts=return_counts, 

699 num_warps=num_warps, 

700 ) 

701 out_size = tile_sum[-1].item() + 1 

702 counts = None 

703 if return_counts: 

704 idx = idx[:out_size] 

705 counts = torch.empty_like(idx) 

706 output_counts_flat_kernel[grid]( 

707 idx, 

708 num_tasks, # in 

709 counts, # out 

710 out_size, 

711 tiles_per_cta, 

712 tile_size, 

713 num_warps=num_warps, 

714 ) 

715 

716 return data_out[:out_size], inverse_indices, counts 

717 

718 

719def simple_unique_flat( 

720 sorted_data: torch.Tensor, 

721 sorted_indices: torch.Tensor, 

722 return_inverse: bool, 

723 return_counts: bool, 

724): 

725 num_tasks = sorted_data.numel() 

726 grid = (1, 1, 1) 

727 

728 # allocate tensor 

729 data_out = torch.empty_like(sorted_data) 

730 if return_inverse: 

731 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) 

732 else: 

733 inverse_indices = None 

734 if return_counts: 

735 idx = torch.empty_like(sorted_data, dtype=torch.int64) 

736 else: 

737 idx = None 

738 unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device) 

739 

740 # launch kernel 

741 with torch_device_fn.device(sorted_data.device.index): 

742 simple_unique_flat_kernel[grid]( 

743 sorted_data, 

744 sorted_indices, # in 

745 data_out, 

746 inverse_indices, 

747 idx, 

748 unique_size, # out 

749 return_inverse, 

750 return_counts, 

751 num_tasks, 

752 tile_size=triton.next_power_of_2(num_tasks), 

753 num_warps=8, 

754 ) 

755 out_size = unique_size.item() + 1 

756 counts = None 

757 if return_counts: 

758 idx = idx[:out_size] 

759 counts = torch.empty_like(idx) 

760 with torch_device_fn.device(sorted_data.device.index): 

761 output_counts_flat_kernel[grid]( 

762 idx, 

763 num_tasks, # in 

764 counts, # out 

765 num_tasks=out_size, 

766 tiles_per_cta=1, 

767 tile_size=triton.next_power_of_2(out_size), 

768 num_warps=8, 

769 ) 

770 return data_out[:out_size], inverse_indices, counts 

771 

772 

773def _unique2( 

774 in0: torch.Tensor, 

775 sorted: bool = True, 

776 return_inverse: bool = False, 

777 return_counts: bool = False, 

778): 

779 logger.debug("GEMS SORT") 

780 if in0.numel() <= 8192: 

781 sorted_data, sorted_indices = torch.sort(in0.ravel()) 

782 data_out, inverse_indices, counts = simple_unique_flat( 

783 sorted_data, sorted_indices, return_inverse, return_counts 

784 ) 

785 elif return_inverse: 

786 sorted_data, sorted_indices = torch.sort(in0.ravel()) 

787 data_out, inverse_indices, counts = sorted_indices_unique_flat( 

788 sorted_data, sorted_indices, return_counts 

789 ) 

790 else: 

791 sorted_data, _ = torch.sort(in0.ravel()) 

792 data_out, inverse_indices, counts = sorted_quick_unique_flat( 

793 sorted_data, return_counts 

794 ) 

795 return ( 

796 data_out, 

797 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0), 

798 counts, 

799 )