Coverage for src/flag_gems/runtime/backend/_cambricon/ops/isin.py: 0%
127 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import math
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils.libentry import libentry
10from ..utils import MAX_GRID_SIZE_X
11from .all import reduce_all
12from .any import reduce_any
13from .unique import _unique2
16def launch_arg(BLOCK_M, BLOCK_N, N, num_warps):
17 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps
20@triton.jit
21def isin_by_comparation_impl(
22 global_pid,
23 in0_ravel_ptr: tl.tensor,
24 in1_ravel_ptr: tl.tensor, # in
25 out_ptr: tl.tensor, # out
26 M: int, # num_tasks
27 N: int, # num_tasks_1
28 BLOCK_M: tl.constexpr, # tile_size
29 BLOCK_N: tl.constexpr, # tile_size_1
30 invert: tl.constexpr,
31):
32 row_off = global_pid * BLOCK_M
33 rows = row_off + tl.arange(0, BLOCK_M)[:, None]
34 row_mask = rows < M
35 out_ptr += rows
36 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)
37 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]
39 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
40 in0 = tl.load(in0_ravel_ptr, row_mask, other=0)
41 for col_off in range(0, N, BLOCK_N):
42 cols = col_off + tl.arange(0, BLOCK_N)[None, :]
43 col_mask = cols < N
44 mask = row_mask and col_mask
45 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)
46 block = tl.where(
47 mask,
48 tl.where(invert, block and (in0 != in1), block or (in0 == in1)),
49 invert,
50 )
51 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))
52 tl.store(out_ptr, out[:, None], row_mask)
55@libentry()
56@triton.jit
57def isin_by_comparation_kernel(
58 in0_ravel_ptr: tl.tensor,
59 in1_ravel_ptr: tl.tensor, # in
60 out_ptr: tl.tensor, # out
61 M: int, # num_tasks
62 N: int, # num_tasks_1
63 BLOCK_M: tl.constexpr, # tile_size
64 BLOCK_N: tl.constexpr, # tile_size_1
65 tiles_per_cta: int,
66 invert: tl.constexpr,
67):
68 pid = tl.program_id(0)
69 ctas_num = tl.num_programs(0)
70 # grid-stride-loop style kernel
71 for j in range(0, tiles_per_cta):
72 global_pid = pid + j * ctas_num
73 isin_by_comparation_impl(
74 global_pid,
75 in0_ravel_ptr,
76 in1_ravel_ptr, # in
77 out_ptr, # out
78 M,
79 N,
80 BLOCK_M,
81 BLOCK_N,
82 invert,
83 )
86def isin_by_comparation(
87 in0: torch.tensor,
88 in1: torch.tensor,
89 invert: bool,
90):
91 in0_ravel = in0.contiguous().ravel()
92 in1_ravel = in1.contiguous().ravel()
93 M = in0.numel()
94 N = in1.numel()
95 if M <= 1024:
96 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 1)
97 elif M <= 3072:
98 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 1)
99 elif M <= 6144:
100 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 1)
101 elif M <= 9216:
102 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 1)
103 else:
104 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 1)
105 ctas_num = min(MAX_GRID_SIZE_X // num_warps, triton.cdiv(M, BLOCK_M))
106 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
107 grid = (ctas_num,)
108 out = torch.empty_like(in0_ravel, dtype=torch.bool)
109 with torch_device_fn.device(in0_ravel.device.index):
110 isin_by_comparation_kernel[grid](
111 in0_ravel,
112 in1_ravel, # in
113 out, # out
114 M,
115 N,
116 BLOCK_M,
117 BLOCK_N,
118 tiles_per_cta=tiles_per_cta,
119 invert=invert,
120 num_warps=num_warps,
121 )
122 return out.view_as(in0)
125@triton.jit
126def isin_by_search_impl(
127 global_pid,
128 in0_ravel_ptr: tl.tensor,
129 in1_sorted_ptr: tl.tensor, # in
130 out_ptr: tl.tensor, # out
131 M: int, # num_tasks
132 N: int, # num_tasks_1
133 log_n: tl.constexpr,
134 BLOCK_M: tl.constexpr, # tile_size
135 invert: tl.constexpr,
136):
137 r = tl.arange(0, BLOCK_M)
138 i0 = global_pid * BLOCK_M + r
139 mask = i0 < M
141 # load in0_ravel
142 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)
144 # binary search: lower_bound
145 out = tl.zeros_like(r).to(tl.int1)
146 start = tl.zeros_like(r)
147 end = start + N
148 while_mask = start < end
149 for i in range(log_n):
150 mid = tl.where(while_mask, start + (end - start) // 2, 0)
151 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
152 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found
153 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
154 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
155 while_mask = start < end
157 # store out
158 tl.store(out_ptr + i0, not out if invert else out, mask=mask)
161@libentry()
162@triton.jit
163def isin_by_search_kernel(
164 in0_ravel_ptr: tl.tensor,
165 in1_sorted_ptr: tl.tensor, # in
166 out_ptr: tl.tensor, # out
167 M: int, # num_tasks
168 N: int, # num_tasks_1
169 log_n: tl.constexpr,
170 BLOCK_M: tl.constexpr, # tile_size
171 tiles_per_cta: int,
172 invert: tl.constexpr,
173):
174 pid = tl.program_id(0)
175 ctas_num = tl.num_programs(0)
176 # grid-stride-loop style kernel
177 for j in range(0, tiles_per_cta):
178 global_pid = pid + j * ctas_num
179 isin_by_search_impl(
180 global_pid,
181 in0_ravel_ptr,
182 in1_sorted_ptr, # in
183 out_ptr, # out
184 M,
185 N,
186 log_n,
187 BLOCK_M,
188 invert,
189 )
192def isin_by_search(
193 in0: torch.tensor,
194 in1: torch.tensor,
195 invert: bool,
196 unique_in0: bool,
197 unique_in1: bool,
198):
199 # unique or sort or ravel
200 if unique_in0:
201 in0_ravel, unique_order, _ = _unique2(
202 in0, sorted=True, return_inverse=True, return_counts=False
203 )
204 else:
205 in0_ravel = in0.contiguous().ravel()
206 if unique_in1:
207 in1_ravel, _, _ = _unique2(
208 in1, sorted=True, return_inverse=False, return_counts=False
209 )
210 else:
211 in1_ravel, _ = torch.sort(in1.ravel())
212 # launch kernel func
213 M = in0_ravel.numel()
214 N = in1_ravel.numel()
215 if M <= 1048576: # 2 ** 20 = 1024 * 1024
216 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8)
217 elif M <= 4194304: # 2 ** 22 = 1024 * 4096
218 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8)
219 elif M <= 8388608: # 2 ** 23 = 1024 * 8192
220 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
221 elif M <= 268435456: # 2 ** 28 = 1024 * 262144
222 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32)
223 else:
224 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
225 log_n = int(math.log2(N)) + 1
226 ctas_num = min(MAX_GRID_SIZE_X // num_warps, triton.cdiv(M, BLOCK_M))
227 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
228 grid = (ctas_num,)
229 out = torch.empty_like(in0_ravel, dtype=torch.bool)
230 with torch_device_fn.device(in0_ravel.device.index):
231 isin_by_search_kernel[grid](
232 in0_ravel,
233 in1_ravel, # in
234 out, # out
235 M,
236 N,
237 log_n,
238 BLOCK_M,
239 tiles_per_cta=tiles_per_cta,
240 invert=invert,
241 num_warps=num_warps,
242 )
243 if unique_in0:
244 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64))
245 return out.view_as(in0)
248def isin(
249 in0,
250 in1,
251 *,
252 assume_unique: bool = False,
253 invert: bool = False,
254) -> torch.Tensor:
255 if not torch.is_tensor(in0):
256 assert torch.is_tensor(in1)
257 in0 = torch.tensor(in0, device=in1.device)
258 elif not torch.is_tensor(in1):
259 assert torch.is_tensor(in0)
260 in1 = torch.tensor(in1, device=in0.device)
261 if in0.numel() == 0 or in1.numel() == 0:
262 return torch.zeros_like(in0, dtype=torch.bool)
263 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12
264 return isin_by_comparation(in0, in1, invert)
265 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096
266 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False)
267 else:
268 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)