Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/isin.py: 0%
138 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import math
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.libentry import libentry
12from .all import reduce_all
13from .any import reduce_any
14from .unique import _unique2
17def launch_arg(BLOCK_M, BLOCK_N, N, num_warps):
18 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps
21@triton.jit
22def isin_by_comparation_impl(
23 global_pid,
24 in0_ravel_ptr: tl.tensor,
25 in1_ravel_ptr: tl.tensor, # in
26 out_ptr: tl.tensor, # out
27 M: int, # num_tasks
28 N: int, # num_tasks_1
29 BLOCK_M: tl.constexpr, # tile_size
30 BLOCK_N: tl.constexpr, # tile_size_1
31 invert: tl.constexpr,
32):
33 row_off = global_pid * BLOCK_M
34 rows = row_off + tl.arange(0, BLOCK_M)[:, None]
35 row_mask = rows < M
36 out_ptr += rows
37 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)
38 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]
40 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
41 in0 = tl.load(in0_ravel_ptr, row_mask, other=0)
42 for col_off in range(0, N, BLOCK_N):
43 cols = col_off + tl.arange(0, BLOCK_N)[None, :]
44 col_mask = cols < N
45 mask = row_mask and col_mask
46 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)
47 block = tl.where(
48 mask,
49 tl.where(invert, block and (in0 != in1), block or (in0 == in1)),
50 invert,
51 )
52 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))
53 tl.store(out_ptr, out[:, None], row_mask)
56@libentry()
57@triton.jit
58def isin_by_comparation_kernel(
59 in0_ravel_ptr: tl.tensor,
60 in1_ravel_ptr: tl.tensor, # in
61 out_ptr: tl.tensor, # out
62 M: int, # num_tasks
63 N: int, # num_tasks_1
64 BLOCK_M: tl.constexpr, # tile_size
65 BLOCK_N: tl.constexpr, # tile_size_1
66 tiles_per_cta: int,
67 invert: tl.constexpr,
68):
69 pid = tle.program_id(0)
70 ctas_num = tle.num_programs(0)
71 # grid-stride-loop style kernel
72 for j in range(0, tiles_per_cta):
73 global_pid = pid + j * ctas_num
74 isin_by_comparation_impl(
75 global_pid,
76 in0_ravel_ptr,
77 in1_ravel_ptr, # in
78 out_ptr, # out
79 M,
80 N,
81 BLOCK_M,
82 BLOCK_N,
83 invert,
84 )
87def isin_by_comparation(
88 in0: torch.tensor,
89 in1: torch.tensor,
90 invert: bool,
91):
92 in0_ravel = in0.contiguous().ravel()
93 in1_ravel = in1.contiguous().ravel()
94 M = in0.numel()
95 N = in1.numel()
96 if M <= 1024:
97 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4)
98 elif M <= 3072:
99 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4)
100 elif M <= 6144:
101 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
102 elif M <= 9216:
103 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8)
104 else:
105 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
106 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
107 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
108 grid = (ctas_num,)
109 out = torch.empty_like(in0_ravel, dtype=torch.bool)
110 with torch_device_fn.device(in0_ravel.device.index):
111 isin_by_comparation_kernel[grid](
112 in0_ravel,
113 in1_ravel, # in
114 out, # out
115 M,
116 N,
117 BLOCK_M,
118 BLOCK_N,
119 tiles_per_cta=tiles_per_cta,
120 invert=invert,
121 num_warps=num_warps,
122 )
123 return out.view_as(in0)
126@triton.jit
127def isin_by_search_impl(
128 global_pid,
129 in0_ravel_ptr: tl.tensor,
130 in1_sorted_ptr: tl.tensor, # in
131 out_ptr: tl.tensor, # out
132 M: int, # num_tasks
133 N: int, # num_tasks_1
134 log_n: tl.constexpr,
135 BLOCK_M: tl.constexpr, # tile_size
136 invert: tl.constexpr,
137):
138 r = tl.arange(0, BLOCK_M)
139 i0 = global_pid * BLOCK_M + r
140 mask = i0 < M
142 # load in0_ravel
143 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)
145 # binary search: lower_bound
146 out = tl.zeros_like(r).to(tl.int1)
147 start = tl.zeros_like(r)
148 end = start + N
149 while_mask = start < end
150 for i in range(log_n):
151 mid = tl.where(while_mask, start + (end - start) // 2, 0)
152 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
153 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found
154 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
155 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
156 while_mask = start < end
158 # store out
159 out_offset = tl.where(mask, i0, M + 1)
160 tl.store(out_ptr + out_offset, not out if invert else out, mask=mask)
163@libentry()
164@triton.jit
165def isin_by_search_kernel(
166 in0_ravel_ptr: tl.tensor,
167 in1_sorted_ptr: tl.tensor, # in
168 out_ptr: tl.tensor, # out
169 M: int, # num_tasks
170 N: int, # num_tasks_1
171 log_n: tl.constexpr,
172 BLOCK_M: tl.constexpr, # tile_size
173 tiles_per_cta: int,
174 invert: tl.constexpr,
175):
176 pid = tle.program_id(0)
177 ctas_num = tle.num_programs(0)
178 # grid-stride-loop style kernel
179 for j in range(0, tiles_per_cta):
180 global_pid = pid + j * ctas_num
181 isin_by_search_impl(
182 global_pid,
183 in0_ravel_ptr,
184 in1_sorted_ptr, # in
185 out_ptr, # out
186 M,
187 N,
188 log_n,
189 BLOCK_M,
190 invert,
191 )
194def isin_by_search(
195 in0: torch.tensor,
196 in1: torch.tensor,
197 invert: bool,
198 unique_in0: bool,
199 unique_in1: bool,
200):
201 # unique or sort or ravel
202 if unique_in0:
203 # print("hit _unique2!!!")
204 in0_ravel, unique_order, _ = _unique2(
205 in0, sorted=True, return_inverse=True, return_counts=False
206 )
207 else:
208 in0_ravel = in0.contiguous().ravel()
209 if unique_in1:
210 # print("hit _unique2!!!")
211 in1_ravel, _, _ = _unique2(
212 in1, sorted=True, return_inverse=False, return_counts=False
213 )
214 else:
215 in1_ravel, _ = torch.sort(in1.ravel())
216 # launch kernel func
217 M = in0_ravel.numel()
218 N = in1_ravel.numel()
219 if M <= 1048576: # 2 ** 20 = 1024 * 1024
220 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8)
221 elif M <= 4194304: # 2 ** 22 = 1024 * 4096
222 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8)
223 elif M <= 8388608: # 2 ** 23 = 1024 * 8192
224 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
225 elif M <= 268435456: # 2 ** 28 = 1024 * 262144
226 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32)
227 else:
228 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
229 log_n = int(math.log2(N)) + 1
230 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
231 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
232 # print(f"M = {M}")
233 # print(f"BLOCK_M = {BLOCK_M}")
234 # print(f"ctas_num = {ctas_num}")
235 # print(f"tiles_per_cta = {tiles_per_cta}")
236 grid = (ctas_num,)
237 out = torch.empty_like(in0_ravel, dtype=torch.bool)
238 with torch_device_fn.device(in0_ravel.device.index):
239 os.environ["TRITONXPU_OTHER_SIM"] = "1"
240 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
241 os.environ["TRITONXPU_INTERLEAVE"] = "0"
242 isin_by_search_kernel[grid](
243 in0_ravel,
244 in1_ravel, # in
245 out, # out
246 M,
247 N,
248 log_n,
249 BLOCK_M,
250 tiles_per_cta=tiles_per_cta,
251 invert=invert,
252 num_warps=num_warps,
253 isCloseUnrollControl=True,
254 )
255 if "TRITONXPU_OTHER_SIM" in os.environ:
256 del os.environ["TRITONXPU_OTHER_SIM"]
257 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
258 del os.environ["TRITONXPU_STORE_MASK_SIM"]
259 if "TRITONXPU_INTERLEAVE" in os.environ:
260 del os.environ["TRITONXPU_INTERLEAVE"]
262 if unique_in0:
263 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64))
264 return out.view_as(in0)
267def isin(
268 in0,
269 in1,
270 *,
271 assume_unique: bool = False,
272 invert: bool = False,
273) -> torch.Tensor:
274 if not torch.is_tensor(in0):
275 assert torch.is_tensor(in1)
276 in0 = torch.tensor(in0, device=in1.device)
277 elif not torch.is_tensor(in1):
278 assert torch.is_tensor(in0)
279 in1 = torch.tensor(in1, device=in0.device)
280 if in0.numel() == 0 or in1.numel() == 0:
281 return torch.zeros_like(in0, dtype=torch.bool)
282 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12
283 # print("hit isin_by_comparation!")
284 return isin_by_comparation(in0, in1, invert)
285 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096
286 # print("hit isin_by_search unique_in1=False!")
287 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False)
288 else:
289 # print("hit isin_by_search unique_in1=True!")
290 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)