Coverage for src/flag_gems/runtime/backend/_ascend/ops/unique.py: 0%

314 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as tle 

9from flag_gems.utils.libentry import libentry 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14@libentry() 

15@triton.jit 

16def simple_unique_flat_kernel( 

17 sorted_data_ptr: tl.tensor, 

18 sorted_indices_ptr: tl.tensor, # in 

19 data_out_ptr: tl.tensor, 

20 inverse_indices_ptr: tl.tensor, 

21 idx_ptr: tl.tensor, 

22 unique_size_ptr: tl.tensor, # out 

23 return_inverse: tl.constexpr, 

24 return_counts: tl.constexpr, 

25 num_tasks: int, 

26 tile_size: tl.constexpr, 

27): 

28 i0 = tl.arange(0, tile_size) 

29 mask = i0 < num_tasks 

30 

31 # load 

32 a = tl.load(sorted_data_ptr + i0, mask=mask) 

33 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

34 b = tl.load(sorted_data_ptr + i0_prev, mask=mask) 

35 

36 # ne & cumsum 

37 ne_result = tl.where(i0 > 0, a != b, 0) 

38 cumsum = tl.cumsum(ne_result) 

39 

40 # unique_size 

41 unique_size_mask = i0 == tile_size - 1 

42 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask) 

43 

44 # data_out: scatter_(to=cumsum, sorted_data) 

45 tl.store(data_out_ptr + cumsum, a, mask=mask) 

46 

47 # inverse_indices: scatter_(to=sorted_indices, cumsum) 

48 if return_inverse: 

49 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

50 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) 

51 

52 # idx 

53 if return_counts: 

54 idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask 

55 tl.store(idx_ptr + cumsum, i0, mask=idx_mask) 

56 

57 

58@triton.jit 

59def output_counts_flat_impl( 

60 global_pid, 

61 idx_ptr: tl.tensor, 

62 origin_num_tasks: int, # in 

63 counts_ptr: tl.tensor, # out 

64 num_tasks: int, 

65 tile_size: tl.constexpr, 

66): 

67 r = tl.arange(0, tile_size) 

68 

69 # load idx 

70 i0 = global_pid * tile_size + r 

71 mask = i0 < num_tasks 

72 idx = tl.load(idx_ptr + i0, mask=mask) 

73 

74 # load idx_next 

75 i0_next = i0 + 1 

76 next_mask = i0_next < num_tasks 

77 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

78 

79 # diff 

80 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

81 

82 # store counts 

83 tl.store(counts_ptr + i0, counts, mask=mask) 

84 

85 

86@libentry() 

87@triton.jit 

88def output_counts_flat_kernel( 

89 idx_ptr: tl.tensor, 

90 origin_num_tasks: int, # in 

91 counts_ptr: tl.tensor, # out 

92 num_tasks: int, 

93 tiles_per_cta: int, 

94 tile_size: tl.constexpr, 

95): 

96 pid = tle.program_id(0) 

97 ctas_num = tle.num_programs(0) 

98 # grid-stride-loop style kernel 

99 for j in range(0, tiles_per_cta): 

100 global_pid = pid + j * ctas_num 

101 output_counts_flat_impl( 

102 global_pid, 

103 idx_ptr, 

104 origin_num_tasks, # in 

105 counts_ptr, # out 

106 num_tasks, 

107 tile_size, 

108 ) 

109 

110 

111@triton.jit 

112def quick_output_flat_impl( 

113 global_pid, 

114 sorted_data_ptr: tl.tensor, 

115 idx_ptr: tl.tensor, 

116 origin_num_tasks: int, # in 

117 data_out_ptr: tl.tensor, 

118 counts_ptr: tl.tensor, # out 

119 num_tasks: int, 

120 tile_size: tl.constexpr, 

121): 

122 r = tl.arange(0, tile_size) 

123 

124 # load idx 

125 i0 = global_pid * tile_size + r 

126 mask = i0 < num_tasks 

