Coverage for src/flag_gems/runtime/backend/_ascend/ops/isin.py: 0%
132 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.ops.all import reduce_all
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import triton_lang_extension as tle
11from flag_gems.utils.libentry import libentry
13from .any import reduce_any
14from .unique import _unique2
16logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
19def launch_arg(BLOCK_M, BLOCK_N, N, num_warps):
20 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps
23@triton.jit
24def isin_by_comparation_impl(
25 global_pid,
26 in0_ravel_ptr: tl.tensor,
27 in1_ravel_ptr: tl.tensor,
28 out_ptr: tl.tensor,
29 M: int,
30 N: int,
31 BLOCK_M: tl.constexpr, # tile_size
32 BLOCK_N: tl.constexpr, # tile_size_1
33 invert: tl.constexpr,
34):
35 row_off = global_pid * BLOCK_M
36 rows = row_off + tl.arange(0, BLOCK_M)
37 row_mask = rows < M
39 # 为 in0 创建 [BLOCK_M, 1] 形状的索引
40 in0_offsets = rows[:, None] # [BLOCK_M, 1]
42 # 初始化结果块
43 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
45 # 加载 in0,每行一个元素
46 in0 = tl.load(
47 in0_ravel_ptr + in0_offsets, row_mask[:, None], other=0
48 ) # [BLOCK_M, 1]
50 # 遍历 in1 的列
51 for col_off in range(0, N, BLOCK_N):
52 cols = col_off + tl.arange(0, BLOCK_N)
53 col_mask = cols < N
55 # 创建 2D mask
56 mask = row_mask[:, None] & col_mask[None, :]
58 # 加载 in1 的一块
59 in1 = tl.load(
60 in1_ravel_ptr + cols[None, :], col_mask[None, :], other=0
61 ) # [1, BLOCK_N]
63 if invert:
64 block = tl.where(mask, block & (in0 != in1), block)
65 else:
66 block = tl.where(mask, block | (in0 == in1), block)
68 # 沿列方向规约
69 if invert:
70 out = tl.reduce(block, axis=1, combine_fn=reduce_all)
71 else:
72 out = tl.reduce(block, axis=1, combine_fn=reduce_any)
74 # 存储结果
75 tl.store(out_ptr + rows, out, row_mask)
78@libentry()
79@triton.jit
80def isin_by_comparation_kernel(
81 in0_ravel_ptr: tl.tensor,
82 in1_ravel_ptr: tl.tensor, # in
83 out_ptr: tl.tensor, # out
84 M: int, # num_tasks
85 N: int, # num_tasks_1
86 BLOCK_M: tl.constexpr, # tile_size
87 BLOCK_N: tl.constexpr, # tile_size_1
88 tiles_per_cta: int,
89 invert: tl.constexpr,
90):
91 pid = tle.program_id(0)
92 ctas_num = tle.num_programs(0)
93 # grid-stride-loop style kernel
94 for j in range(0, tiles_per_cta):
95 global_pid = pid + j * ctas_num
96 isin_by_comparation_impl(
97 global_pid,
98 in0_ravel_ptr,
99 in1_ravel_ptr, # in
100 out_ptr, # out
101 M,
102 N,
103 BLOCK_M,
104 BLOCK_N,
105 invert,
106 )
109def isin_by_comparation(
110 in0: torch.tensor,
111 in1: torch.tensor,
112 invert: bool,
113):
114 in0_ravel = in0.contiguous().ravel()
115 in1_ravel = in1.contiguous().ravel()
116 M = in0.numel()
117 N = in1.numel()
118 if M <= 1024:
119 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4)
120 elif M <= 3072:
121 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4)
122 elif M <= 6144:
123 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
124 elif M <= 9216:
125 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8)
126 else:
127 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
128 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
129 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
130 grid = (ctas_num,)
131 out = torch.empty_like(in0_ravel, dtype=torch.bool)
132 with torch_device_fn.device(in0_ravel.device.index):
133 isin_by_comparation_kernel[grid](
134 in0_ravel,
135 in1_ravel, # in
136 out, # out
137 M,
138 N,
139 BLOCK_M,
140 BLOCK_N,
141 tiles_per_cta=tiles_per_cta,
142 invert=invert,
143 num_warps=num_warps,
144 )
145 return out.view_as(in0)
148@triton.jit
149def isin_by_search_impl(
150 global_pid,
151 in0_ravel_ptr: tl.tensor,
152 in1_sorted_ptr: tl.tensor, # in
153 out_ptr: tl.tensor, # out
154 M: int, # num_tasks
155 N: int, # num_tasks_1
156 log_n: tl.constexpr,
157 BLOCK_M: tl.constexpr, # tile_size
158 invert: tl.constexpr,
159):
160 r = tl.arange(0, BLOCK_M)
161 i0 = global_pid * BLOCK_M + r
162 mask = i0 < M
164 # load in0_ravel
165 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)
167 # binary search: lower_bound
168 out = tl.zeros_like(r).to(tl.int1)
169 start = tl.zeros_like(r)
170 end = start + N
171 while_mask = start < end
172 for i in range(log_n):
173 mid = tl.where(while_mask, start + (end - start) // 2, 0)
174 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
175 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found
176 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
177 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
178 while_mask = start < end
180 # store out
181 tl.store(out_ptr + i0, not out if invert else out, mask=mask)
184@libentry()
185@triton.jit
186def isin_by_search_kernel(
187 in0_ravel_ptr: tl.tensor,
188 in1_sorted_ptr: tl.tensor, # in
189 out_ptr: tl.tensor, # out
190 M: int, # num_tasks
191 N: int, # num_tasks_1
192 log_n: tl.constexpr,
193 BLOCK_M: tl.constexpr, # tile_size
194 tiles_per_cta: int,
195 invert: tl.constexpr,
196):
197 pid = tle.program_id(0)
198 ctas_num = tle.num_programs(0)
199 # grid-stride-loop style kernel
200 for j in range(0, tiles_per_cta):
201 global_pid = pid + j * ctas_num
202 isin_by_search_impl(
203 global_pid,
204 in0_ravel_ptr,
205 in1_sorted_ptr, # in
206 out_ptr, # out
207 M,
208 N,
209 log_n,
210 BLOCK_M,
211 invert,
212 )
215def isin_by_search(
216 in0: torch.tensor,
217 in1: torch.tensor,
218 invert: bool,
219 unique_in0: bool,
220 unique_in1: bool,
221):
222 # unique or sort or ravel
223 if unique_in0:
224 in0_ravel, unique_order, _ = _unique2(
225 in0, sorted=True, return_inverse=True, return_counts=False
226 )
227 else:
228 in0_ravel = in0.contiguous().ravel()
229 if unique_in1:
230 in1_ravel, _, _ = _unique2(
231 in1, sorted=True, return_inverse=False, return_counts=False
232 )
233 else:
234 in1_ravel, _ = torch.sort(in1.ravel())
235 # launch kernel func
236 M = in0_ravel.numel()
237 N = in1_ravel.numel()
238 if M <= 1048576: # 2 ** 20 = 1024 * 1024
239 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8)
240 elif M <= 4194304: # 2 ** 22 = 1024 * 4096
241 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8)
242 elif M <= 8388608: # 2 ** 23 = 1024 * 8192
243 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
244 elif M <= 268435456: # 2 ** 28 = 1024 * 262144
245 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32)
246 else:
247 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
248 log_n = int(math.log2(N)) + 1
249 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
250 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
251 grid = (ctas_num,)
252 out = torch.empty_like(in0_ravel, dtype=torch.bool)
253 with torch_device_fn.device(in0_ravel.device.index):
254 isin_by_search_kernel[grid](
255 in0_ravel,
256 in1_ravel, # in
257 out, # out
258 M,
259 N,
260 log_n,
261 BLOCK_M,
262 tiles_per_cta=tiles_per_cta,
263 invert=invert,
264 num_warps=num_warps,
265 )
266 if unique_in0:
267 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64))
268 return out.view_as(in0)
271def isin(
272 in0,
273 in1,
274 *,
275 assume_unique: bool = False,
276 invert: bool = False,
277) -> torch.Tensor:
278 logger.debug("GEMS_ASCEND ISIN")
279 if not torch.is_tensor(in0):
280 assert torch.is_tensor(in1)
281 in0 = torch.tensor(in0, device=in1.device)
282 elif not torch.is_tensor(in1):
283 assert torch.is_tensor(in0)
284 in1 = torch.tensor(in1, device=in0.device)
285 if in0.numel() == 0 or in1.numel() == 0:
286 return torch.zeros_like(in0, dtype=torch.bool)
287 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12
288 return isin_by_comparation(in0, in1, invert)
289 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096
290 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False)
291 else:
292 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)