Coverage for src/flag_gems/ops/isin.py: 59%
130 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.ops.all import reduce_all
9from flag_gems.ops.any import reduce_any
10from flag_gems.ops.unique import _unique2
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import triton_lang_extension as tle
13from flag_gems.utils.libentry import libentry
15logger = logging.getLogger(__name__)
18def launch_arg(BLOCK_M, BLOCK_N, N, num_warps):
19 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps
22@triton.jit
23def isin_by_comparation_impl(
24 global_pid,
25 in0_ravel_ptr: tl.tensor,
26 in1_ravel_ptr: tl.tensor, # in
27 out_ptr: tl.tensor, # out
28 M: int, # num_tasks
29 N: int, # num_tasks_1
30 BLOCK_M: tl.constexpr, # tile_size
31 BLOCK_N: tl.constexpr, # tile_size_1
32 invert: tl.constexpr,
33):
34 row_off = global_pid * BLOCK_M
35 rows = row_off + tl.arange(0, BLOCK_M)[:, None]
36 row_mask = rows < M
37 out_ptr += rows
38 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)
39 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]
41 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
42 in0 = tl.load(in0_ravel_ptr, row_mask, other=0)
43 for col_off in range(0, N, BLOCK_N):
44 cols = col_off + tl.arange(0, BLOCK_N)[None, :]
45 col_mask = cols < N
46 mask = row_mask and col_mask
47 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)
48 block = tl.where(
49 mask,
50 tl.where(invert, block and (in0 != in1), block or (in0 == in1)),
51 invert,
52 )
53 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))
54 tl.store(out_ptr, out[:, None], row_mask)
57@libentry()
58@triton.jit
59def isin_by_comparation_kernel(
60 in0_ravel_ptr: tl.tensor,
61 in1_ravel_ptr: tl.tensor, # in
62 out_ptr: tl.tensor, # out
63 M: int, # num_tasks
64 N: int, # num_tasks_1
65 BLOCK_M: tl.constexpr, # tile_size
66 BLOCK_N: tl.constexpr, # tile_size_1
67 tiles_per_cta: int,
68 invert: tl.constexpr,
69):
70 pid = tle.program_id(0)
71 ctas_num = tle.num_programs(0)
72 # grid-stride-loop style kernel
73 for j in range(0, tiles_per_cta):
74 global_pid = pid + j * ctas_num
75 isin_by_comparation_impl(
76 global_pid,
77 in0_ravel_ptr,
78 in1_ravel_ptr, # in
79 out_ptr, # out
80 M,
81 N,
82 BLOCK_M,
83 BLOCK_N,
84 invert,
85 )
88def isin_by_comparation(
89 in0: torch.tensor,
90 in1: torch.tensor,
91 invert: bool,
92):
93 in0_ravel = in0.contiguous().ravel()
94 in1_ravel = in1.contiguous().ravel()
95 M = in0.numel()
96 N = in1.numel()
97 if M <= 1024:
98 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4)
99 elif M <= 3072:
100 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4)
101 elif M <= 6144:
102 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
103 elif M <= 9216:
104 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8)
105 else:
106 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
107 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
108 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
109 grid = (ctas_num,)
110 out = torch.empty_like(in0_ravel, dtype=torch.bool)
111 with torch_device_fn.device(in0_ravel.device.index):
112 isin_by_comparation_kernel[grid](
113 in0_ravel,
114 in1_ravel, # in
115 out, # out
116 M,
117 N,
118 BLOCK_M,
119 BLOCK_N,
120 tiles_per_cta=tiles_per_cta,
121 invert=invert,
122 num_warps=num_warps,
123 )
124 return out.view_as(in0)
127@triton.jit
128def isin_by_search_impl(
129 global_pid,
130 in0_ravel_ptr: tl.tensor,
131 in1_sorted_ptr: tl.tensor, # in
132 out_ptr: tl.tensor, # out
133 M: int, # num_tasks
134 N: int, # num_tasks_1
135 log_n: tl.constexpr,
136 BLOCK_M: tl.constexpr, # tile_size
137 invert: tl.constexpr,
138):
139 r = tl.arange(0, BLOCK_M)
140 i0 = global_pid * BLOCK_M + r
141 mask = i0 < M
143 # load in0_ravel
144 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)
146 # binary search: lower_bound
147 out = tl.zeros_like(r).to(tl.int1)
148 start = tl.zeros_like(r)
149 end = start + N
150 while_mask = start < end
151 for i in range(log_n):
152 mid = tl.where(while_mask, start + (end - start) // 2, 0)
153 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
154 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found
155 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
156 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
157 while_mask = start < end
159 # store out
160 tl.store(out_ptr + i0, not out if invert else out, mask=mask)
163@libentry()
164@triton.jit
165def isin_by_search_kernel(
166 in0_ravel_ptr: tl.tensor,
167 in1_sorted_ptr: tl.tensor, # in
168 out_ptr: tl.tensor, # out
169 M: int, # num_tasks
170 N: int, # num_tasks_1
171 log_n: tl.constexpr,
172 BLOCK_M: tl.constexpr, # tile_size
173 tiles_per_cta: int,
174 invert: tl.constexpr,
175):
176 pid = tle.program_id(0)
177 ctas_num = tle.num_programs(0)
178 # grid-stride-loop style kernel
179 for j in range(0, tiles_per_cta):
180 global_pid = pid + j * ctas_num
181 isin_by_search_impl(
182 global_pid,
183 in0_ravel_ptr,
184 in1_sorted_ptr, # in
185 out_ptr, # out
186 M,
187 N,
188 log_n,
189 BLOCK_M,
190 invert,
191 )
194def isin_by_search(
195 in0: torch.tensor,
196 in1: torch.tensor,
197 invert: bool,
198 unique_in0: bool,
199 unique_in1: bool,
200):
201 # unique or sort or ravel
202 if unique_in0:
203 in0_ravel, unique_order, _ = _unique2(
204 in0, sorted=True, return_inverse=True, return_counts=False
205 )
206 else:
207 in0_ravel = in0.contiguous().ravel()
208 if unique_in1:
209 in1_ravel, _, _ = _unique2(
210 in1, sorted=True, return_inverse=False, return_counts=False
211 )
212 else:
213 in1_ravel, _ = torch.sort(in1.ravel())
214 # launch kernel func
215 M = in0_ravel.numel()
216 N = in1_ravel.numel()
217 if M <= 1048576: # 2 ** 20 = 1024 * 1024
218 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8)
219 elif M <= 4194304: # 2 ** 22 = 1024 * 4096
220 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8)
221 elif M <= 8388608: # 2 ** 23 = 1024 * 8192
222 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
223 elif M <= 268435456: # 2 ** 28 = 1024 * 262144
224 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32)
225 else:
226 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
227 log_n = int(math.log2(N)) + 1
228 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
229 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
230 grid = (ctas_num,)
231 out = torch.empty_like(in0_ravel, dtype=torch.bool)
232 with torch_device_fn.device(in0_ravel.device.index):
233 isin_by_search_kernel[grid](
234 in0_ravel,
235 in1_ravel, # in
236 out, # out
237 M,
238 N,
239 log_n,
240 BLOCK_M,
241 tiles_per_cta=tiles_per_cta,
242 invert=invert,
243 num_warps=num_warps,
244 )
245 if unique_in0:
246 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64))
247 return out.view_as(in0)
250def isin(
251 in0,
252 in1,
253 *,
254 assume_unique: bool = False,
255 invert: bool = False,
256) -> torch.Tensor:
257 logger.debug("GEMS ISIN")
258 if not torch.is_tensor(in0):
259 assert torch.is_tensor(in1)
260 in0 = torch.tensor(in0, device=in1.device)
261 elif not torch.is_tensor(in1):
262 assert torch.is_tensor(in0)
263 in1 = torch.tensor(in1, device=in0.device)
264 if in0.numel() == 0 or in1.numel() == 0:
265 return torch.zeros_like(in0, dtype=torch.bool)
266 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12
267 return isin_by_comparation(in0, in1, invert)
268 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096
269 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False)
270 else:
271 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)