127 idx = tl.load(idx_ptr + i0, mask=mask) 

128 

129 # load idx_next 

130 i0_next = i0 + 1 

131 next_mask = i0_next < num_tasks 

132 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

133 

134 # diff 

135 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

136 

137 # store counts 

138 tl.store(counts_ptr + i0, counts, mask=mask) 

139 

140 # data_out: gather(sorted_data, from=idx) 

141 sorted_data = tl.load(sorted_data_ptr + idx, mask=mask) 

142 tl.store(data_out_ptr + i0, sorted_data, mask=mask) 

143 

144 

145@libentry() 

146@triton.jit 

147def quick_output_flat_kernel( 

148 sorted_data_ptr: tl.tensor, 

149 idx_ptr: tl.tensor, 

150 origin_num_tasks: int, # in 

151 data_out_ptr: tl.tensor, 

152 counts_ptr: tl.tensor, # out 

153 num_tasks: int, 

154 tiles_per_cta: int, 

155 tile_size: tl.constexpr, 

156): 

157 pid = tle.program_id(0) 

158 ctas_num = tle.num_programs(0) 

159 # grid-stride-loop style kernel 

160 for j in range(0, tiles_per_cta): 

161 global_pid = pid + j * ctas_num 

162 quick_output_flat_impl( 

163 global_pid, 

164 sorted_data_ptr, 

165 idx_ptr, 

166 origin_num_tasks, # in 

167 data_out_ptr, 

168 counts_ptr, # out 

169 num_tasks, 

170 tile_size, 

171 ) 

172 

173 

174@triton.jit 

175def local_quick_unique_flat_impl( 

176 global_pid, 

177 sorted_data_ptr: tl.tensor, # in 

178 local_unique_ptr: tl.tensor, 

179 origin_idx_ptr: tl.tensor, 

180 tile_sum_ptr: tl.tensor, # out 

181 global_ctas_num: int, 

182 num_tasks: int, 

183 tile_size: tl.constexpr, 

184 return_counts: tl.constexpr, 

185): 

186 offset = global_pid * tile_size 

187 r = tl.arange(0, tile_size) 

188 i0 = offset + r 

189 mask = i0 < num_tasks 

190 

191 # load 

192 a = tl.load(sorted_data_ptr + i0, mask=mask, other=0) 

193 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

194 b = tl.load(sorted_data_ptr + i0_prev, mask=mask, other=0) 

195 

196 # ne & cumsum 

197 # 对于 i0=0 的位置(第一个元素),ne_result 应该是 1(它是第一个唯一值) 

198 # 对于其他位置,ne_result = (a != b) 

199 ne_result = tl.where(i0 > 0, a != b, 1) 

200 ne_result = tl.where(mask, ne_result, 0) # 只保留有效位置 

201 

202 cumsum = tl.cumsum(ne_result) 

203 

204 # 对于第一个唯一值(i0=0),cumsum=1,所以索引是 0(cumsum-1) 

205 # 对于其他唯一值,cumsum 递增 

206 local_unique_offset = cumsum - 1 # cumsum 从 1 开始,所以减 1 得到从 0 开始的索引 

207 local_unique_mask = mask 

208 

209 if return_counts: 

210 # origin_idx: 只在唯一值位置存储 

211 origin_idx_mask = ne_result.to(tl.int1) & local_unique_mask 

212 tl.store( 

213 origin_idx_ptr + (offset + local_unique_offset), 

214 i0, 

215 mask=origin_idx_mask, 

216 ) 

217 else: 

218 # local_unique: 只在唯一值位置存储 

219 store_mask = ne_result.to(tl.int1) & local_unique_mask 

220 tl.store(local_unique_ptr + (offset + local_unique_offset), a, mask=store_mask) 

221 

222 # tile_sum - 获取最后一个有效位置的 cumsum 值 

223 valid_cumsum = tl.where(mask, cumsum, 0) 

224 last_cumsum = tl.max(valid_cumsum) 

225 

226 # 直接使用 last_cumsum,不需要特殊处理第一个 tile 

