Coverage for src/flag_gems/runtime/backend/_hygon/ops/unique.py: 0%
287 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as tle
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def simple_unique_flat_kernel(
17 sorted_data_ptr: tl.tensor,
18 sorted_indices_ptr: tl.tensor, # in
19 data_out_ptr: tl.tensor,
20 inverse_indices_ptr: tl.tensor,
21 idx_ptr: tl.tensor,
22 unique_size_ptr: tl.tensor, # out
23 return_inverse: tl.constexpr,
24 return_counts: tl.constexpr,
25 num_tasks: int,
26 tile_size: tl.constexpr,
27):
28 i0 = tl.arange(0, tile_size)
29 mask = i0 < num_tasks
31 # load
32 a = tl.load(sorted_data_ptr + i0, mask=mask)
33 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
34 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
36 # ne & cumsum
37 ne_result = tl.where(i0 > 0, a != b, 0)
38 cumsum = tl.cumsum(ne_result)
40 # unique_size
41 unique_size_mask = i0 == tile_size - 1
42 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask)
44 # data_out: scatter_(to=cumsum, sorted_data)
45 tl.store(data_out_ptr + cumsum, a, mask=mask)
47 # inverse_indices: scatter_(to=sorted_indices, cumsum)
48 if return_inverse:
49 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
50 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)
52 # idx
53 if return_counts:
54 idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask
55 tl.store(idx_ptr + cumsum, i0, mask=idx_mask)
58@triton.jit
59def output_counts_flat_impl(
60 global_pid,
61 idx_ptr: tl.tensor,
62 origin_num_tasks: int, # in
63 counts_ptr: tl.tensor, # out
64 num_tasks: int,
65 tile_size: tl.constexpr,
66):
67 r = tl.arange(0, tile_size)
69 # load idx
70 i0 = global_pid * tile_size + r
71 mask = i0 < num_tasks
72 idx = tl.load(idx_ptr + i0, mask=mask)
74 # load idx_next
75 i0_next = i0 + 1
76 next_mask = i0_next < num_tasks
77 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
79 # diff
80 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
82 # store counts
83 tl.store(counts_ptr + i0, counts, mask=mask)
86@libentry()
87@triton.jit
88def output_counts_flat_kernel(
89 idx_ptr: tl.tensor,
90 origin_num_tasks: int, # in
91 counts_ptr: tl.tensor, # out
92 num_tasks: int,
93 tiles_per_cta: int,
94 tile_size: tl.constexpr,
95):
96 pid = tle.program_id(0)
97 ctas_num = tle.num_programs(0)
98 # grid-stride-loop style kernel
99 for j in range(0, tiles_per_cta):
100 global_pid = pid + j * ctas_num
101 output_counts_flat_impl(
102 global_pid,
103 idx_ptr,
104 origin_num_tasks, # in
105 counts_ptr, # out
106 num_tasks,
107 tile_size,
108 )
111@triton.jit
112def quick_output_flat_impl(
113 global_pid,
114 sorted_data_ptr: tl.tensor,
115 idx_ptr: tl.tensor,
116 origin_num_tasks: int, # in
117 data_out_ptr: tl.tensor,
118 counts_ptr: tl.tensor, # out
119 num_tasks: int,
120 tile_size: tl.constexpr,
121):
122 r = tl.arange(0, tile_size)
124 # load idx
125 i0 = global_pid * tile_size + r
126 mask = i0 < num_tasks
127 idx = tl.load(idx_ptr + i0, mask=mask)
129 # load idx_next
130 i0_next = i0 + 1
131 next_mask = i0_next < num_tasks
132 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
134 # diff
135 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
137 # store counts
138 tl.store(counts_ptr + i0, counts, mask=mask)
140 # data_out: gather(sorted_data, from=idx)
141 sorted_data = tl.load(sorted_data_ptr + idx, mask=mask)
142 tl.store(data_out_ptr + i0, sorted_data, mask=mask)
145@libentry()
146@triton.jit
147def quick_output_flat_kernel(
148 sorted_data_ptr: tl.tensor,
149 idx_ptr: tl.tensor,
150 origin_num_tasks: int, # in
151 data_out_ptr: tl.tensor,
152 counts_ptr: tl.tensor, # out
153 num_tasks: int,
154 tiles_per_cta: int,
155 tile_size: tl.constexpr,
156):
157 pid = tle.program_id(0)
158 ctas_num = tle.num_programs(0)
159 # grid-stride-loop style kernel
160 for j in range(0, tiles_per_cta):
161 global_pid = pid + j * ctas_num
162 quick_output_flat_impl(
163 global_pid,
164 sorted_data_ptr,
165 idx_ptr,
166 origin_num_tasks, # in
167 data_out_ptr,
168 counts_ptr, # out
169 num_tasks,
170 tile_size,
171 )
174@triton.jit
175def local_quick_unique_flat_impl(
176 global_pid,
177 sorted_data_ptr: tl.tensor, # in
178 local_unique_ptr: tl.tensor,
179 origin_idx_ptr: tl.tensor,
180 tile_sum_ptr: tl.tensor, # out
181 global_ctas_num: int,
182 num_tasks: int,
183 tile_size: tl.constexpr,
184 return_counts: tl.constexpr,
185):
186 offset = global_pid * tile_size
187 r = tl.arange(0, tile_size)
188 i0 = offset + r
189 mask = i0 < num_tasks
191 # load
192 a = tl.load(sorted_data_ptr + i0, mask=mask)
193 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
194 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
196 # ne & cumsum
197 ne_result = tl.where(i0 > 0, a != b, 0)
198 cumsum = tl.cumsum(ne_result)
200 # local_id or local_unique
201 local_unique_offset = cumsum - tl.where(global_pid > 0, 1, 0)
202 local_unique_mask = (local_unique_offset >= 0) & mask
203 if return_counts:
204 # origin_idx: scatter_(to=cumsum, i0)
205 origin_idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & local_unique_mask
206 tl.store(
207 origin_idx_ptr + (offset + local_unique_offset),
208 i0,
209 mask=origin_idx_mask,
210 )
211 else:
212 # local_unique: scatter_(to=cumsum, sorted_data)
213 tl.store(
214 local_unique_ptr + (offset + local_unique_offset), a, mask=local_unique_mask
215 )
217 # tile_sum
218 tile_sum_mask = (r == tile_size - 1) & (global_pid < global_ctas_num)
219 tile_sum = tl.where(tile_sum_mask & (global_pid == 0), cumsum + 1, cumsum)
220 tl.store(tile_sum_ptr + global_pid + tl.zeros_like(r), tile_sum, mask=tile_sum_mask)
223@libentry()
224@triton.jit
225def local_quick_unique_flat_kernel(
226 sorted_data_ptr: tl.tensor, # in
227 local_unique_ptr: tl.tensor,
228 origin_idx_ptr: tl.tensor,
229 tile_sum_ptr: tl.tensor, # out
230 global_ctas_num: int,
231 num_tasks: int,
232 tiles_per_cta: int,
233 tile_size: tl.constexpr,
234 return_counts: tl.constexpr,
235):
236 pid = tle.program_id(0)
237 ctas_num = tle.num_programs(0)
238 # grid-stride-loop style kernel
239 for j in range(0, tiles_per_cta):
240 global_pid = pid + j * ctas_num
241 local_quick_unique_flat_impl(
242 global_pid,
243 sorted_data_ptr, # in
244 local_unique_ptr,
245 origin_idx_ptr,
246 tile_sum_ptr, # out
247 global_ctas_num,
248 num_tasks,
249 tile_size,
250 return_counts,
251 )
254@triton.jit
255def global_quick_unique_flat_impl(
256 global_pid,
257 total,
258 local_unique_ptr: tl.tensor,
259 origin_idx_ptr: tl.tensor,
260 tile_sum_ptr: tl.tensor, # in
261 data_out_ptr: tl.tensor,
262 idx_ptr: tl.tensor, # out
263 ctas_num: int,
264 global_ctas_num: int,
265 next_power_global_ctas_num: tl.constexpr,
266 num_tasks: int,
267 tile_size: tl.constexpr,
268 return_counts: tl.constexpr,
269):
270 r = tl.arange(0, tile_size)
271 i0 = global_pid * tile_size + r
272 mask = i0 < num_tasks
274 # load tile_sum
275 p = tl.arange(0, next_power_global_ctas_num)
276 pre_tile_sum_mask = (
277 (p >= global_pid - ctas_num)
278 & (p < global_pid)
279 & (p >= 0)
280 & (p < global_ctas_num)
281 )
282 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
283 cur_tile_sum_mask = global_pid < global_ctas_num
284 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask)
286 # total
287 total += tl.sum(pre_tile_sum)
288 if global_pid == global_ctas_num - 1:
289 last_tile_sum_mask = p == global_pid
290 tl.store(tile_sum_ptr + p, total + cur_tile_sum, mask=last_tile_sum_mask)
292 # idx or data_out
293 tile_mask = r < cur_tile_sum
294 out_offset = total + r
295 if return_counts:
296 # move origin_idx to idx_ptr
297 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask)
298 tl.store(idx_ptr + out_offset, origin_idx, mask=tile_mask)
299 else:
300 # move local_unique to data_out_ptr
301 local_unique = tl.load(local_unique_ptr + i0, mask=mask)
302 tl.store(data_out_ptr + out_offset, local_unique, mask=tile_mask)
304 return total
307@libentry()
308@triton.jit
309def global_quick_unique_flat_kernel(
310 local_unique_ptr: tl.tensor,
311 origin_idx_ptr: tl.tensor,
312 tile_sum_ptr: tl.tensor, # in
313 data_out_ptr: tl.tensor,
314 idx_ptr: tl.tensor, # out
315 ctas_num: int,
316 global_ctas_num: int,
317 next_power_global_ctas_num: tl.constexpr,
318 num_tasks: int,
319 tiles_per_cta: int,
320 tile_size: tl.constexpr,
321 one_tile_per_cta: tl.constexpr,
322 return_counts: tl.constexpr,
323):
324 pid = tle.program_id(0)
325 ctas_num = tle.num_programs(0)
326 if one_tile_per_cta: # monolitic kernel style
327 global_quick_unique_flat_impl(
328 pid,
329 0,
330 local_unique_ptr,
331 origin_idx_ptr,
332 tile_sum_ptr, # in
333 data_out_ptr,
334 idx_ptr, # out
335 ctas_num,
336 global_ctas_num,
337 next_power_global_ctas_num,
338 num_tasks,
339 tile_size,
340 return_counts,
341 )
342 else: # grid-stride-loop style kernel
343 total = tl.zeros([1], dtype=tl.int64)
344 for j in range(0, tiles_per_cta):
345 global_pid = pid + j * ctas_num
346 total = global_quick_unique_flat_impl(
347 global_pid,
348 total,
349 local_unique_ptr,
350 origin_idx_ptr,
351 tile_sum_ptr, # in
352 data_out_ptr,
353 idx_ptr, # out
354 ctas_num,
355 global_ctas_num,
356 next_power_global_ctas_num,
357 num_tasks,
358 tile_size,
359 return_counts,
360 )
363def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool):
364 num_tasks = sorted_data.numel()
365 next_power_num_tasks = triton.next_power_of_2(num_tasks)
366 tile_size = min(8192, next_power_num_tasks)
367 global_ctas_num = triton.cdiv(num_tasks, tile_size)
368 if global_ctas_num <= 8192:
369 tile_size = max(
370 32, min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks)
371 )
372 global_ctas_num = triton.cdiv(num_tasks, tile_size)
373 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
374 ctas_num = global_ctas_num if global_ctas_num < 65536 else 2048
375 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
376 num_warps = 8 if tiles_per_cta == 1 else 32
377 grid = (ctas_num, 1, 1)
379 # allocate tensor
380 if return_counts:
381 local_unique = None
382 origin_idx = torch.empty_like(sorted_data, dtype=torch.int64)
383 idx = torch.empty_like(origin_idx)
384 else:
385 local_unique = torch.empty_like(sorted_data)
386 origin_idx = None
387 idx = None
388 counts = None
389 tile_sum = torch.empty(
390 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device
391 )
392 data_out = None
393 if not return_counts:
394 data_out = torch.empty_like(sorted_data)
396 # launch kernel
397 with torch_device_fn.device(sorted_data.device.index):
398 local_quick_unique_flat_kernel[grid](
399 sorted_data, # in
400 local_unique,
401 origin_idx,
402 tile_sum, # out
403 global_ctas_num,
404 num_tasks,
405 tiles_per_cta=tiles_per_cta,
406 tile_size=tile_size,
407 return_counts=return_counts,
408 num_warps=num_warps,
409 )
410 global_quick_unique_flat_kernel[grid](
411 local_unique,
412 origin_idx,
413 tile_sum, # in
414 data_out,
415 idx, # out
416 ctas_num,
417 global_ctas_num,
418 next_power_global_ctas_num,
419 num_tasks,
420 tiles_per_cta=tiles_per_cta,
421 tile_size=tile_size,
422 one_tile_per_cta=tiles_per_cta == 1,
423 return_counts=return_counts,
424 num_warps=num_warps,
425 )
426 out_size = tile_sum[-1].item()
427 if return_counts:
428 data_out = torch.empty(
429 (out_size,), dtype=sorted_data.dtype, device=sorted_data.device
430 )
431 idx = idx[:out_size]
432 counts = origin_idx[:out_size]
433 quick_output_flat_kernel[grid](
434 sorted_data,
435 idx,
436 num_tasks, # in
437 data_out,
438 counts, # out
439 out_size,
440 tiles_per_cta,
441 tile_size,
442 num_warps=num_warps,
443 )
445 if return_counts:
446 return data_out, None, counts
447 else:
448 return data_out[:out_size], None, None
451@triton.jit
452def local_ne_flat_impl(
453 global_pid,
454 sorted_data_ptr: tl.tensor, # in
455 ne_result_ptr: tl.tensor,
456 tile_sum_ptr: tl.tensor, # out
457 global_ctas_num: int,
458 num_tasks: int,
459 tile_size: tl.constexpr,
460):
461 r = tl.arange(0, tile_size)
462 i0 = global_pid * tile_size + r
463 mask = i0 < num_tasks
464 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
466 # load
467 a = tl.load(sorted_data_ptr + i0, mask=mask)
468 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
470 # compute
471 ne_result = tl.where(i0 > 0, a != b, 0)
473 # store ne_result
474 tl.store(ne_result_ptr + i0, ne_result, mask=mask)
476 # store tile_sum
477 tile_sum = tl.sum(ne_result)
478 tile_sum_mask = global_pid < global_ctas_num
479 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask)
482@libentry()
483@triton.jit
484def local_ne_flat_kernel(
485 sorted_data_ptr: tl.tensor, # in
486 ne_result_ptr: tl.tensor,
487 tile_sum_ptr: tl.tensor, # out
488 global_ctas_num: int,
489 num_tasks: int,
490 tiles_per_cta: int,
491 tile_size: tl.constexpr,
492):
493 pid = tle.program_id(0)
494 ctas_num = tle.num_programs(0)
495 # grid-stride-loop style kernel
496 for j in range(0, tiles_per_cta):
497 global_pid = pid + j * ctas_num
498 local_ne_flat_impl(
499 global_pid,
500 sorted_data_ptr, # in
501 ne_result_ptr,
502 tile_sum_ptr, # out
503 global_ctas_num,
504 num_tasks,
505 tile_size,
506 )
509@triton.jit
510def global_cumsum_flat_impl(
511 global_pid,
512 total,
513 ne_result_ptr: tl.tensor,
514 tile_sum_ptr: tl.tensor, # in
515 sorted_data_ptr: tl.tensor,
516 sorted_indices_ptr: tl.tensor, # in
517 data_out_ptr: tl.tensor,
518 inverse_indices_ptr: tl.tensor,
519 idx_ptr: tl.tensor, # out
520 ctas_num: tl.constexpr,
521 global_ctas_num: int,
522 next_power_global_ctas_num: tl.constexpr,
523 num_tasks: int,
524 tile_size: tl.constexpr,
525 return_counts: tl.constexpr,
526):
527 offset = global_pid * tile_size
528 r = tl.arange(0, tile_size)
529 i0 = offset + r
530 mask = i0 < num_tasks
532 # load sorted_data, sorted_indices
533 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask)
534 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
536 # load tile_sum
537 p = tl.arange(0, next_power_global_ctas_num)
538 pre_tile_sum_mask = (
539 (p >= global_pid - ctas_num)
540 & (p < global_pid)
541 & (p >= 0)
542 & (p < global_ctas_num)
543 )
544 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
546 # cumsum
547 total += tl.sum(pre_tile_sum)
548 ne_result = tl.load(ne_result_ptr + i0, mask=mask)
549 ne_result_i1 = ne_result.to(tl.int1)
550 ne_result = ne_result.to(tl.int32)
551 cumsum = tl.cumsum(ne_result)
553 # tile_sum
554 if global_pid == global_ctas_num - 1:
555 last_tile_sum_mask = i0 == num_tasks - 1
556 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum)
557 tl.store(
558 tile_sum_ptr + global_pid + tl.zeros_like(r),
559 tile_sum,
560 mask=last_tile_sum_mask,
561 )
562 cumsum += total
564 # data_out: scatter_(to=cumsum, sorted_data)
565 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask)
567 # inverse_indices: scatter_(to=sorted_indices, cumsum)
568 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)
570 # idx
571 if return_counts:
572 idx_mask = ((i0 == 0) | ne_result_i1) & mask
573 tl.store(idx_ptr + cumsum, i0, mask=idx_mask)
575 return total
578@libentry()
579@triton.jit
580def global_cumsum_flat_kernel(
581 ne_result_ptr: tl.tensor,
582 tile_sum_ptr: tl.tensor, # in
583 sorted_data_ptr: tl.tensor,
584 sorted_indices_ptr: tl.tensor, # in
585 data_out_ptr: tl.tensor,
586 inverse_indices_ptr: tl.tensor,
587 idx_ptr: tl.tensor, # out
588 ctas_num: int,
589 global_ctas_num: int,
590 next_power_global_ctas_num: tl.constexpr,
591 num_tasks: int,
592 tiles_per_cta: int,
593 tile_size: tl.constexpr,
594 one_tile_per_cta: tl.constexpr,
595 return_counts: tl.constexpr,
596):
597 pid = tle.program_id(0)
598 ctas_num = tle.num_programs(0)
599 if one_tile_per_cta: # monolitic kernel style
600 global_cumsum_flat_impl(
601 pid,
602 0,
603 ne_result_ptr,
604 tile_sum_ptr, # in
605 sorted_data_ptr,
606 sorted_indices_ptr, # in
607 data_out_ptr,
608 inverse_indices_ptr,
609 idx_ptr, # out
610 ctas_num,
611 global_ctas_num,
612 next_power_global_ctas_num,
613 num_tasks,
614 tile_size,
615 return_counts,
616 )
617 else: # grid-stride-loop style kernel
618 total = tl.zeros([1], dtype=tl.int64)
619 for j in range(0, tiles_per_cta):
620 global_pid = pid + j * ctas_num
621 total = global_cumsum_flat_impl(
622 global_pid,
623 total,
624 ne_result_ptr,
625 tile_sum_ptr, # in
626 sorted_data_ptr,
627 sorted_indices_ptr, # in
628 data_out_ptr,
629 inverse_indices_ptr,
630 idx_ptr, # out
631 ctas_num,
632 global_ctas_num,
633 next_power_global_ctas_num,
634 num_tasks,
635 tile_size,
636 return_counts,
637 )
640def sorted_indices_unique_flat(
641 sorted_data: torch.Tensor, sorted_indices: torch.Tensor, return_counts: bool
642):
643 num_tasks = sorted_data.numel()
644 next_power_num_tasks = triton.next_power_of_2(num_tasks)
645 tile_size = min(8192, next_power_num_tasks)
646 global_ctas_num = triton.cdiv(num_tasks, tile_size)
647 if global_ctas_num <= 8192:
648 min_tile_size = 512 if global_ctas_num > 32 else 256
649 tile_size = max(
650 min_tile_size,
651 min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks),
652 )
653 global_ctas_num = triton.cdiv(num_tasks, tile_size)
654 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
655 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192
656 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
657 num_warps = 8 if tiles_per_cta == 1 else 32
658 grid = (ctas_num, 1, 1)
660 # allocate tensor
661 ne_result = torch.empty_like(sorted_data, dtype=torch.bool)
662 tile_sum = torch.empty(
663 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device
664 )
665 data_out = torch.empty_like(sorted_data)
666 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
667 idx = None
668 if return_counts:
669 idx = torch.empty_like(inverse_indices)
671 # launch kernel
672 with torch_device_fn.device(sorted_data.device.index):
673 local_ne_flat_kernel[grid](
674 sorted_data, # in
675 ne_result,
676 tile_sum, # out
677 global_ctas_num,
678 num_tasks,
679 tiles_per_cta=tiles_per_cta,
680 tile_size=tile_size,
681 num_warps=num_warps,
682 )
683 global_cumsum_flat_kernel[grid](
684 ne_result,
685 tile_sum, # in
686 sorted_data,
687 sorted_indices, # in
688 data_out,
689 inverse_indices,
690 idx, # out
691 ctas_num,
692 global_ctas_num,
693 next_power_global_ctas_num,
694 num_tasks,
695 tiles_per_cta=tiles_per_cta,
696 tile_size=tile_size,
697 one_tile_per_cta=tiles_per_cta == 1,
698 return_counts=return_counts,
699 num_warps=num_warps,
700 )
701 out_size = tile_sum[-1].item() + 1
702 counts = None
703 if return_counts:
704 idx = idx[:out_size]
705 counts = torch.empty_like(idx)
706 output_counts_flat_kernel[grid](
707 idx,
708 num_tasks, # in
709 counts, # out
710 out_size,
711 tiles_per_cta,
712 tile_size,
713 num_warps=num_warps,
714 )
716 return data_out[:out_size], inverse_indices, counts
719def simple_unique_flat(
720 sorted_data: torch.Tensor,
721 sorted_indices: torch.Tensor,
722 return_inverse: bool,
723 return_counts: bool,
724):
725 num_tasks = sorted_data.numel()
726 grid = (1, 1, 1)
728 # allocate tensor
729 data_out = torch.empty_like(sorted_data)
730 if return_inverse:
731 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
732 else:
733 inverse_indices = None
734 if return_counts:
735 idx = torch.empty_like(sorted_data, dtype=torch.int64)
736 else:
737 idx = None
738 unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device)
740 # launch kernel
741 with torch_device_fn.device(sorted_data.device.index):
742 simple_unique_flat_kernel[grid](
743 sorted_data,
744 sorted_indices, # in
745 data_out,
746 inverse_indices,
747 idx,
748 unique_size, # out
749 return_inverse,
750 return_counts,
751 num_tasks,
752 tile_size=triton.next_power_of_2(num_tasks),
753 num_warps=8,
754 )
755 out_size = unique_size.item() + 1
756 counts = None
757 if return_counts:
758 idx = idx[:out_size]
759 counts = torch.empty_like(idx)
760 with torch_device_fn.device(sorted_data.device.index):
761 output_counts_flat_kernel[grid](
762 idx,
763 num_tasks, # in
764 counts, # out
765 num_tasks=out_size,
766 tiles_per_cta=1,
767 tile_size=triton.next_power_of_2(out_size),
768 num_warps=8,
769 )
770 return data_out[:out_size], inverse_indices, counts
773def _unique2(
774 in0: torch.Tensor,
775 sorted: bool = True,
776 return_inverse: bool = False,
777 return_counts: bool = False,
778):
779 logger.debug("GEMS SORT")
780 if in0.numel() <= 8192:
781 sorted_data, sorted_indices = torch.sort(in0.ravel())
782 data_out, inverse_indices, counts = simple_unique_flat(
783 sorted_data, sorted_indices, return_inverse, return_counts
784 )
785 elif return_inverse:
786 sorted_data, sorted_indices = torch.sort(in0.ravel())
787 data_out, inverse_indices, counts = sorted_indices_unique_flat(
788 sorted_data, sorted_indices, return_counts
789 )
790 else:
791 sorted_data, _ = torch.sort(in0.ravel())
792 data_out, inverse_indices, counts = sorted_quick_unique_flat(
793 sorted_data, return_counts
794 )
795 return (
796 data_out,
797 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0),
798 counts,
799 )