Coverage for src/flag_gems/runtime/backend/_cambricon/ops/unique.py: 0%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import torch
2import triton
3import triton.language as tl
5from flag_gems.runtime import torch_device_fn
6from flag_gems.utils.libentry import libentry
8from ..utils import TOTAL_CORE_NUM
11@libentry()
12@triton.autotune(
13 configs=[
14 triton.Config({"BLOCK_SIZE": 2**k}, num_stages=s, num_warps=1)
15 for k in range(11, 17, 1)
16 for s in [1, 3]
17 ],
18 key=[
19 "tile_size",
20 ],
21)
22@triton.jit
23def get_ne_kernel(
24 sorted_data_ptr: tl.tensor,
25 sorted_data_2: tl.tensor,
26 ne_out_ptr: tl.tensor,
27 tile_size: tl.constexpr,
28 BLOCK_SIZE: tl.constexpr,
29):
30 pid = tl.program_id(axis=0)
31 num_jobs = tl.num_programs(axis=0)
32 split_n = (tile_size + num_jobs - 1) // num_jobs
33 start_offset = pid * split_n
34 i0 = tl.arange(0, BLOCK_SIZE)
36 for i in range(0, split_n, BLOCK_SIZE):
37 offset = start_offset + i + i0
38 mask = offset < tile_size
39 a = tl.load(sorted_data_ptr + offset, mask=mask)
40 b = tl.load(sorted_data_2 + offset, mask=mask)
41 # ne
42 ne_result = (offset > 0) * (a != b)
43 tl.store(ne_out_ptr + offset, ne_result, mask=mask)
46@libentry()
47@triton.autotune(
48 configs=[
49 triton.Config({"BLOCK_SIZE": k}, num_stages=s, num_warps=1)
50 for k in [32, 256, 1024, 2048, 4096]
51 for s in [1, 3]
52 ],
53 key=[
54 "tile_size",
55 ],
56)
57@triton.jit
58def get_unique_out_kernel(
59 sorted_data_ptr: tl.tensor,
60 sorted_indices_ptr: tl.tensor, # in
61 ne_result_ptr: tl.tensor,
62 pre_sum_ptr: tl.tensor,
63 idx_ptr: tl.tensor,
64 data_out_ptr: tl.tensor,
65 inverse_indices_ptr: tl.tensor,
66 return_inverse: tl.constexpr,
67 return_counts: tl.constexpr,
68 tile_size: tl.constexpr,
69 BLOCK_SIZE: tl.constexpr,
70):
71 pid = tl.program_id(axis=0)
72 num_jobs = tl.num_programs(axis=0)
74 split_n = (tile_size + num_jobs - 1) // num_jobs
75 start_offset = pid * split_n
76 i0 = tl.arange(0, BLOCK_SIZE)
78 for i in range(0, split_n, BLOCK_SIZE):
79 offset = start_offset + i + i0
80 mask = offset < tile_size
81 sorted_data = tl.load(sorted_data_ptr + offset, mask=mask)
82 pre_sum_data = tl.load(pre_sum_ptr + offset, mask=mask)
84 # data_out: scatter_(to=pre_sum_data, sorted_data)
85 tl.store(data_out_ptr + pre_sum_data, sorted_data, mask=mask)
87 # inverse_indices: scatter_(to=sorted_indices, pre_sum_data)
88 if return_inverse:
89 sorted_indices = tl.load(sorted_indices_ptr + offset, mask=mask)
90 tl.store(inverse_indices_ptr + sorted_indices, pre_sum_data, mask=mask)
92 # idx: mark positions of unique values in idx_ptr
93 if return_counts:
94 ne_result = tl.load(ne_result_ptr + offset, mask=mask)
95 idx_mask = ((offset == 0) | ne_result.to(tl.int1)) & mask
96 tl.store(idx_ptr + pre_sum_data, offset, mask=idx_mask)
99@triton.autotune(
100 configs=[
101 triton.Config({"BLOCK_SIZE": 2**k}, num_stages=s, num_warps=1)
102 for k in range(7, 14, 1)
103 for s in [1, 3]
104 ],
105 key=[
106 "tile_size",
107 ],
108)
109@triton.jit
110def get_output_counts_kernel(
111 idx_ptr: tl.tensor,
112 idx_next_ptr: tl.tensor,
113 counts_ptr: tl.tensor, # out
114 tile_size: tl.constexpr,
115 BLOCK_SIZE: tl.constexpr,
116):
117 pid = tl.program_id(axis=0)
118 num_jobs = tl.num_programs(axis=0)
119 split_n = (tile_size + num_jobs - 1) // num_jobs
120 start_offset = pid * split_n
122 i0 = tl.arange(0, BLOCK_SIZE)
124 for i in range(0, split_n, BLOCK_SIZE):
125 offset = start_offset + i + i0
126 mask = offset < tile_size
127 # load idx
128 idx = tl.load(idx_ptr + offset, mask=mask)
129 # load idx_next
130 idx_next = tl.load(idx_next_ptr + offset, mask=mask)
131 # diff
132 counts = idx_next - idx
133 # store counts
134 tl.store(counts_ptr + offset, counts, mask=mask)
137def sorted_unique_flat(
138 sorted_data: torch.Tensor,
139 sorted_indices: torch.Tensor,
140 return_inverse: bool,
141 return_counts: bool,
142):
143 num_tasks = sorted_data.numel()
144 grid = lambda meta: (
145 min(triton.cdiv(num_tasks, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
146 )
148 # allocate tensor
149 ne_out = torch.empty_like(sorted_data, dtype=torch.bool)
150 data_out = torch.empty_like(sorted_data)
151 if return_inverse:
152 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
153 else:
154 inverse_indices = None
155 if return_counts:
156 idx = torch.empty_like(sorted_data, dtype=torch.int64)
157 else:
158 idx = None
159 sorted_data_2 = torch.empty_like(sorted_data)
160 sorted_data_2[1:] = sorted_data[:-1]
162 # launch kernel
163 with torch_device_fn.device(sorted_data.device.index):
164 get_ne_kernel[grid](
165 sorted_data,
166 sorted_data_2,
167 ne_out,
168 tile_size=num_tasks,
169 )
170 pre_sum = ne_out.cumsum(axis=0)
171 get_unique_out_kernel[grid](
172 sorted_data,
173 sorted_indices,
174 ne_out,
175 pre_sum,
176 idx,
177 data_out,
178 inverse_indices,
179 return_inverse,
180 return_counts,
181 tile_size=num_tasks,
182 )
184 out_size = pre_sum[-1].item() + 1
185 counts = None
186 if return_counts:
187 idx = idx[:out_size]
188 sorted_data_size = len(sorted_data)
189 idx_next = torch.roll(idx, -1)
190 idx_next[-1] = sorted_data_size
191 counts = torch.zeros_like(idx)
192 with torch_device_fn.device(sorted_data.device.index):
193 get_output_counts_kernel[grid](
194 idx,
195 idx_next,
196 counts, # out
197 tile_size=out_size,
198 )
199 return data_out[:out_size], inverse_indices, counts
202def _unique2(
203 in0: torch.Tensor,
204 sorted: bool = True,
205 return_inverse: bool = False,
206 return_counts: bool = False,
207):
208 sorted_data, sorted_indices = torch.sort(in0.ravel(), stable=False)
209 data_out, inverse_indices, counts = sorted_unique_flat(
210 sorted_data, sorted_indices, return_inverse, return_counts
211 )
212 return (
213 data_out,
214 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0),
215 counts,
216 )