227 if global_pid < global_ctas_num: 

228 tl.store(tile_sum_ptr + global_pid, last_cumsum) 

229 

230 

231@libentry() 

232@triton.jit 

233def local_quick_unique_flat_kernel( 

234 sorted_data_ptr: tl.tensor, # in 

235 local_unique_ptr: tl.tensor, 

236 origin_idx_ptr: tl.tensor, 

237 tile_sum_ptr: tl.tensor, # out 

238 global_ctas_num: int, 

239 num_tasks: int, 

240 tiles_per_cta: int, 

241 tile_size: tl.constexpr, 

242 return_counts: tl.constexpr, 

243): 

244 pid = tle.program_id(0) 

245 ctas_num = tle.num_programs(0) 

246 # grid-stride-loop style kernel 

247 for j in range(0, tiles_per_cta): 

248 global_pid = pid + j * ctas_num 

249 local_quick_unique_flat_impl( 

250 global_pid, 

251 sorted_data_ptr, # in 

252 local_unique_ptr, 

253 origin_idx_ptr, 

254 tile_sum_ptr, # out 

255 global_ctas_num, 

256 num_tasks, 

257 tile_size, 

258 return_counts, 

259 ) 

260 

261 

262@triton.jit 

263def global_quick_unique_flat_impl( 

264 global_pid, 

265 total, 

266 local_unique_ptr: tl.tensor, 

267 origin_idx_ptr: tl.tensor, 

268 tile_sum_ptr: tl.tensor, # in 

269 data_out_ptr: tl.tensor, 

270 idx_ptr: tl.tensor, # out 

271 ctas_num: int, 

272 global_ctas_num: int, 

273 next_power_global_ctas_num: tl.constexpr, 

274 num_tasks: int, 

275 tile_size: tl.constexpr, 

276 return_counts: tl.constexpr, 

277 CHUNK_SIZE: tl.constexpr, # 每个块的大小 

278 MAX_CHUNKS: tl.constexpr, # 最大块数 

279): 

280 r = tl.arange(0, tile_size) 

281 i0 = global_pid * tile_size + r 

282 mask = i0 < num_tasks 

283 

284 # load tile_sum - 使用分块处理避免UB overflow 

285 start_idx = tl.maximum(global_pid - ctas_num, 0) 

286 end_idx = tl.minimum(global_pid, global_ctas_num) 

287 

288 # 分块累加 pre_tile_sum 

289 total_sum = 0 

290 total_sum = total_sum.to(tl.int64) 

291 for chunk_id in range(MAX_CHUNKS): 

292 chunk_start = start_idx + chunk_id * CHUNK_SIZE 

293 

294 # 只有当这个chunk在有效范围内时才处理 

295 if chunk_start < end_idx: 

296 p = tl.arange(0, CHUNK_SIZE) 

297 p_idx = chunk_start + p 

298 

299 # 计算mask:需要确保索引在 [start_idx, end_idx) 范围内 

300 pre_tile_sum_mask = ( 

301 (p_idx < end_idx) & (p_idx >= start_idx) & (p_idx < global_ctas_num) 

302 ) 

303 

304 pre_tile_sum = tl.load( 

305 tile_sum_ptr + p_idx, mask=pre_tile_sum_mask, other=0 

306 ) 

307 total_sum += tl.sum(pre_tile_sum) 

308 

309 cur_tile_sum_mask = global_pid < global_ctas_num 

310 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask, other=0) 

311 

312 # total 

313 total += total_sum 

314 

315 # tile_sum 存储 

316 if global_pid == global_ctas_num - 1: 

317 tl.store(tile_sum_ptr + global_pid, total + cur_tile_sum) 

318 

319 # idx or data_out 

320 tile_mask = r < cur_tile_sum 

321 out_offset = total + r 

322 

323 if return_counts: 

324 # move origin_idx to idx_ptr 

325 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask, other=0) 

326 tl.store(idx_ptr + out_offset, origin_idx, mask=tile_mask) 

