Coverage for src/flag_gems/runtime/backend/_metax/ops/unique.py: 0%
284 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import torch
2import triton
3import triton.language as tl
5from flag_gems.runtime import torch_device_fn
6from flag_gems.utils import triton_lang_extension as tle
7from flag_gems.utils.libentry import libentry
10@libentry()
11@triton.jit
12def simple_unique_flat_kernel(
13 sorted_data_ptr: tl.tensor,
14 sorted_indices_ptr: tl.tensor, # in
15 data_out_ptr: tl.tensor,
16 inverse_indices_ptr: tl.tensor,
17 idx_ptr: tl.tensor,
18 unique_size_ptr: tl.tensor, # out
19 return_inverse: tl.constexpr,
20 return_counts: tl.constexpr,
21 num_tasks: int,
22 tile_size: tl.constexpr,
23):
24 i0 = tl.arange(0, tile_size)
25 mask = i0 < num_tasks
27 # load
28 a = tl.load(sorted_data_ptr + i0, mask=mask)
29 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
30 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
32 # ne & cumsum
33 ne_result = tl.where(i0 > 0, a != b, 0)
34 cumsum = tl.cumsum(ne_result)
36 # unique_size
37 unique_size_mask = i0 == tile_size - 1
38 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask)
40 # data_out: scatter_(to=cumsum, sorted_data)
41 tl.store(data_out_ptr + cumsum, a, mask=mask)
43 # inverse_indices: scatter_(to=sorted_indices, cumsum)
44 if return_inverse:
45 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
46 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)
48 # idx
49 if return_counts:
50 idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask
51 tl.store(idx_ptr + cumsum, i0, mask=idx_mask)
54@triton.jit
55def output_counts_flat_impl(
56 global_pid,
57 idx_ptr: tl.tensor,
58 origin_num_tasks: int, # in
59 counts_ptr: tl.tensor, # out
60 num_tasks: int,
61 tile_size: tl.constexpr,
62):
63 r = tl.arange(0, tile_size)
65 # load idx
66 i0 = global_pid * tile_size + r
67 mask = i0 < num_tasks
68 idx = tl.load(idx_ptr + i0, mask=mask)
70 # load idx_next
71 i0_next = i0 + 1
72 next_mask = i0_next < num_tasks
73 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
75 # diff
76 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
78 # store counts
79 tl.store(counts_ptr + i0, counts, mask=mask)
82@libentry()
83@triton.jit
84def output_counts_flat_kernel(
85 idx_ptr: tl.tensor,
86 origin_num_tasks: int, # in
87 counts_ptr: tl.tensor, # out
88 num_tasks: int,
89 tiles_per_cta: int,
90 tile_size: tl.constexpr,
91):
92 pid = tle.program_id(0)
93 ctas_num = tle.num_programs(0)
94 # grid-stride-loop style kernel
95 for j in range(0, tiles_per_cta):
96 global_pid = pid + j * ctas_num
97 output_counts_flat_impl(
98 global_pid,
99 idx_ptr,
100 origin_num_tasks, # in
101 counts_ptr, # out
102 num_tasks,
103 tile_size,
104 )
107@triton.jit
108def quick_output_flat_impl(
109 global_pid,
110 sorted_data_ptr: tl.tensor,
111 idx_ptr: tl.tensor,
112 origin_num_tasks: int, # in
113 data_out_ptr: tl.tensor,
114 counts_ptr: tl.tensor, # out
115 num_tasks: int,
116 tile_size: tl.constexpr,
117):
118 r = tl.arange(0, tile_size)
120 # load idx
121 i0 = global_pid * tile_size + r
122 mask = i0 < num_tasks
123 idx = tl.load(idx_ptr + i0, mask=mask)
125 # load idx_next
126 i0_next = i0 + 1
127 next_mask = i0_next < num_tasks
128 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
130 # diff
131 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
133 # store counts
134 tl.store(counts_ptr + i0, counts, mask=mask)
136 # data_out: gather(sorted_data, from=idx)
137 sorted_data = tl.load(sorted_data_ptr + idx, mask=mask)
138 tl.store(data_out_ptr + i0, sorted_data, mask=mask)
141@libentry()
142@triton.jit
143def quick_output_flat_kernel(
144 sorted_data_ptr: tl.tensor,
145 idx_ptr: tl.tensor,
146 origin_num_tasks: int, # in
147 data_out_ptr: tl.tensor,
148 counts_ptr: tl.tensor, # out
149 num_tasks: int,
150 tiles_per_cta: int,
151 tile_size: tl.constexpr,
152):
153 pid = tle.program_id(0)
154 ctas_num = tle.num_programs(0)
155 # grid-stride-loop style kernel
156 for j in range(0, tiles_per_cta):
157 global_pid = pid + j * ctas_num
158 quick_output_flat_impl(
159 global_pid,
160 sorted_data_ptr,
161 idx_ptr,
162 origin_num_tasks, # in
163 data_out_ptr,
164 counts_ptr, # out
165 num_tasks,
166 tile_size,
167 )
170@triton.jit
171def local_quick_unique_flat_impl(
172 global_pid,
173 sorted_data_ptr: tl.tensor, # in
174 local_unique_ptr: tl.tensor,
175 origin_idx_ptr: tl.tensor,
176 tile_sum_ptr: tl.tensor, # out
177 global_ctas_num: int,
178 num_tasks: int,
179 tile_size: tl.constexpr,
180 return_counts: tl.constexpr,
181):
182 offset = global_pid * tile_size
183 r = tl.arange(0, tile_size)
184 i0 = offset + r
185 mask = i0 < num_tasks
187 # load
188 a = tl.load(sorted_data_ptr + i0, mask=mask)
189 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
190 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
192 # ne & cumsum
193 ne_result = tl.where(i0 > 0, a != b, 0)
194 cumsum = tl.cumsum(ne_result)
196 # local_id or local_unique
197 local_unique_offset = cumsum - tl.where(global_pid > 0, 1, 0)
198 local_unique_mask = (local_unique_offset >= 0) & mask
199 if return_counts:
200 # origin_idx: scatter_(to=cumsum, i0)
201 origin_idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & local_unique_mask
202 tl.store(
203 origin_idx_ptr + (offset + local_unique_offset),
204 i0,
205 mask=origin_idx_mask,
206 )
207 else:
208 # local_unique: scatter_(to=cumsum, sorted_data)
209 tl.store(
210 local_unique_ptr + (offset + local_unique_offset), a, mask=local_unique_mask
211 )
213 # tile_sum
214 tile_sum_mask = (r == tile_size - 1) & (global_pid < global_ctas_num)
215 tile_sum = tl.where(tile_sum_mask & (global_pid == 0), cumsum + 1, cumsum)
216 tl.store(tile_sum_ptr + global_pid + tl.zeros_like(r), tile_sum, mask=tile_sum_mask)
219@libentry()
220@triton.jit
221def local_quick_unique_flat_kernel(
222 sorted_data_ptr: tl.tensor, # in
223 local_unique_ptr: tl.tensor,
224 origin_idx_ptr: tl.tensor,
225 tile_sum_ptr: tl.tensor, # out
226 global_ctas_num: int,
227 num_tasks: int,
228 tiles_per_cta: int,
229 tile_size: tl.constexpr,
230 return_counts: tl.constexpr,
231):
232 pid = tle.program_id(0)
233 ctas_num = tle.num_programs(0)
234 # grid-stride-loop style kernel
235 for j in range(0, tiles_per_cta):
236 global_pid = pid + j * ctas_num
237 local_quick_unique_flat_impl(
238 global_pid,
239 sorted_data_ptr, # in
240 local_unique_ptr,
241 origin_idx_ptr,
242 tile_sum_ptr, # out
243 global_ctas_num,
244 num_tasks,
245 tile_size,
246 return_counts,
247 )
250@triton.jit
251def global_quick_unique_flat_impl(
252 global_pid,
253 total,
254 local_unique_ptr: tl.tensor,
255 origin_idx_ptr: tl.tensor,
256 tile_sum_ptr: tl.tensor, # in
257 data_out_ptr: tl.tensor,
258 idx_ptr: tl.tensor, # out
259 ctas_num: int,
260 global_ctas_num: int,
261 next_power_global_ctas_num: tl.constexpr,
262 num_tasks: int,
263 tile_size: tl.constexpr,
264 return_counts: tl.constexpr,
265):
266 r = tl.arange(0, tile_size)
267 i0 = global_pid * tile_size + r
268 mask = i0 < num_tasks
270 # load tile_sum
271 p = tl.arange(0, next_power_global_ctas_num)
272 pre_tile_sum_mask = (
273 (p >= global_pid - ctas_num)
274 & (p < global_pid)
275 & (p >= 0)
276 & (p < global_ctas_num)
277 )
278 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
279 cur_tile_sum_mask = global_pid < global_ctas_num
280 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask)
282 # total
283 total += tl.sum(pre_tile_sum)
284 if global_pid == global_ctas_num - 1:
285 last_tile_sum_mask = p == global_pid
286 tl.store(tile_sum_ptr + p, total + cur_tile_sum, mask=last_tile_sum_mask)
288 # idx or data_out
289 tile_mask = r < cur_tile_sum
290 out_offset = total + r
291 if return_counts:
292 # move origin_idx to idx_ptr
293 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask)
294 tl.store(idx_ptr + out_offset, origin_idx, mask=tile_mask)
295 else:
296 # move local_unique to data_out_ptr
297 local_unique = tl.load(local_unique_ptr + i0, mask=mask)
298 tl.store(data_out_ptr + out_offset, local_unique, mask=tile_mask)
300 return total
303@libentry()
304@triton.jit
305def global_quick_unique_flat_kernel(
306 local_unique_ptr: tl.tensor,
307 origin_idx_ptr: tl.tensor,
308 tile_sum_ptr: tl.tensor, # in
309 data_out_ptr: tl.tensor,
310 idx_ptr: tl.tensor, # out
311 ctas_num: int,
312 global_ctas_num: int,
313 next_power_global_ctas_num: tl.constexpr,
314 num_tasks: int,
315 tiles_per_cta: int,
316 tile_size: tl.constexpr,
317 one_tile_per_cta: tl.constexpr,
318 return_counts: tl.constexpr,
319):
320 pid = tle.program_id(0)
321 ctas_num = tle.num_programs(0)
322 if one_tile_per_cta: # monolitic kernel style
323 global_quick_unique_flat_impl(
324 pid,
325 0,
326 local_unique_ptr,
327 origin_idx_ptr,
328 tile_sum_ptr, # in
329 data_out_ptr,
330 idx_ptr, # out
331 ctas_num,
332 global_ctas_num,
333 next_power_global_ctas_num,
334 num_tasks,
335 tile_size,
336 return_counts,
337 )
338 else: # grid-stride-loop style kernel
339 total = tl.zeros([1], dtype=tl.int64)
340 for j in range(0, tiles_per_cta):
341 global_pid = pid + j * ctas_num
342 total = global_quick_unique_flat_impl(
343 global_pid,
344 total,
345 local_unique_ptr,
346 origin_idx_ptr,
347 tile_sum_ptr, # in
348 data_out_ptr,
349 idx_ptr, # out
350 ctas_num,
351 global_ctas_num,
352 next_power_global_ctas_num,
353 num_tasks,
354 tile_size,
355 return_counts,
356 )
359def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool):
360 num_tasks = sorted_data.numel()
361 next_power_num_tasks = triton.next_power_of_2(num_tasks)
362 tile_size = min(8192, next_power_num_tasks)
363 global_ctas_num = triton.cdiv(num_tasks, tile_size)
364 if global_ctas_num <= 8192:
365 tile_size = max(
366 32, min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks)
367 )
368 global_ctas_num = triton.cdiv(num_tasks, tile_size)
369 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
370 ctas_num = global_ctas_num if global_ctas_num < 65536 else 2048
371 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
372 num_warps = 8 if tiles_per_cta == 1 else 16 # maca support up to 16
373 grid = (ctas_num, 1, 1)
375 # allocate tensor
376 if return_counts:
377 local_unique = None
378 origin_idx = torch.empty_like(sorted_data, dtype=torch.int64)
379 idx = torch.empty_like(origin_idx)
380 else:
381 local_unique = torch.empty_like(sorted_data)
382 origin_idx = None
383 idx = None
384 counts = None
385 tile_sum = torch.empty(
386 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device
387 )
388 data_out = None
389 if not return_counts:
390 data_out = torch.empty_like(sorted_data)
392 # launch kernel
393 with torch_device_fn.device(sorted_data.device.index):
394 local_quick_unique_flat_kernel[grid](
395 sorted_data, # in
396 local_unique,
397 origin_idx,
398 tile_sum, # out
399 global_ctas_num,
400 num_tasks,
401 tiles_per_cta=tiles_per_cta,
402 tile_size=tile_size,
403 return_counts=return_counts,
404 num_warps=num_warps,
405 )
406 global_quick_unique_flat_kernel[grid](
407 local_unique,
408 origin_idx,
409 tile_sum, # in
410 data_out,
411 idx, # out
412 ctas_num,
413 global_ctas_num,
414 next_power_global_ctas_num,
415 num_tasks,
416 tiles_per_cta=tiles_per_cta,
417 tile_size=tile_size,
418 one_tile_per_cta=tiles_per_cta == 1,
419 return_counts=return_counts,
420 num_warps=num_warps,
421 )
422 out_size = tile_sum[-1].item()
423 if return_counts:
424 data_out = torch.empty(
425 (out_size,), dtype=sorted_data.dtype, device=sorted_data.device
426 )
427 idx = idx[:out_size]
428 counts = origin_idx[:out_size]
429 quick_output_flat_kernel[grid](
430 sorted_data,
431 idx,
432 num_tasks, # in
433 data_out,
434 counts, # out
435 out_size,
436 tiles_per_cta,
437 tile_size,
438 num_warps=num_warps,
439 )
441 if return_counts:
442 return data_out, None, counts
443 else:
444 return data_out[:out_size], None, None
447@triton.jit
448def local_ne_flat_impl(
449 global_pid,
450 sorted_data_ptr: tl.tensor, # in
451 ne_result_ptr: tl.tensor,
452 tile_sum_ptr: tl.tensor, # out
453 global_ctas_num: int,
454 num_tasks: int,
455 tile_size: tl.constexpr,
456):
457 r = tl.arange(0, tile_size)
458 i0 = global_pid * tile_size + r
459 mask = i0 < num_tasks
460 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
462 # load
463 a = tl.load(sorted_data_ptr + i0, mask=mask)
464 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
466 # compute
467 ne_result = tl.where(i0 > 0, a != b, 0)
469 # store ne_result
470 tl.store(ne_result_ptr + i0, ne_result, mask=mask)
472 # store tile_sum
473 tile_sum = tl.sum(ne_result)
474 tile_sum_mask = global_pid < global_ctas_num
475 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask)
478@libentry()
479@triton.jit
480def local_ne_flat_kernel(
481 sorted_data_ptr: tl.tensor, # in
482 ne_result_ptr: tl.tensor,
483 tile_sum_ptr: tl.tensor, # out
484 global_ctas_num: int,
485 num_tasks: int,
486 tiles_per_cta: int,
487 tile_size: tl.constexpr,
488):
489 pid = tle.program_id(0)
490 ctas_num = tle.num_programs(0)
491 # grid-stride-loop style kernel
492 for j in range(0, tiles_per_cta):
493 global_pid = pid + j * ctas_num
494 local_ne_flat_impl(
495 global_pid,
496 sorted_data_ptr, # in
497 ne_result_ptr,
498 tile_sum_ptr, # out
499 global_ctas_num,
500 num_tasks,
501 tile_size,
502 )
505@triton.jit
506def global_cumsum_flat_impl(
507 global_pid,
508 total,
509 ne_result_ptr: tl.tensor,
510 tile_sum_ptr: tl.tensor, # in
511 sorted_data_ptr: tl.tensor,
512 sorted_indices_ptr: tl.tensor, # in
513 data_out_ptr: tl.tensor,
514 inverse_indices_ptr: tl.tensor,
515 idx_ptr: tl.tensor, # out
516 ctas_num: tl.constexpr,
517 global_ctas_num: int,
518 next_power_global_ctas_num: tl.constexpr,
519 num_tasks: int,
520 tile_size: tl.constexpr,
521 return_counts: tl.constexpr,
522):
523 offset = global_pid * tile_size
524 r = tl.arange(0, tile_size)
525 i0 = offset + r
526 mask = i0 < num_tasks
528 # load sorted_data, sorted_indices
529 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask)
530 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
532 # load tile_sum
533 p = tl.arange(0, next_power_global_ctas_num)
534 pre_tile_sum_mask = (
535 (p >= global_pid - ctas_num)
536 & (p < global_pid)
537 & (p >= 0)
538 & (p < global_ctas_num)
539 )
540 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
542 # cumsum
543 total += tl.sum(pre_tile_sum)
544 ne_result = tl.load(ne_result_ptr + i0, mask=mask)
545 ne_result_i1 = ne_result.to(tl.int1)
546 ne_result = ne_result.to(tl.int32)
547 cumsum = tl.cumsum(ne_result)
549 # tile_sum
550 if global_pid == global_ctas_num - 1:
551 last_tile_sum_mask = i0 == num_tasks - 1
552 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum)
553 tl.store(
554 tile_sum_ptr + global_pid + tl.zeros_like(r),
555 tile_sum,
556 mask=last_tile_sum_mask,
557 )
558 cumsum += total
560 # data_out: scatter_(to=cumsum, sorted_data)
561 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask)
563 # inverse_indices: scatter_(to=sorted_indices, cumsum)
564 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)
566 # idx
567 if return_counts:
568 idx_mask = ((i0 == 0) | ne_result_i1) & mask
569 tl.store(idx_ptr + cumsum, i0, mask=idx_mask)
571 return total
574@libentry()
575@triton.jit
576def global_cumsum_flat_kernel(
577 ne_result_ptr: tl.tensor,
578 tile_sum_ptr: tl.tensor, # in
579 sorted_data_ptr: tl.tensor,
580 sorted_indices_ptr: tl.tensor, # in
581 data_out_ptr: tl.tensor,
582 inverse_indices_ptr: tl.tensor,
583 idx_ptr: tl.tensor, # out
584 ctas_num: int,
585 global_ctas_num: int,
586 next_power_global_ctas_num: tl.constexpr,
587 num_tasks: int,
588 tiles_per_cta: int,
589 tile_size: tl.constexpr,
590 one_tile_per_cta: tl.constexpr,
591 return_counts: tl.constexpr,
592):
593 pid = tle.program_id(0)
594 ctas_num = tle.num_programs(0)
595 if one_tile_per_cta: # monolitic kernel style
596 global_cumsum_flat_impl(
597 pid,
598 0,
599 ne_result_ptr,
600 tile_sum_ptr, # in
601 sorted_data_ptr,
602 sorted_indices_ptr, # in
603 data_out_ptr,
604 inverse_indices_ptr,
605 idx_ptr, # out
606 ctas_num,
607 global_ctas_num,
608 next_power_global_ctas_num,
609 num_tasks,
610 tile_size,
611 return_counts,
612 )
613 else: # grid-stride-loop style kernel
614 total = tl.zeros([1], dtype=tl.int64)
615 for j in range(0, tiles_per_cta):
616 global_pid = pid + j * ctas_num
617 total = global_cumsum_flat_impl(
618 global_pid,
619 total,
620 ne_result_ptr,
621 tile_sum_ptr, # in
622 sorted_data_ptr,
623 sorted_indices_ptr, # in
624 data_out_ptr,
625 inverse_indices_ptr,
626 idx_ptr, # out
627 ctas_num,
628 global_ctas_num,
629 next_power_global_ctas_num,
630 num_tasks,
631 tile_size,
632 return_counts,
633 )
636def sorted_indices_unique_flat(
637 sorted_data: torch.Tensor, sorted_indices: torch.Tensor, return_counts: bool
638):
639 num_tasks = sorted_data.numel()
640 next_power_num_tasks = triton.next_power_of_2(num_tasks)
641 tile_size = min(8192, next_power_num_tasks)
642 global_ctas_num = triton.cdiv(num_tasks, tile_size)
643 if global_ctas_num <= 8192:
644 min_tile_size = 512 if global_ctas_num > 32 else 256
645 tile_size = max(
646 min_tile_size,
647 min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks),
648 )
649 global_ctas_num = triton.cdiv(num_tasks, tile_size)
650 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
651 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192
652 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
653 num_warps = 8 if tiles_per_cta == 1 else 16 # maca support up to 16
654 grid = (ctas_num, 1, 1)
656 # allocate tensor
657 ne_result = torch.empty_like(sorted_data, dtype=torch.bool)
658 tile_sum = torch.empty(
659 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device
660 )
661 data_out = torch.empty_like(sorted_data)
662 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
663 idx = None
664 if return_counts:
665 idx = torch.empty_like(inverse_indices)
667 # launch kernel
668 with torch_device_fn.device(sorted_data.device.index):
669 local_ne_flat_kernel[grid](
670 sorted_data, # in
671 ne_result,
672 tile_sum, # out
673 global_ctas_num,
674 num_tasks,
675 tiles_per_cta=tiles_per_cta,
676 tile_size=tile_size,
677 num_warps=num_warps,
678 )
679 global_cumsum_flat_kernel[grid](
680 ne_result,
681 tile_sum, # in
682 sorted_data,
683 sorted_indices, # in
684 data_out,
685 inverse_indices,
686 idx, # out
687 ctas_num,
688 global_ctas_num,
689 next_power_global_ctas_num,
690 num_tasks,
691 tiles_per_cta=tiles_per_cta,
692 tile_size=tile_size,
693 one_tile_per_cta=tiles_per_cta == 1,
694 return_counts=return_counts,
695 num_warps=num_warps,
696 )
697 out_size = tile_sum[-1].item() + 1
698 counts = None
699 if return_counts:
700 idx = idx[:out_size]
701 counts = torch.empty_like(idx)
702 output_counts_flat_kernel[grid](
703 idx,
704 num_tasks, # in
705 counts, # out
706 out_size,
707 tiles_per_cta,
708 tile_size,
709 num_warps=num_warps,
710 )
712 return data_out[:out_size], inverse_indices, counts
715def simple_unique_flat(
716 sorted_data: torch.Tensor,
717 sorted_indices: torch.Tensor,
718 return_inverse: bool,
719 return_counts: bool,
720):
721 num_tasks = sorted_data.numel()
722 grid = (1, 1, 1)
724 # allocate tensor
725 data_out = torch.empty_like(sorted_data)
726 if return_inverse:
727 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
728 else:
729 inverse_indices = None
730 if return_counts:
731 idx = torch.empty_like(sorted_data, dtype=torch.int64)
732 else:
733 idx = None
734 unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device)
736 # launch kernel
737 with torch_device_fn.device(sorted_data.device.index):
738 simple_unique_flat_kernel[grid](
739 sorted_data,
740 sorted_indices, # in
741 data_out,
742 inverse_indices,
743 idx,
744 unique_size, # out
745 return_inverse,
746 return_counts,
747 num_tasks,
748 tile_size=triton.next_power_of_2(num_tasks),
749 num_warps=8,
750 )
751 out_size = unique_size.item() + 1
752 counts = None
753 if return_counts:
754 idx = idx[:out_size]
755 counts = torch.empty_like(idx)
756 with torch_device_fn.device(sorted_data.device.index):
757 output_counts_flat_kernel[grid](
758 idx,
759 num_tasks, # in
760 counts, # out
761 num_tasks=out_size,
762 tiles_per_cta=1,
763 tile_size=triton.next_power_of_2(out_size),
764 num_warps=8,
765 )
766 return data_out[:out_size], inverse_indices, counts
769def _unique2(
770 in0: torch.Tensor,
771 sorted: bool = True,
772 return_inverse: bool = False,
773 return_counts: bool = False,
774):
775 if in0.numel() <= 8192:
776 sorted_data, sorted_indices = torch.sort(in0.ravel())
777 data_out, inverse_indices, counts = simple_unique_flat(
778 sorted_data, sorted_indices, return_inverse, return_counts
779 )
780 elif return_inverse:
781 sorted_data, sorted_indices = torch.sort(in0.ravel())
782 data_out, inverse_indices, counts = sorted_indices_unique_flat(
783 sorted_data, sorted_indices, return_counts
784 )
785 else:
786 sorted_data, _ = torch.sort(in0.ravel())
787 data_out, inverse_indices, counts = sorted_quick_unique_flat(
788 sorted_data, return_counts
789 )
790 return (
791 data_out,
792 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0),
793 counts,
794 )