Coverage for src/flag_gems/runtime/backend/_ascend/ops/unique.py: 0%
314 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as tle
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14@libentry()
15@triton.jit
16def simple_unique_flat_kernel(
17 sorted_data_ptr: tl.tensor,
18 sorted_indices_ptr: tl.tensor, # in
19 data_out_ptr: tl.tensor,
20 inverse_indices_ptr: tl.tensor,
21 idx_ptr: tl.tensor,
22 unique_size_ptr: tl.tensor, # out
23 return_inverse: tl.constexpr,
24 return_counts: tl.constexpr,
25 num_tasks: int,
26 tile_size: tl.constexpr,
27):
28 i0 = tl.arange(0, tile_size)
29 mask = i0 < num_tasks
31 # load
32 a = tl.load(sorted_data_ptr + i0, mask=mask)
33 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
34 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
36 # ne & cumsum
37 ne_result = tl.where(i0 > 0, a != b, 0)
38 cumsum = tl.cumsum(ne_result)
40 # unique_size
41 unique_size_mask = i0 == tile_size - 1
42 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask)
44 # data_out: scatter_(to=cumsum, sorted_data)
45 tl.store(data_out_ptr + cumsum, a, mask=mask)
47 # inverse_indices: scatter_(to=sorted_indices, cumsum)
48 if return_inverse:
49 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
50 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)
52 # idx
53 if return_counts:
54 idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask
55 tl.store(idx_ptr + cumsum, i0, mask=idx_mask)
58@triton.jit
59def output_counts_flat_impl(
60 global_pid,
61 idx_ptr: tl.tensor,
62 origin_num_tasks: int, # in
63 counts_ptr: tl.tensor, # out
64 num_tasks: int,
65 tile_size: tl.constexpr,
66):
67 r = tl.arange(0, tile_size)
69 # load idx
70 i0 = global_pid * tile_size + r
71 mask = i0 < num_tasks
72 idx = tl.load(idx_ptr + i0, mask=mask)
74 # load idx_next
75 i0_next = i0 + 1
76 next_mask = i0_next < num_tasks
77 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
79 # diff
80 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
82 # store counts
83 tl.store(counts_ptr + i0, counts, mask=mask)
86@libentry()
87@triton.jit
88def output_counts_flat_kernel(
89 idx_ptr: tl.tensor,
90 origin_num_tasks: int, # in
91 counts_ptr: tl.tensor, # out
92 num_tasks: int,
93 tiles_per_cta: int,
94 tile_size: tl.constexpr,
95):
96 pid = tle.program_id(0)
97 ctas_num = tle.num_programs(0)
98 # grid-stride-loop style kernel
99 for j in range(0, tiles_per_cta):
100 global_pid = pid + j * ctas_num
101 output_counts_flat_impl(
102 global_pid,
103 idx_ptr,
104 origin_num_tasks, # in
105 counts_ptr, # out
106 num_tasks,
107 tile_size,
108 )
111@triton.jit
112def quick_output_flat_impl(
113 global_pid,
114 sorted_data_ptr: tl.tensor,
115 idx_ptr: tl.tensor,
116 origin_num_tasks: int, # in
117 data_out_ptr: tl.tensor,
118 counts_ptr: tl.tensor, # out
119 num_tasks: int,
120 tile_size: tl.constexpr,
121):
122 r = tl.arange(0, tile_size)
124 # load idx
125 i0 = global_pid * tile_size + r
126 mask = i0 < num_tasks
127 idx = tl.load(idx_ptr + i0, mask=mask)
129 # load idx_next
130 i0_next = i0 + 1
131 next_mask = i0_next < num_tasks
132 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
134 # diff
135 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
137 # store counts
138 tl.store(counts_ptr + i0, counts, mask=mask)
140 # data_out: gather(sorted_data, from=idx)
141 sorted_data = tl.load(sorted_data_ptr + idx, mask=mask)
142 tl.store(data_out_ptr + i0, sorted_data, mask=mask)
145@libentry()
146@triton.jit
147def quick_output_flat_kernel(
148 sorted_data_ptr: tl.tensor,
149 idx_ptr: tl.tensor,
150 origin_num_tasks: int, # in
151 data_out_ptr: tl.tensor,
152 counts_ptr: tl.tensor, # out
153 num_tasks: int,
154 tiles_per_cta: int,
155 tile_size: tl.constexpr,
156):
157 pid = tle.program_id(0)
158 ctas_num = tle.num_programs(0)
159 # grid-stride-loop style kernel
160 for j in range(0, tiles_per_cta):
161 global_pid = pid + j * ctas_num
162 quick_output_flat_impl(
163 global_pid,
164 sorted_data_ptr,
165 idx_ptr,
166 origin_num_tasks, # in
167 data_out_ptr,
168 counts_ptr, # out
169 num_tasks,
170 tile_size,
171 )
174@triton.jit
175def local_quick_unique_flat_impl(
176 global_pid,
177 sorted_data_ptr: tl.tensor, # in
178 local_unique_ptr: tl.tensor,
179 origin_idx_ptr: tl.tensor,
180 tile_sum_ptr: tl.tensor, # out
181 global_ctas_num: int,
182 num_tasks: int,
183 tile_size: tl.constexpr,
184 return_counts: tl.constexpr,
185):
186 offset = global_pid * tile_size
187 r = tl.arange(0, tile_size)
188 i0 = offset + r
189 mask = i0 < num_tasks
191 # load
192 a = tl.load(sorted_data_ptr + i0, mask=mask, other=0)
193 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
194 b = tl.load(sorted_data_ptr + i0_prev, mask=mask, other=0)
196 # ne & cumsum
197 # 对于 i0=0 的位置(第一个元素),ne_result 应该是 1(它是第一个唯一值)
198 # 对于其他位置,ne_result = (a != b)
199 ne_result = tl.where(i0 > 0, a != b, 1)
200 ne_result = tl.where(mask, ne_result, 0) # 只保留有效位置
202 cumsum = tl.cumsum(ne_result)
204 # 对于第一个唯一值(i0=0),cumsum=1,所以索引是 0(cumsum-1)
205 # 对于其他唯一值,cumsum 递增
206 local_unique_offset = cumsum - 1 # cumsum 从 1 开始,所以减 1 得到从 0 开始的索引
207 local_unique_mask = mask
209 if return_counts:
210 # origin_idx: 只在唯一值位置存储
211 origin_idx_mask = ne_result.to(tl.int1) & local_unique_mask
212 tl.store(
213 origin_idx_ptr + (offset + local_unique_offset),
214 i0,
215 mask=origin_idx_mask,
216 )
217 else:
218 # local_unique: 只在唯一值位置存储
219 store_mask = ne_result.to(tl.int1) & local_unique_mask
220 tl.store(local_unique_ptr + (offset + local_unique_offset), a, mask=store_mask)
222 # tile_sum - 获取最后一个有效位置的 cumsum 值
223 valid_cumsum = tl.where(mask, cumsum, 0)
224 last_cumsum = tl.max(valid_cumsum)
226 # 直接使用 last_cumsum,不需要特殊处理第一个 tile
227 if global_pid < global_ctas_num:
228 tl.store(tile_sum_ptr + global_pid, last_cumsum)
231@libentry()
232@triton.jit
233def local_quick_unique_flat_kernel(
234 sorted_data_ptr: tl.tensor, # in
235 local_unique_ptr: tl.tensor,
236 origin_idx_ptr: tl.tensor,
237 tile_sum_ptr: tl.tensor, # out
238 global_ctas_num: int,
239 num_tasks: int,
240 tiles_per_cta: int,
241 tile_size: tl.constexpr,
242 return_counts: tl.constexpr,
243):
244 pid = tle.program_id(0)
245 ctas_num = tle.num_programs(0)
246 # grid-stride-loop style kernel
247 for j in range(0, tiles_per_cta):
248 global_pid = pid + j * ctas_num
249 local_quick_unique_flat_impl(
250 global_pid,
251 sorted_data_ptr, # in
252 local_unique_ptr,
253 origin_idx_ptr,
254 tile_sum_ptr, # out
255 global_ctas_num,
256 num_tasks,
257 tile_size,
258 return_counts,
259 )
262@triton.jit
263def global_quick_unique_flat_impl(
264 global_pid,
265 total,
266 local_unique_ptr: tl.tensor,
267 origin_idx_ptr: tl.tensor,
268 tile_sum_ptr: tl.tensor, # in
269 data_out_ptr: tl.tensor,
270 idx_ptr: tl.tensor, # out
271 ctas_num: int,
272 global_ctas_num: int,
273 next_power_global_ctas_num: tl.constexpr,
274 num_tasks: int,
275 tile_size: tl.constexpr,
276 return_counts: tl.constexpr,
277 CHUNK_SIZE: tl.constexpr, # 每个块的大小
278 MAX_CHUNKS: tl.constexpr, # 最大块数
279):
280 r = tl.arange(0, tile_size)
281 i0 = global_pid * tile_size + r
282 mask = i0 < num_tasks
284 # load tile_sum - 使用分块处理避免UB overflow
285 start_idx = tl.maximum(global_pid - ctas_num, 0)
286 end_idx = tl.minimum(global_pid, global_ctas_num)
288 # 分块累加 pre_tile_sum
289 total_sum = 0
290 total_sum = total_sum.to(tl.int64)
291 for chunk_id in range(MAX_CHUNKS):
292 chunk_start = start_idx + chunk_id * CHUNK_SIZE
294 # 只有当这个chunk在有效范围内时才处理
295 if chunk_start < end_idx:
296 p = tl.arange(0, CHUNK_SIZE)
297 p_idx = chunk_start + p
299 # 计算mask:需要确保索引在 [start_idx, end_idx) 范围内
300 pre_tile_sum_mask = (
301 (p_idx < end_idx) & (p_idx >= start_idx) & (p_idx < global_ctas_num)
302 )
304 pre_tile_sum = tl.load(
305 tile_sum_ptr + p_idx, mask=pre_tile_sum_mask, other=0
306 )
307 total_sum += tl.sum(pre_tile_sum)
309 cur_tile_sum_mask = global_pid < global_ctas_num
310 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask, other=0)
312 # total
313 total += total_sum
315 # tile_sum 存储
316 if global_pid == global_ctas_num - 1:
317 tl.store(tile_sum_ptr + global_pid, total + cur_tile_sum)
319 # idx or data_out
320 tile_mask = r < cur_tile_sum
321 out_offset = total + r
323 if return_counts:
324 # move origin_idx to idx_ptr
325 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask, other=0)
326 tl.store(idx_ptr + out_offset, origin_idx, mask=tile_mask)
327 else:
328 # move local_unique to data_out_ptr
329 local_unique = tl.load(local_unique_ptr + i0, mask=mask, other=0)
330 tl.store(data_out_ptr + out_offset, local_unique, mask=tile_mask)
332 return total
335@libentry()
336@triton.jit
337def global_quick_unique_flat_kernel(
338 local_unique_ptr: tl.tensor,
339 origin_idx_ptr: tl.tensor,
340 tile_sum_ptr: tl.tensor, # in
341 data_out_ptr: tl.tensor,
342 idx_ptr: tl.tensor, # out
343 ctas_num: int,
344 global_ctas_num: int,
345 next_power_global_ctas_num: tl.constexpr,
346 num_tasks: int,
347 tiles_per_cta: int,
348 tile_size: tl.constexpr,
349 one_tile_per_cta: tl.constexpr,
350 return_counts: tl.constexpr,
351):
352 pid = tle.program_id(0)
353 ctas_num = tle.num_programs(0)
355 # 分块处理参数
356 CHUNK_SIZE: tl.constexpr = 2048 # 每块处理2048个元素
357 MAX_CHUNKS: tl.constexpr = 32 # 最多32块 (2048 * 32 = 65536)
359 if one_tile_per_cta:
360 # monolitic kernel style
361 global_quick_unique_flat_impl(
362 pid,
363 0,
364 local_unique_ptr,
365 origin_idx_ptr,
366 tile_sum_ptr, # in
367 data_out_ptr,
368 idx_ptr, # out
369 ctas_num,
370 global_ctas_num,
371 next_power_global_ctas_num,
372 num_tasks,
373 tile_size,
374 return_counts,
375 CHUNK_SIZE,
376 MAX_CHUNKS,
377 )
378 else:
379 # grid-stride-loop style kernel
380 total = tl.zeros([1], dtype=tl.int64)
381 for j in range(0, tiles_per_cta):
382 global_pid = pid + j * ctas_num
383 total = global_quick_unique_flat_impl(
384 global_pid,
385 total,
386 local_unique_ptr,
387 origin_idx_ptr,
388 tile_sum_ptr, # in
389 data_out_ptr,
390 idx_ptr, # out
391 ctas_num,
392 global_ctas_num,
393 next_power_global_ctas_num,
394 num_tasks,
395 tile_size,
396 return_counts,
397 CHUNK_SIZE,
398 MAX_CHUNKS,
399 )
402def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool):
403 num_tasks = sorted_data.numel()
404 next_power_num_tasks = triton.next_power_of_2(num_tasks)
405 tile_size = min(4096, next_power_num_tasks)
406 global_ctas_num = triton.cdiv(num_tasks, tile_size)
408 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
409 ctas_num = global_ctas_num if global_ctas_num < 65536 else 2048
410 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
411 num_warps = 8 if tiles_per_cta == 1 else 32
412 grid = (ctas_num, 1, 1)
414 # allocate tensor
415 if return_counts:
416 local_unique = None
417 origin_idx = torch.empty_like(sorted_data, dtype=torch.int64)
418 idx = torch.empty_like(origin_idx)
419 else:
420 local_unique = torch.empty_like(sorted_data)
421 origin_idx = None
422 idx = None
423 counts = None
424 tile_sum = torch.empty(
425 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device
426 )
427 data_out = None
428 if not return_counts:
429 data_out = torch.empty_like(sorted_data)
431 # launch kernel
432 with torch_device_fn.device(sorted_data.device.index):
433 local_quick_unique_flat_kernel[grid](
434 sorted_data, # in
435 local_unique,
436 origin_idx,
437 tile_sum, # out
438 global_ctas_num,
439 num_tasks,
440 tiles_per_cta=tiles_per_cta,
441 tile_size=tile_size,
442 return_counts=return_counts,
443 num_warps=num_warps,
444 )
445 global_quick_unique_flat_kernel[grid](
446 local_unique,
447 origin_idx,
448 tile_sum, # in
449 data_out,
450 idx, # out
451 ctas_num,
452 global_ctas_num,
453 next_power_global_ctas_num,
454 num_tasks,
455 tiles_per_cta=tiles_per_cta,
456 tile_size=tile_size,
457 one_tile_per_cta=tiles_per_cta == 1,
458 return_counts=return_counts,
459 num_warps=num_warps,
460 )
461 out_size = tile_sum[-1].item()
462 if return_counts:
463 data_out = torch.empty(
464 (out_size,), dtype=sorted_data.dtype, device=sorted_data.device
465 )
466 idx = idx[:out_size]
467 counts = origin_idx[:out_size]
468 quick_output_flat_kernel[grid](
469 sorted_data,
470 idx,
471 num_tasks, # in
472 data_out,
473 counts, # out
474 out_size,
475 tiles_per_cta,
476 tile_size,
477 num_warps=num_warps,
478 )
480 if return_counts:
481 return data_out, None, counts
482 else:
483 return data_out[:out_size], None, None
486@triton.jit
487def local_ne_flat_impl(
488 global_pid,
489 sorted_data_ptr: tl.tensor, # in
490 ne_result_ptr: tl.tensor,
491 tile_sum_ptr: tl.tensor, # out
492 global_ctas_num: int,
493 num_tasks: int,
494 tile_size: tl.constexpr,
495 BLOCK_SIZE_SUB: tl.constexpr, # 新增参数用于分块处理
496):
497 # 计算当前tile的起始位置
498 tile_start = global_pid * tile_size
500 # 计算子块数量
501 num_sub_blocks = triton.cdiv(tile_size, BLOCK_SIZE_SUB)
503 # 初始化tile累加和
504 tile_sum_acc = tl.zeros([], dtype=tl.int32)
506 # 按子块索引循环处理
507 for sub_block_idx in range(num_sub_blocks):
508 # 计算当前子块的起始位置
509 sub_block_start = tile_start + sub_block_idx * BLOCK_SIZE_SUB
511 # 创建子块索引
512 r = tl.arange(0, BLOCK_SIZE_SUB)
513 i0 = sub_block_start + r
515 # 计算mask,确保不越界
516 mask = (i0 < num_tasks) & (i0 >= 0)
517 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
519 # load数据
520 a = tl.load(sorted_data_ptr + i0, mask=mask, other=0)
521 b = tl.load(sorted_data_ptr + i0_prev, mask=mask, other=0)
523 # 计算不等式结果
524 # 特殊处理第一个元素(全局索引为0的情况)
525 ne_result = tl.where(i0 > 0, a != b, 0)
526 ne_result = tl.where(mask, ne_result, 0)
528 # 存储ne_result
529 tl.store(ne_result_ptr + i0, ne_result, mask=mask)
531 # 累加到tile_sum
532 sub_block_sum = tl.sum(ne_result)
533 tile_sum_acc += sub_block_sum
535 # 存储tile累加和
536 tile_sum_mask = global_pid < global_ctas_num
537 tl.store(tile_sum_ptr + global_pid, tile_sum_acc, mask=tile_sum_mask)
540@libentry()
541@triton.jit
542def local_ne_flat_kernel(
543 sorted_data_ptr: tl.tensor, # in
544 ne_result_ptr: tl.tensor,
545 tile_sum_ptr: tl.tensor, # out
546 global_ctas_num: int,
547 num_tasks: int,
548 tiles_per_cta: int,
549 tile_size: tl.constexpr,
550):
551 pid = tle.program_id(0)
552 ctas_num = tle.num_programs(0)
553 # grid-stride-loop style kernel
554 for j in range(0, tiles_per_cta):
555 global_pid = pid + j * ctas_num
556 local_ne_flat_impl(
557 global_pid,
558 sorted_data_ptr, # in
559 ne_result_ptr,
560 tile_sum_ptr, # out
561 global_ctas_num,
562 num_tasks,
563 tile_size,
564 BLOCK_SIZE_SUB=256,
565 )
568@triton.jit
569def global_cumsum_flat_impl(
570 global_pid,
571 total,
572 ne_result_ptr: tl.tensor,
573 tile_sum_ptr: tl.tensor, # in
574 sorted_data_ptr: tl.tensor,
575 sorted_indices_ptr: tl.tensor, # in
576 data_out_ptr: tl.tensor,
577 inverse_indices_ptr: tl.tensor,
578 idx_ptr: tl.tensor, # out
579 ctas_num: tl.constexpr,
580 global_ctas_num: int,
581 next_power_global_ctas_num: tl.constexpr,
582 num_tasks: int,
583 tile_size: tl.constexpr,
584 return_counts: tl.constexpr,
585 MAX_CTAS_NUM: tl.constexpr,
586 CHUNK_SIZE: tl.constexpr = 512,
587):
588 offset = global_pid * tile_size
589 r = tl.arange(0, tile_size)
590 i0 = offset + r
591 mask = i0 < num_tasks
593 # load sorted_data, sorted_indices
594 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask)
595 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
597 # 计算需要加载的tile_sum范围
598 start_idx = tl.maximum(global_pid - ctas_num, 0)
599 end_idx = tl.minimum(global_pid, global_ctas_num)
600 actual_load_size = end_idx - start_idx
601 actual_load_size = actual_load_size.to(tl.int64)
603 # 分块累加tile_sum,避免一次性分配过大的张量
604 chunk_sum = 0
605 chunk_sum = chunk_sum.to(tl.int64)
607 for chunk_id in range(tl.cdiv(MAX_CTAS_NUM, CHUNK_SIZE)):
608 # 计算当前chunk的范围
609 chunk_start = chunk_id * CHUNK_SIZE
610 chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, actual_load_size)
612 # 只在有效chunk范围内加载
613 if chunk_start < actual_load_size:
614 p = tl.arange(0, CHUNK_SIZE)
615 p_idx = start_idx + chunk_start + p
617 # 更精确的mask条件
618 pre_tile_sum_mask = (
619 (p < (chunk_end - chunk_start))
620 & (p_idx >= start_idx) # 当前chunk内有效
621 & (p_idx < end_idx)
622 & (p_idx >= 0)
623 & (p_idx < global_ctas_num)
624 )
626 pre_tile_sum = tl.load(
627 tile_sum_ptr + p_idx, mask=pre_tile_sum_mask, other=0
628 )
629 chunk_sum += tl.sum(pre_tile_sum)
631 # cumsum
632 total += chunk_sum
633 ne_result = tl.load(ne_result_ptr + i0, mask=mask)
634 ne_result_i1 = ne_result.to(tl.int1)
635 ne_result = ne_result.to(tl.int32)
636 cumsum = tl.cumsum(ne_result)
638 # tile_sum
639 if global_pid == global_ctas_num - 1:
640 last_tile_sum_mask = i0 == num_tasks - 1
641 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum)
642 tl.store(
643 tile_sum_ptr + global_pid + tl.zeros_like(r),
644 tile_sum,
645 mask=last_tile_sum_mask,
646 )
647 cumsum += total
649 # data_out: scatter_(to=cumsum, sorted_data)
650 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask)
652 # inverse_indices: scatter_(to=sorted_indices, cumsum)
653 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)
655 # idx
656 if return_counts:
657 idx_mask = ((i0 == 0) | ne_result_i1) & mask
658 tl.store(idx_ptr + cumsum, i0, mask=idx_mask)
660 return total
663@libentry()
664@triton.jit
665def global_cumsum_flat_kernel(
666 ne_result_ptr: tl.tensor,
667 tile_sum_ptr: tl.tensor, # in
668 sorted_data_ptr: tl.tensor,
669 sorted_indices_ptr: tl.tensor, # in
670 data_out_ptr: tl.tensor,
671 inverse_indices_ptr: tl.tensor,
672 idx_ptr: tl.tensor, # out
673 ctas_num: int,
674 global_ctas_num: int,
675 next_power_global_ctas_num: tl.constexpr,
676 num_tasks: int,
677 tiles_per_cta: int,
678 tile_size: tl.constexpr,
679 one_tile_per_cta: tl.constexpr,
680 return_counts: tl.constexpr,
681):
682 pid = tle.program_id(0)
683 ctas_num = tle.num_programs(0)
684 MAX_CTAS_NUM: tl.constexpr = 65536
686 if one_tile_per_cta: # monolitic kernel style
687 global_cumsum_flat_impl(
688 pid,
689 0,
690 ne_result_ptr,
691 tile_sum_ptr, # in
692 sorted_data_ptr,
693 sorted_indices_ptr, # in
694 data_out_ptr,
695 inverse_indices_ptr,
696 idx_ptr, # out
697 ctas_num,
698 global_ctas_num,
699 next_power_global_ctas_num,
700 num_tasks,
701 tile_size,
702 return_counts,
703 MAX_CTAS_NUM,
704 )
705 else: # grid-stride-loop style kernel
706 total = tl.zeros([1], dtype=tl.int64)
707 for j in range(0, tiles_per_cta):
708 global_pid = pid + j * ctas_num
709 total = global_cumsum_flat_impl(
710 global_pid,
711 total,
712 ne_result_ptr,
713 tile_sum_ptr, # in
714 sorted_data_ptr,
715 sorted_indices_ptr, # in
716 data_out_ptr,
717 inverse_indices_ptr,
718 idx_ptr, # out
719 ctas_num,
720 global_ctas_num,
721 next_power_global_ctas_num,
722 num_tasks,
723 tile_size,
724 return_counts,
725 MAX_CTAS_NUM,
726 )
729def sorted_indices_unique_flat(
730 sorted_data: torch.Tensor, sorted_indices: torch.Tensor, return_counts: bool
731):
732 num_tasks = sorted_data.numel()
733 next_power_num_tasks = triton.next_power_of_2(num_tasks)
734 if num_tasks >= 167772160:
735 tile_size = 4096
736 else:
737 tile_size = min(2048, next_power_num_tasks)
738 global_ctas_num = triton.cdiv(num_tasks, tile_size)
739 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
740 ctas_num = global_ctas_num if global_ctas_num < 65536 else 8192
741 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
742 grid = (ctas_num, 1, 1)
743 # allocate tensor
744 ne_result = torch.empty_like(sorted_data, dtype=torch.bool)
745 tile_sum = torch.empty(
746 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device
747 )
748 data_out = torch.empty_like(sorted_data)
749 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
750 idx = None
751 if return_counts:
752 idx = torch.empty_like(inverse_indices)
753 # launch kernel
754 with torch_device_fn.device(sorted_data.device.index):
755 local_ne_flat_kernel[grid](
756 sorted_data, # in
757 ne_result,
758 tile_sum, # out
759 global_ctas_num,
760 num_tasks,
761 tiles_per_cta=tiles_per_cta,
762 tile_size=tile_size,
763 )
764 global_cumsum_flat_kernel[grid](
765 ne_result,
766 tile_sum, # in
767 sorted_data,
768 sorted_indices, # in
769 data_out,
770 inverse_indices,
771 idx, # out
772 ctas_num,
773 global_ctas_num,
774 next_power_global_ctas_num,
775 num_tasks,
776 tiles_per_cta=tiles_per_cta,
777 tile_size=tile_size,
778 one_tile_per_cta=tiles_per_cta == 1,
779 return_counts=return_counts,
780 )
781 out_size = tile_sum[-1].item() + 1
782 counts = None
783 if return_counts:
784 idx = idx[:out_size]
785 counts = torch.empty_like(idx)
786 output_counts_flat_kernel[grid](
787 idx,
788 num_tasks, # in
789 counts, # out
790 out_size,
791 tiles_per_cta,
792 tile_size,
793 )
794 return data_out[:out_size], inverse_indices, counts
797def simple_unique_flat(
798 sorted_data: torch.Tensor,
799 sorted_indices: torch.Tensor,
800 return_inverse: bool,
801 return_counts: bool,
802):
803 num_tasks = sorted_data.numel()
804 grid = (1, 1, 1)
806 # allocate tensor
807 data_out = torch.empty_like(sorted_data)
808 if return_inverse:
809 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
810 else:
811 inverse_indices = None
812 if return_counts:
813 idx = torch.empty_like(sorted_data, dtype=torch.int64)
814 else:
815 idx = None
816 unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device)
818 # launch kernel
819 with torch_device_fn.device(sorted_data.device.index):
820 simple_unique_flat_kernel[grid](
821 sorted_data,
822 sorted_indices, # in
823 data_out,
824 inverse_indices,
825 idx,
826 unique_size, # out
827 return_inverse,
828 return_counts,
829 num_tasks,
830 tile_size=triton.next_power_of_2(num_tasks),
831 num_warps=8,
832 )
833 out_size = unique_size.item() + 1
834 counts = None
835 if return_counts:
836 idx = idx[:out_size]
837 counts = torch.empty_like(idx)
838 with torch_device_fn.device(sorted_data.device.index):
839 output_counts_flat_kernel[grid](
840 idx,
841 num_tasks, # in
842 counts, # out
843 num_tasks=out_size,
844 tiles_per_cta=1,
845 tile_size=triton.next_power_of_2(out_size),
846 num_warps=8,
847 )
848 return data_out[:out_size], inverse_indices, counts
851def _unique2(
852 in0: torch.Tensor,
853 sorted: bool = True,
854 return_inverse: bool = False,
855 return_counts: bool = False,
856):
857 logger.debug("GEMS_ASCEND _UNIQUE2")
858 if in0.numel() <= 8192:
859 sorted_data, sorted_indices = torch.sort(in0.ravel())
860 data_out, inverse_indices, counts = simple_unique_flat(
861 sorted_data, sorted_indices, return_inverse, return_counts
862 )
863 elif return_inverse:
864 sorted_data, sorted_indices = torch.sort(in0.ravel())
865 data_out, inverse_indices, counts = sorted_indices_unique_flat(
866 sorted_data, sorted_indices, return_counts
867 )
868 else:
869 sorted_data, _ = torch.sort(in0.ravel())
870 data_out, inverse_indices, counts = sorted_quick_unique_flat(
871 sorted_data, return_counts
872 )
873 return (
874 data_out,
875 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0),
876 counts,
877 )