327 else: 

328 # move local_unique to data_out_ptr 

329 local_unique = tl.load(local_unique_ptr + i0, mask=mask, other=0) 

330 tl.store(data_out_ptr + out_offset, local_unique, mask=tile_mask) 

331 

332 return total 

333 

334 

335@libentry() 

336@triton.jit 

337def global_quick_unique_flat_kernel( 

338 local_unique_ptr: tl.tensor, 

339 origin_idx_ptr: tl.tensor, 

340 tile_sum_ptr: tl.tensor, # in 

341 data_out_ptr: tl.tensor, 

342 idx_ptr: tl.tensor, # out 

343 ctas_num: int, 

344 global_ctas_num: int, 

345 next_power_global_ctas_num: tl.constexpr, 

346 num_tasks: int, 

347 tiles_per_cta: int, 

348 tile_size: tl.constexpr, 

349 one_tile_per_cta: tl.constexpr, 

350 return_counts: tl.constexpr, 

351): 

352 pid = tle.program_id(0) 

353 ctas_num = tle.num_programs(0) 

354 

355 # 分块处理参数 

356 CHUNK_SIZE: tl.constexpr = 2048 # 每块处理2048个元素 

357 MAX_CHUNKS: tl.constexpr = 32 # 最多32块 (2048 * 32 = 65536) 

358 

359 if one_tile_per_cta: 

360 # monolitic kernel style 

361 global_quick_unique_flat_impl( 

362 pid, 

363 0, 

364 local_unique_ptr, 

365 origin_idx_ptr, 

366 tile_sum_ptr, # in 

367 data_out_ptr, 

368 idx_ptr, # out 

369 ctas_num, 

370 global_ctas_num, 

371 next_power_global_ctas_num, 

372 num_tasks, 

373 tile_size, 

374 return_counts, 

375 CHUNK_SIZE, 

376 MAX_CHUNKS, 

377 ) 

378 else: 

379 # grid-stride-loop style kernel 

380 total = tl.zeros([1], dtype=tl.int64) 

381 for j in range(0, tiles_per_cta): 

382 global_pid = pid + j * ctas_num 

383 total = global_quick_unique_flat_impl( 

384 global_pid, 

385 total, 

386 local_unique_ptr, 

387 origin_idx_ptr, 

388 tile_sum_ptr, # in 

389 data_out_ptr, 

390 idx_ptr, # out 

391 ctas_num, 

392 global_ctas_num, 

393 next_power_global_ctas_num, 

394 num_tasks, 

395 tile_size, 

396 return_counts, 

397 CHUNK_SIZE, 

398 MAX_CHUNKS, 

399 ) 

400 

401 

402def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool): 

403 num_tasks = sorted_data.numel() 

404 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

405 tile_size = min(4096, next_power_num_tasks) 

406 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

407 

408 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

409 ctas_num = global_ctas_num if global_ctas_num < 65536 else 2048 

410 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

411 num_warps = 8 if tiles_per_cta == 1 else 32 

412 grid = (ctas_num, 1, 1) 

413 

414 # allocate tensor 

415 if return_counts: 

416 local_unique = None 

417 origin_idx = torch.empty_like(sorted_data, dtype=torch.int64) 

418 idx = torch.empty_like(origin_idx) 

419 else: 

420 local_unique = torch.empty_like(sorted_data) 

421 origin_idx = None 

422 idx = None 

423 counts = None 

424 tile_sum = torch.empty( 

425 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device 

426 ) 

427 data_out = None 

428 if not return_counts: 

429 data_out = torch.empty_like(sorted_data) 

430 

431 # launch kernel 

432 with torch_device_fn.device(sorted_data.device.index): 

433 local_quick_unique_flat_kernel[grid]( 

434 sorted_data, # in 

435 local_unique, 

436 origin_idx, 

437 tile_sum, # out 

438 global_ctas_num, 

439 num_tasks, 

440 tiles_per_cta=tiles_per_cta, 

441 tile_size=tile_size, 

442 return_counts=return_counts, 

443 num_warps=num_warps, 

444 ) 

