Coverage for src/flag_gems/runtime/backend/_hygon/ops/isin.py: 0%
132 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.libentry import libentry
12from .all import reduce_all
13from .any import reduce_any
14from .unique import _unique2
16logger = logging.getLogger(__name__)
19def launch_arg(BLOCK_M, BLOCK_N, N, num_warps):
20 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps
23@triton.jit
24def isin_by_comparation_impl(
25 global_pid,
26 in0_ravel_ptr: tl.tensor,
27 in1_ravel_ptr: tl.tensor, # in
28 out_ptr: tl.tensor, # out
29 M: int, # num_tasks
30 N: int, # num_tasks_1
31 BLOCK_M: tl.constexpr, # tile_size
32 BLOCK_N: tl.constexpr, # tile_size_1
33 invert: tl.constexpr,
34):
35 row_off = global_pid * BLOCK_M
36 rows = row_off + tl.arange(0, BLOCK_M)[:, None]
37 row_mask = rows < M
38 out_ptr += rows
39 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)
40 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]
42 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
43 in0 = tl.load(in0_ravel_ptr, row_mask, other=0)
44 for col_off in range(0, N, BLOCK_N):
45 cols = col_off + tl.arange(0, BLOCK_N)[None, :]
46 col_mask = cols < N
47 mask = row_mask and col_mask
48 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)
49 block = tl.where(
50 mask,
51 tl.where(invert, block and (in0 != in1), block or (in0 == in1)),
52 invert,
53 )
54 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))
55 tl.store(out_ptr, out[:, None], row_mask)
58@libentry()
59@triton.jit
60def isin_by_comparation_kernel(
61 in0_ravel_ptr: tl.tensor,
62 in1_ravel_ptr: tl.tensor, # in
63 out_ptr: tl.tensor, # out
64 M: int, # num_tasks
65 N: int, # num_tasks_1
66 BLOCK_M: tl.constexpr, # tile_size
67 BLOCK_N: tl.constexpr, # tile_size_1
68 tiles_per_cta: int,
69 invert: tl.constexpr,
70):
71 pid = tle.program_id(0)
72 ctas_num = tle.num_programs(0)
73 # grid-stride-loop style kernel
74 for j in range(0, tiles_per_cta):
75 global_pid = pid + j * ctas_num
76 isin_by_comparation_impl(
77 global_pid,
78 in0_ravel_ptr,
79 in1_ravel_ptr, # in
80 out_ptr, # out
81 M,
82 N,
83 BLOCK_M,
84 BLOCK_N,
85 invert,
86 )
89def isin_by_comparation(
90 in0: torch.tensor,
91 in1: torch.tensor,
92 invert: bool,
93):
94 in0_ravel = in0.contiguous().ravel()
95 in1_ravel = in1.contiguous().ravel()
96 M = in0.numel()
97 N = in1.numel()
98 if M <= 1024:
99 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4)
100 elif M <= 3072:
101 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4)
102 elif M <= 6144:
103 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
104 # elif M <= 9216:
105 # BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8)
106 else:
107 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
108 if torch.version.hip is not None:
109 num_warps = min(16, num_warps)
110 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
111 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
112 grid = (ctas_num,)
113 out = torch.empty_like(in0_ravel, dtype=torch.bool)
114 with torch_device_fn.device(in0_ravel.device.index):
115 isin_by_comparation_kernel[grid](
116 in0_ravel,
117 in1_ravel, # in
118 out, # out
119 M,
120 N,
121 BLOCK_M,
122 BLOCK_N,
123 tiles_per_cta=tiles_per_cta,
124 invert=invert,
125 num_warps=num_warps,
126 )
127 return out.view_as(in0)
130@triton.jit
131def isin_by_search_impl(
132 global_pid,
133 in0_ravel_ptr: tl.tensor,
134 in1_sorted_ptr: tl.tensor, # in
135 out_ptr: tl.tensor, # out
136 M: int, # num_tasks
137 N: int, # num_tasks_1
138 log_n: tl.constexpr,
139 BLOCK_M: tl.constexpr, # tile_size
140 invert: tl.constexpr,
141):
142 r = tl.arange(0, BLOCK_M)
143 i0 = global_pid * BLOCK_M + r
144 mask = i0 < M
146 # load in0_ravel
147 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)
149 # binary search: lower_bound
150 out = tl.zeros_like(r).to(tl.int1)
151 start = tl.zeros_like(r)
152 end = start + N
153 while_mask = start < end
154 for i in range(log_n):
155 mid = tl.where(while_mask, start + (end - start) // 2, 0)
156 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
157 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found
158 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
159 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
160 while_mask = start < end
162 # store out
163 tl.store(out_ptr + i0, not out if invert else out, mask=mask)
166@libentry()
167@triton.jit
168def isin_by_search_kernel(
169 in0_ravel_ptr: tl.tensor,
170 in1_sorted_ptr: tl.tensor, # in
171 out_ptr: tl.tensor, # out
172 M: int, # num_tasks
173 N: int, # num_tasks_1
174 log_n: tl.constexpr,
175 BLOCK_M: tl.constexpr, # tile_size
176 tiles_per_cta: int,
177 invert: tl.constexpr,
178):
179 pid = tle.program_id(0)
180 ctas_num = tle.num_programs(0)
181 # grid-stride-loop style kernel
182 for j in range(0, tiles_per_cta):
183 global_pid = pid + j * ctas_num
184 isin_by_search_impl(
185 global_pid,
186 in0_ravel_ptr,
187 in1_sorted_ptr, # in
188 out_ptr, # out
189 M,
190 N,
191 log_n,
192 BLOCK_M,
193 invert,
194 )
197def isin_by_search(
198 in0: torch.tensor,
199 in1: torch.tensor,
200 invert: bool,
201 unique_in0: bool,
202 unique_in1: bool,
203):
204 # unique or sort or ravel
205 if unique_in0:
206 in0_ravel, unique_order, _ = _unique2(
207 in0, sorted=True, return_inverse=True, return_counts=False
208 )
209 else:
210 in0_ravel = in0.contiguous().ravel()
211 if unique_in1:
212 in1_ravel, _, _ = _unique2(
213 in1, sorted=True, return_inverse=False, return_counts=False
214 )
215 else:
216 in1_ravel, _ = torch.sort(in1.ravel())
217 # launch kernel func
218 M = in0_ravel.numel()
219 N = in1_ravel.numel()
220 if M <= 1048576: # 2 ** 20 = 1024 * 1024
221 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8)
222 elif M <= 4194304: # 2 ** 22 = 1024 * 4096
223 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8)
224 elif M <= 8388608: # 2 ** 23 = 1024 * 8192
225 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
226 elif M <= 268435456: # 2 ** 28 = 1024 * 262144
227 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32)
228 else:
229 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
230 if torch.version.hip is not None:
231 num_warps = min(16, num_warps)
232 log_n = int(math.log2(N)) + 1
233 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
234 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
235 grid = (ctas_num,)
236 out = torch.empty_like(in0_ravel, dtype=torch.bool)
237 with torch_device_fn.device(in0_ravel.device.index):
238 isin_by_search_kernel[grid](
239 in0_ravel,
240 in1_ravel, # in
241 out, # out
242 M,
243 N,
244 log_n,
245 BLOCK_M,
246 tiles_per_cta=tiles_per_cta,
247 invert=invert,
248 num_warps=num_warps,
249 )
250 if unique_in0:
251 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64))
252 return out.view_as(in0)
255def isin(
256 in0,
257 in1,
258 *,
259 assume_unique: bool = False,
260 invert: bool = False,
261) -> torch.Tensor:
262 logger.debug("GEMS ALLCLOSE")
263 if not torch.is_tensor(in0):
264 assert torch.is_tensor(in1)
265 in0 = torch.tensor(in0, device=in1.device)
266 elif not torch.is_tensor(in1):
267 assert torch.is_tensor(in0)
268 in1 = torch.tensor(in1, device=in0.device)
269 if in0.numel() == 0 or in1.numel() == 0:
270 return torch.zeros_like(in0, dtype=torch.bool)
271 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12
272 return isin_by_comparation(in0, in1, invert)
273 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096
274 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False)
275 else:
276 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)