Coverage for src/flag_gems/runtime/backend/_mthreads/ops/isin.py: 0%
127 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import math
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.all import reduce_all
8from flag_gems.ops.any import reduce_any
9from flag_gems.ops.unique import _unique2
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.libentry import libentry
15def launch_arg(BLOCK_M, BLOCK_N, N, num_warps):
16 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps
19@triton.jit
20def isin_by_comparation_impl(
21 global_pid,
22 in0_ravel_ptr: tl.tensor,
23 in1_ravel_ptr: tl.tensor, # in
24 out_ptr: tl.tensor, # out
25 M: int, # num_tasks
26 N: int, # num_tasks_1
27 BLOCK_M: tl.constexpr, # tile_size
28 BLOCK_N: tl.constexpr, # tile_size_1
29 invert: tl.constexpr,
30):
31 row_off = global_pid * BLOCK_M
32 rows = row_off + tl.arange(0, BLOCK_M)[:, None]
33 row_mask = rows < M
34 out_ptr += rows
35 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)
36 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]
38 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
39 in0 = tl.load(in0_ravel_ptr, row_mask, other=0)
40 for col_off in range(0, N, BLOCK_N):
41 cols = col_off + tl.arange(0, BLOCK_N)[None, :]
42 col_mask = cols < N
43 mask = row_mask and col_mask
44 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)
45 block = tl.where(
46 mask,
47 tl.where(invert, block and (in0 != in1), block or (in0 == in1)),
48 invert,
49 )
50 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))
51 tl.store(out_ptr, out[:, None], row_mask)
54@libentry()
55@triton.jit
56def isin_by_comparation_kernel(
57 in0_ravel_ptr: tl.tensor,
58 in1_ravel_ptr: tl.tensor, # in
59 out_ptr: tl.tensor, # out
60 M: int, # num_tasks
61 N: int, # num_tasks_1
62 BLOCK_M: tl.constexpr, # tile_size
63 BLOCK_N: tl.constexpr, # tile_size_1
64 tiles_per_cta: int,
65 invert: tl.constexpr,
66):
67 pid = tle.program_id(0)
68 ctas_num = tle.num_programs(0)
69 # grid-stride-loop style kernel
70 for j in range(0, tiles_per_cta):
71 global_pid = pid + j * ctas_num
72 isin_by_comparation_impl(
73 global_pid,
74 in0_ravel_ptr,
75 in1_ravel_ptr, # in
76 out_ptr, # out
77 M,
78 N,
79 BLOCK_M,
80 BLOCK_N,
81 invert,
82 )
85def isin_by_comparation(
86 in0: torch.tensor,
87 in1: torch.tensor,
88 invert: bool,
89):
90 in0_ravel = in0.contiguous().ravel()
91 in1_ravel = in1.contiguous().ravel()
92 M = in0.numel()
93 N = in1.numel()
94 if M <= 1024:
95 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4)
96 elif M <= 3072:
97 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4)
98 elif M <= 6144:
99 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
100 elif M <= 9216:
101 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8)
102 else:
103 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
104 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
105 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
106 grid = (ctas_num,)
107 out = torch.empty_like(in0_ravel, dtype=torch.bool)
108 with torch_device_fn.device(in0_ravel.device.index):
109 isin_by_comparation_kernel[grid](
110 in0_ravel,
111 in1_ravel, # in
112 out, # out
113 M,
114 N,
115 BLOCK_M,
116 BLOCK_N,
117 tiles_per_cta=tiles_per_cta,
118 invert=invert,
119 num_warps=num_warps,
120 )
121 return out.view_as(in0)
124@triton.jit
125def isin_by_search_impl(
126 global_pid,
127 in0_ravel_ptr: tl.tensor,
128 in1_sorted_ptr: tl.tensor, # in
129 out_ptr: tl.tensor, # out
130 M: int, # num_tasks
131 N: int, # num_tasks_1
132 log_n: tl.constexpr,
133 BLOCK_M: tl.constexpr, # tile_size
134 invert: tl.constexpr,
135):
136 r = tl.arange(0, BLOCK_M)
137 i0 = global_pid * BLOCK_M + r
138 mask = i0 < M
140 # load in0_ravel
141 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)
143 # binary search: lower_bound
144 out = tl.zeros_like(r).to(tl.int1)
145 start = tl.zeros_like(r)
146 end = start + N
147 while_mask = start < end
148 for i in range(log_n):
149 mid = tl.where(while_mask, start + (end - start) // 2, 0)
150 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
151 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found
152 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
153 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
154 while_mask = start < end
156 # store out
157 tl.store(out_ptr + i0, not out if invert else out, mask=mask)
160@libentry()
161@triton.jit
162def isin_by_search_kernel(
163 in0_ravel_ptr: tl.tensor,
164 in1_sorted_ptr: tl.tensor, # in
165 out_ptr: tl.tensor, # out
166 M: int, # num_tasks
167 N: int, # num_tasks_1
168 log_n: tl.constexpr,
169 BLOCK_M: tl.constexpr, # tile_size
170 tiles_per_cta: int,
171 invert: tl.constexpr,
172):
173 pid = tle.program_id(0)
174 ctas_num = tle.num_programs(0)
175 # grid-stride-loop style kernel
176 for j in range(0, tiles_per_cta):
177 global_pid = pid + j * ctas_num
178 isin_by_search_impl(
179 global_pid,
180 in0_ravel_ptr,
181 in1_sorted_ptr, # in
182 out_ptr, # out
183 M,
184 N,
185 log_n,
186 BLOCK_M,
187 invert,
188 )
191def isin_by_search(
192 in0: torch.tensor,
193 in1: torch.tensor,
194 invert: bool,
195 unique_in0: bool,
196 unique_in1: bool,
197):
198 # unique or sort or ravel
199 if unique_in0:
200 in0_ravel, unique_order, _ = _unique2(
201 in0, sorted=True, return_inverse=True, return_counts=False
202 )
203 else:
204 in0_ravel = in0.contiguous().ravel()
205 if unique_in1:
206 in1_ravel, _, _ = _unique2(
207 in1, sorted=True, return_inverse=False, return_counts=False
208 )
209 else:
210 in1_ravel, _ = torch.sort(in1.ravel())
211 # launch kernel func
212 M = in0_ravel.numel()
213 N = in1_ravel.numel()
214 if M <= 1048576: # 2 ** 20 = 1024 * 1024
215 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8)
216 elif M <= 4194304: # 2 ** 22 = 1024 * 4096
217 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8)
218 elif M <= 8388608: # 2 ** 23 = 1024 * 8192
219 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 8)
220 elif M <= 268435456: # 2 ** 28 = 1024 * 262144
221 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 8)
222 else:
223 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 8)
224 log_n = int(math.log2(N)) + 1
225 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
226 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
227 grid = (ctas_num,)
228 out = torch.empty_like(in0_ravel, dtype=torch.bool)
229 with torch_device_fn.device(in0_ravel.device.index):
230 isin_by_search_kernel[grid](
231 in0_ravel,
232 in1_ravel, # in
233 out, # out
234 M,
235 N,
236 log_n,
237 BLOCK_M,
238 tiles_per_cta=tiles_per_cta,
239 invert=invert,
240 num_warps=num_warps,
241 )
242 if unique_in0:
243 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64))
244 return out.view_as(in0)
247def isin(
248 in0,
249 in1,
250 *,
251 assume_unique: bool = False,
252 invert: bool = False,
253) -> torch.Tensor:
254 if not torch.is_tensor(in0):
255 assert torch.is_tensor(in1)
256 in0 = torch.tensor(in0, device=in1.device)
257 elif not torch.is_tensor(in1):
258 assert torch.is_tensor(in0)
259 in1 = torch.tensor(in1, device=in0.device)
260 if in0.numel() == 0 or in1.numel() == 0:
261 return torch.zeros_like(in0, dtype=torch.bool)
262 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12
263 return isin_by_comparation(in0, in1, invert)
264 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096
265 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False)
266 else:
267 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)