445 global_quick_unique_flat_kernel[grid]( 

446 local_unique, 

447 origin_idx, 

448 tile_sum, # in 

449 data_out, 

450 idx, # out 

451 ctas_num, 

452 global_ctas_num, 

453 next_power_global_ctas_num, 

454 num_tasks, 

455 tiles_per_cta=tiles_per_cta, 

456 tile_size=tile_size, 

457 one_tile_per_cta=tiles_per_cta == 1, 

458 return_counts=return_counts, 

459 num_warps=num_warps, 

460 ) 

461 out_size = tile_sum[-1].item() 

462 if return_counts: 

463 data_out = torch.empty( 

464 (out_size,), dtype=sorted_data.dtype, device=sorted_data.device 

465 ) 

466 idx = idx[:out_size] 

467 counts = origin_idx[:out_size] 

468 quick_output_flat_kernel[grid]( 

469 sorted_data, 

470 idx, 

471 num_tasks, # in 

472 data_out, 

473 counts, # out 

474 out_size, 

475 tiles_per_cta, 

476 tile_size, 

477 num_warps=num_warps, 

478 ) 

479 

480 if return_counts: 

481 return data_out, None, counts 

482 else: 

483 return data_out[:out_size], None, None 

484 

485 

486@triton.jit 

487def local_ne_flat_impl( 

488 global_pid, 

489 sorted_data_ptr: tl.tensor, # in 

490 ne_result_ptr: tl.tensor, 

491 tile_sum_ptr: tl.tensor, # out 

492 global_ctas_num: int, 

493 num_tasks: int, 

494 tile_size: tl.constexpr, 

495 BLOCK_SIZE_SUB: tl.constexpr, # 新增参数用于分块处理 

496): 

497 # 计算当前tile的起始位置 

498 tile_start = global_pid * tile_size 

499 

500 # 计算子块数量 

501 num_sub_blocks = triton.cdiv(tile_size, BLOCK_SIZE_SUB) 

502 

503 # 初始化tile累加和 

504 tile_sum_acc = tl.zeros([], dtype=tl.int32) 

505 

506 # 按子块索引循环处理 

507 for sub_block_idx in range(num_sub_blocks): 

508 # 计算当前子块的起始位置 

509 sub_block_start = tile_start + sub_block_idx * BLOCK_SIZE_SUB 

510 

511 # 创建子块索引 

512 r = tl.arange(0, BLOCK_SIZE_SUB) 

513 i0 = sub_block_start + r 

514 

515 # 计算mask,确保不越界 

516 mask = (i0 < num_tasks) & (i0 >= 0) 

517 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

518 

519 # load数据 

520 a = tl.load(sorted_data_ptr + i0, mask=mask, other=0) 

521 b = tl.load(sorted_data_ptr + i0_prev, mask=mask, other=0) 

522 

523 # 计算不等式结果 

524 # 特殊处理第一个元素(全局索引为0的情况) 

525 ne_result = tl.where(i0 > 0, a != b, 0) 

526 ne_result = tl.where(mask, ne_result, 0) 

527 

528 # 存储ne_result 

529 tl.store(ne_result_ptr + i0, ne_result, mask=mask) 

530 

531 # 累加到tile_sum 

532 sub_block_sum = tl.sum(ne_result) 

533 tile_sum_acc += sub_block_sum 

534 

535 # 存储tile累加和 

536 tile_sum_mask = global_pid < global_ctas_num 

537 tl.store(tile_sum_ptr + global_pid, tile_sum_acc, mask=tile_sum_mask) 

538 

539 

540@libentry() 

541@triton.jit 

542def local_ne_flat_kernel( 

543 sorted_data_ptr: tl.tensor, # in 

544 ne_result_ptr: tl.tensor, 

545 tile_sum_ptr: tl.tensor, # out 

546 global_ctas_num: int, 

547 num_tasks: int, 

548 tiles_per_cta: int, 

549 tile_size: tl.constexpr, 

550): 

