Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/isin.py: 0%
126 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import math
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.all import reduce_all
8from flag_gems.ops.any import reduce_any
9from flag_gems.ops.unique import _unique2
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils.libentry import libentry
14def launch_arg(BLOCK_M, BLOCK_N, N, num_warps):
15 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps
18@triton.jit
19def isin_by_comparation_impl(
20 global_pid,
21 in0_ravel_ptr: tl.tensor,
22 in1_ravel_ptr: tl.tensor, # in
23 out_ptr: tl.tensor, # out
24 M: int, # num_tasks
25 N: int, # num_tasks_1
26 BLOCK_M: tl.constexpr, # tile_size
27 BLOCK_N: tl.constexpr, # tile_size_1
28 invert: tl.constexpr,
29):
30 row_off = global_pid * BLOCK_M
31 rows = row_off + tl.arange(0, BLOCK_M)[:, None]
32 row_mask = rows < M
33 out_ptr += rows
34 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)
35 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]
37 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
38 in0 = tl.load(in0_ravel_ptr, row_mask, other=0)
39 for col_off in range(0, N, BLOCK_N):
40 cols = col_off + tl.arange(0, BLOCK_N)[None, :]
41 col_mask = cols < N
42 mask = row_mask and col_mask
43 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)
44 block = tl.where(
45 mask,
46 tl.where(invert, block and (in0 != in1), block or (in0 == in1)),
47 invert,
48 )
49 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))
50 tl.store(out_ptr, out[:, None], row_mask)
53@libentry()
54@triton.jit
55def isin_by_comparation_kernel(
56 in0_ravel_ptr: tl.tensor,
57 in1_ravel_ptr: tl.tensor, # in
58 out_ptr: tl.tensor, # out
59 M: int, # num_tasks
60 N: int, # num_tasks_1
61 BLOCK_M: tl.constexpr, # tile_size
62 BLOCK_N: tl.constexpr, # tile_size_1
63 tiles_per_cta: int,
64 invert: tl.constexpr,
65):
66 pid = tl.program_id(0)
67 ctas_num = tl.num_programs(0)
68 # grid-stride-loop style kernel
69 for j in range(0, tiles_per_cta):
70 global_pid = pid + j * ctas_num
71 isin_by_comparation_impl(
72 global_pid,
73 in0_ravel_ptr,
74 in1_ravel_ptr, # in
75 out_ptr, # out
76 M,
77 N,
78 BLOCK_M,
79 BLOCK_N,
80 invert,
81 )
84def isin_by_comparation(
85 in0: torch.tensor,
86 in1: torch.tensor,
87 invert: bool,
88):
89 in0_ravel = in0.contiguous().ravel()
90 in1_ravel = in1.contiguous().ravel()
91 M = in0.numel()
92 N = in1.numel()
93 if M <= 1024:
94 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 1)
95 elif M <= 3072:
96 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 1)
97 elif M <= 6144:
98 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 1)
99 elif M <= 9216:
100 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 1)
101 else:
102 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 1)
103 ctas_num = min(16 // num_warps, triton.cdiv(M, BLOCK_M))
104 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
105 grid = (ctas_num,)
106 out = torch.empty_like(in0_ravel, dtype=torch.bool)
107 with torch_device_fn.device(in0_ravel.device.index):
108 isin_by_comparation_kernel[grid](
109 in0_ravel,
110 in1_ravel, # in
111 out, # out
112 M,
113 N,
114 BLOCK_M,
115 BLOCK_N,
116 tiles_per_cta=tiles_per_cta,
117 invert=invert,
118 num_warps=num_warps,
119 )
120 return out.view_as(in0)
123@triton.jit
124def isin_by_search_impl(
125 global_pid,
126 in0_ravel_ptr: tl.tensor,
127 in1_sorted_ptr: tl.tensor, # in
128 out_ptr: tl.tensor, # out
129 M: int, # num_tasks
130 N: int, # num_tasks_1
131 log_n: tl.constexpr,
132 BLOCK_M: tl.constexpr, # tile_size
133 invert: tl.constexpr,
134):
135 r = tl.arange(0, BLOCK_M)
136 i0 = global_pid * BLOCK_M + r
137 mask = i0 < M
139 # load in0_ravel
140 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)
142 # binary search: lower_bound
143 out = tl.zeros_like(r).to(tl.int1)
144 start = tl.zeros_like(r)
145 end = start + N
146 while_mask = start < end
147 for i in range(log_n):
148 mid = tl.where(while_mask, start + (end - start) // 2, 0)
149 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
150 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found
151 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
152 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
153 while_mask = start < end
155 # store out
156 tl.store(out_ptr + i0, not out if invert else out, mask=mask)
159@libentry()
160@triton.jit
161def isin_by_search_kernel(
162 in0_ravel_ptr: tl.tensor,
163 in1_sorted_ptr: tl.tensor, # in
164 out_ptr: tl.tensor, # out
165 M: int, # num_tasks
166 N: int, # num_tasks_1
167 log_n: tl.constexpr,
168 BLOCK_M: tl.constexpr, # tile_size
169 tiles_per_cta: int,
170 invert: tl.constexpr,
171):
172 pid = tl.program_id(0)
173 ctas_num = tl.num_programs(0)
174 # grid-stride-loop style kernel
175 for j in range(0, tiles_per_cta):
176 global_pid = pid + j * ctas_num
177 isin_by_search_impl(
178 global_pid,
179 in0_ravel_ptr,
180 in1_sorted_ptr, # in
181 out_ptr, # out
182 M,
183 N,
184 log_n,
185 BLOCK_M,
186 invert,
187 )
190def isin_by_search(
191 in0: torch.tensor,
192 in1: torch.tensor,
193 invert: bool,
194 unique_in0: bool,
195 unique_in1: bool,
196):
197 # unique or sort or ravel
198 if unique_in0:
199 in0_ravel, unique_order, _ = _unique2(
200 in0, sorted=True, return_inverse=True, return_counts=False
201 )
202 else:
203 in0_ravel = in0.contiguous().ravel()
204 if unique_in1:
205 in1_ravel, _, _ = _unique2(
206 in1, sorted=True, return_inverse=False, return_counts=False
207 )
208 else:
209 in1_ravel, _ = torch.sort(in1.ravel())
210 # launch kernel func
211 M = in0_ravel.numel()
212 N = in1_ravel.numel()
213 if M <= 1048576: # 2 ** 20 = 1024 * 1024
214 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 1)
215 elif M <= 4194304: # 2 ** 22 = 1024 * 4096
216 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 1)
217 elif M <= 8388608: # 2 ** 23 = 1024 * 8192
218 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 1)
219 elif M <= 268435456: # 2 ** 28 = 1024 * 262144
220 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 1)
221 else:
222 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 1)
223 log_n = int(math.log2(N)) + 1
224 ctas_num = min(16 // num_warps, triton.cdiv(M, BLOCK_M))
225 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
226 grid = (ctas_num,)
227 out = torch.empty_like(in0_ravel, dtype=torch.bool)
228 with torch_device_fn.device(in0_ravel.device.index):
229 isin_by_search_kernel[grid](
230 in0_ravel,
231 in1_ravel, # in
232 out, # out
233 M,
234 N,
235 log_n,
236 BLOCK_M,
237 tiles_per_cta=tiles_per_cta,
238 invert=invert,
239 num_warps=num_warps,
240 )
241 if unique_in0:
242 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64))
243 return out.view_as(in0)
246def isin(
247 in0,
248 in1,
249 *,
250 assume_unique: bool = False,
251 invert: bool = False,
252) -> torch.Tensor:
253 if not torch.is_tensor(in0):
254 assert torch.is_tensor(in1)
255 in0 = torch.tensor(in0, device=in1.device)
256 elif not torch.is_tensor(in1):
257 assert torch.is_tensor(in0)
258 in1 = torch.tensor(in1, device=in0.device)
259 if in0.numel() == 0 or in1.numel() == 0:
260 return torch.zeros_like(in0, dtype=torch.bool)
261 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12
262 return isin_by_comparation(in0, in1, invert)
263 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096
264 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False)
265 else:
266 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)