551 pid = tle.program_id(0) 

552 ctas_num = tle.num_programs(0) 

553 # grid-stride-loop style kernel 

554 for j in range(0, tiles_per_cta): 

555 global_pid = pid + j * ctas_num 

556 local_ne_flat_impl( 

557 global_pid, 

558 sorted_data_ptr, # in 

559 ne_result_ptr, 

560 tile_sum_ptr, # out 

561 global_ctas_num, 

562 num_tasks, 

563 tile_size, 

564 BLOCK_SIZE_SUB=256, 

565 ) 

566 

567 

568@triton.jit 

569def global_cumsum_flat_impl( 

570 global_pid, 

571 total, 

572 ne_result_ptr: tl.tensor, 

573 tile_sum_ptr: tl.tensor, # in 

574 sorted_data_ptr: tl.tensor, 

575 sorted_indices_ptr: tl.tensor, # in 

576 data_out_ptr: tl.tensor, 

577 inverse_indices_ptr: tl.tensor, 

578 idx_ptr: tl.tensor, # out 

579 ctas_num: tl.constexpr, 

580 global_ctas_num: int, 

581 next_power_global_ctas_num: tl.constexpr, 

582 num_tasks: int, 

583 tile_size: tl.constexpr, 

584 return_counts: tl.constexpr, 

585 MAX_CTAS_NUM: tl.constexpr, 

586 CHUNK_SIZE: tl.constexpr = 512, 

587): 

588 offset = global_pid * tile_size 

589 r = tl.arange(0, tile_size) 

590 i0 = offset + r 

591 mask = i0 < num_tasks 

592 

593 # load sorted_data, sorted_indices 

594 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask) 

595 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) 

596 

597 # 计算需要加载的tile_sum范围 

598 start_idx = tl.maximum(global_pid - ctas_num, 0) 

599 end_idx = tl.minimum(global_pid, global_ctas_num) 

600 actual_load_size = end_idx - start_idx 

601 actual_load_size = actual_load_size.to(tl.int64) 

602 

603 # 分块累加tile_sum,避免一次性分配过大的张量 

604 chunk_sum = 0 

605 chunk_sum = chunk_sum.to(tl.int64) 

606 

607 for chunk_id in range(tl.cdiv(MAX_CTAS_NUM, CHUNK_SIZE)): 

608 # 计算当前chunk的范围 

609 chunk_start = chunk_id * CHUNK_SIZE 

610 chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, actual_load_size) 

611 

612 # 只在有效chunk范围内加载 

613 if chunk_start < actual_load_size: 

614 p = tl.arange(0, CHUNK_SIZE) 

615 p_idx = start_idx + chunk_start + p 

616 

617 # 更精确的mask条件 

618 pre_tile_sum_mask = ( 

619 (p < (chunk_end - chunk_start)) 

620 & (p_idx >= start_idx) # 当前chunk内有效 

621 & (p_idx < end_idx) 

622 & (p_idx >= 0) 

623 & (p_idx < global_ctas_num) 

624 ) 

625 

626 pre_tile_sum = tl.load( 

627 tile_sum_ptr + p_idx, mask=pre_tile_sum_mask, other=0 

628 ) 

629 chunk_sum += tl.sum(pre_tile_sum) 

630 

631 # cumsum 

632 total += chunk_sum 

633 ne_result = tl.load(ne_result_ptr + i0, mask=mask) 

634 ne_result_i1 = ne_result.to(tl.int1) 

635 ne_result = ne_result.to(tl.int32) 

636 cumsum = tl.cumsum(ne_result) 

637 

638 # tile_sum 

639 if global_pid == global_ctas_num - 1: 

640 last_tile_sum_mask = i0 == num_tasks - 1 

641 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum) 

642 tl.store( 

643 tile_sum_ptr + global_pid + tl.zeros_like(r), 

644 tile_sum, 

645 mask=last_tile_sum_mask, 

646 ) 

647 cumsum += total 

648 

649 # data_out: scatter_(to=cumsum, sorted_data) 

650 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask) 

651 

652 # inverse_indices: scatter_(to=sorted_indices, cumsum) 

653 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) 

654 

655 # idx 

656 if return_counts: 

657 idx_mask = ((i0 == 0) | ne_result_i1) & mask 

658 tl.store(idx_ptr + cumsum, i0, mask=idx_mask) 

659 

660 return total 

661 

662 

663@libentry() 

664@triton.jit 

665def global_cumsum_flat_kernel( 

666 ne_result_ptr: tl.tensor, 

667 tile_sum_ptr: tl.tensor, # in 

668 sorted_data_ptr: tl.tensor, 

669 sorted_indices_ptr: tl.tensor, # in 

670 data_out_ptr: tl.tensor, 

671 inverse_indices_ptr: tl.tensor, 

672 idx_ptr: tl.tensor, # out 

673 ctas_num: int, 

674 global_ctas_num: int, 

675 next_power_global_ctas_num: tl.constexpr, 

676 num_tasks: int, 

677 tiles_per_cta: int, 

678 tile_size: tl.constexpr, 

679 one_tile_per_cta: tl.constexpr, 

680 return_counts: tl.constexpr, 

681): 

682 pid = tle.program_id(0) 

683 ctas_num = tle.num_programs(0) 

684 MAX_CTAS_NUM: tl.constexpr = 65536 

685 

686 if one_tile_per_cta: # monolitic kernel style 

687 global_cumsum_flat_impl( 

688 pid, 

689 0, 

690 ne_result_ptr, 

691 tile_sum_ptr, # in 

692 sorted_data_ptr, 

693 sorted_indices_ptr, # in 

694 data_out_ptr, 

695 inverse_indices_ptr, 

696 idx_ptr, # out 

697 ctas_num, 

698 global_ctas_num, 

699 next_power_global_ctas_num, 

700 num_tasks, 

701 tile_size, 

702 return_counts, 

703 MAX_CTAS_NUM, 

704 ) 

705 else: # grid-stride-loop style kernel 

706 total = tl.zeros([1], dtype=tl.int64) 

707 for j in range(0, tiles_per_cta): 

708 global_pid = pid + j * ctas_num 

709 total = global_cumsum_flat_impl( 

710 global_pid, 

711 total, 

712 ne_result_ptr, 

713 tile_sum_ptr, # in 

714 sorted_data_ptr, 

715 sorted_indices_ptr, # in 

716 data_out_ptr, 

717 inverse_indices_ptr, 

718 idx_ptr, # out 

719 ctas_num, 

720 global_ctas_num, 

721 next_power_global_ctas_num, 

722 num_tasks, 

723 tile_size, 

724 return_counts, 

725 MAX_CTAS_NUM, 

726 ) 

727 

728 

729def sorted_indices_unique_flat( 

730 sorted_data: torch.Tensor, sorted_indices: torch.Tensor, return_counts: bool 

731): 

732 num_tasks = sorted_data.numel() 

733 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

734 if num_tasks >= 167772160: 

735 tile_size = 4096 

736 else: 

737 tile_size = min(2048, next_power_num_tasks) 

738 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

739 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

740 ctas_num = global_ctas_num if global_ctas_num < 65536 else 8192 

741 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

742 grid = (ctas_num, 1, 1) 

743 # allocate tensor 

744 ne_result = torch.empty_like(sorted_data, dtype=torch.bool) 

745 tile_sum = torch.empty( 

746 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device 

747 ) 

748 data_out = torch.empty_like(sorted_data) 

749 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) 

750 idx = None 

751 if return_counts: 

752 idx = torch.empty_like(inverse_indices) 

753 # launch kernel 

754 with torch_device_fn.device(sorted_data.device.index): 

755 local_ne_flat_kernel[grid]( 

756 sorted_data, # in 

757 ne_result, 

758 tile_sum, # out 

759 global_ctas_num, 

760 num_tasks, 

761 tiles_per_cta=tiles_per_cta, 

762 tile_size=tile_size, 

763 ) 

764 global_cumsum_flat_kernel[grid]( 

765 ne_result, 

766 tile_sum, # in 

767 sorted_data, 

768 sorted_indices, # in 

769 data_out, 

770 inverse_indices, 

771 idx, # out 

772 ctas_num, 

773 global_ctas_num, 

774 next_power_global_ctas_num, 

775 num_tasks, 

776 tiles_per_cta=tiles_per_cta, 

777 tile_size=tile_size, 

778 one_tile_per_cta=tiles_per_cta == 1, 

779 return_counts=return_counts, 

780 ) 

781 out_size = tile_sum[-1].item() + 1 

782 counts = None 

783 if return_counts: 

784 idx = idx[:out_size] 

785 counts = torch.empty_like(idx) 

786 output_counts_flat_kernel[grid]( 

787 idx, 

788 num_tasks, # in 

789 counts, # out 

790 out_size, 

791 tiles_per_cta, 

792 tile_size, 

793 ) 

794 return data_out[:out_size], inverse_indices, counts 

795 

796 

797def simple_unique_flat( 

798 sorted_data: torch.Tensor, 

799 sorted_indices: torch.Tensor, 

800 return_inverse: bool, 

801 return_counts: bool, 

802): 

803 num_tasks = sorted_data.numel() 

804 grid = (1, 1, 1) 

805 

806 # allocate tensor 

807 data_out = torch.empty_like(sorted_data) 

808 if return_inverse: 

809 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) 

810 else: 

811 inverse_indices = None 

812 if return_counts: 

813 idx = torch.empty_like(sorted_data, dtype=torch.int64) 

814 else: 

815 idx = None 

816 unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device) 

817 

818 # launch kernel 

819 with torch_device_fn.device(sorted_data.device.index): 

820 simple_unique_flat_kernel[grid]( 

821 sorted_data, 

822 sorted_indices, # in 

823 data_out, 

824 inverse_indices, 

825 idx, 

826 unique_size, # out 

827 return_inverse, 

828 return_counts, 

829 num_tasks, 

830 tile_size=triton.next_power_of_2(num_tasks), 

831 num_warps=8, 

832 ) 

833 out_size = unique_size.item() + 1 

834 counts = None 

835 if return_counts: 

836 idx = idx[:out_size] 

837 counts = torch.empty_like(idx) 

838 with torch_device_fn.device(sorted_data.device.index): 

839 output_counts_flat_kernel[grid]( 

840 idx, 

841 num_tasks, # in 

842 counts, # out 

843 num_tasks=out_size, 

844 tiles_per_cta=1, 

845 tile_size=triton.next_power_of_2(out_size), 

846 num_warps=8, 

847 ) 

848 return data_out[:out_size], inverse_indices, counts 

849 

850 

851def _unique2( 

852 in0: torch.Tensor, 

853 sorted: bool = True, 

854 return_inverse: bool = False, 

855 return_counts: bool = False, 

856): 

857 logger.debug("GEMS_ASCEND _UNIQUE2") 

858 if in0.numel() <= 8192: 

859 sorted_data, sorted_indices = torch.sort(in0.ravel()) 

860 data_out, inverse_indices, counts = simple_unique_flat( 

861 sorted_data, sorted_indices, return_inverse, return_counts 

862 ) 

863 elif return_inverse: 

864 sorted_data, sorted_indices = torch.sort(in0.ravel()) 

865 data_out, inverse_indices, counts = sorted_indices_unique_flat( 

866 sorted_data, sorted_indices, return_counts 

867 ) 

868 else: 

869 sorted_data, _ = torch.sort(in0.ravel()) 

870 data_out, inverse_indices, counts = sorted_quick_unique_flat( 

871 sorted_data, return_counts 

872 ) 

873 return ( 

874 data_out, 

875 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0), 

876 counts, 

877 )