Coverage for src/flag_gems/runtime/backend/_ascend/ops/isin.py: 0%

132 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.ops.all import reduce_all 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import triton_lang_extension as tle 

11from flag_gems.utils.libentry import libentry 

12 

13from .any import reduce_any 

14from .unique import _unique2 

15 

16logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

17 

18 

19def launch_arg(BLOCK_M, BLOCK_N, N, num_warps): 

20 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps 

21 

22 

23@triton.jit 

24def isin_by_comparation_impl( 

25 global_pid, 

26 in0_ravel_ptr: tl.tensor, 

27 in1_ravel_ptr: tl.tensor, 

28 out_ptr: tl.tensor, 

29 M: int, 

30 N: int, 

31 BLOCK_M: tl.constexpr, # tile_size 

32 BLOCK_N: tl.constexpr, # tile_size_1 

33 invert: tl.constexpr, 

34): 

35 row_off = global_pid * BLOCK_M 

36 rows = row_off + tl.arange(0, BLOCK_M) 

37 row_mask = rows < M 

38 

39 # 为 in0 创建 [BLOCK_M, 1] 形状的索引 

40 in0_offsets = rows[:, None] # [BLOCK_M, 1] 

41 

42 # 初始化结果块 

43 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1) 

44 

45 # 加载 in0,每行一个元素 

46 in0 = tl.load( 

47 in0_ravel_ptr + in0_offsets, row_mask[:, None], other=0 

48 ) # [BLOCK_M, 1] 

49 

50 # 遍历 in1 的列 

51 for col_off in range(0, N, BLOCK_N): 

52 cols = col_off + tl.arange(0, BLOCK_N) 

53 col_mask = cols < N 

54 

55 # 创建 2D mask 

56 mask = row_mask[:, None] & col_mask[None, :] 

57 

58 # 加载 in1 的一块 

59 in1 = tl.load( 

60 in1_ravel_ptr + cols[None, :], col_mask[None, :], other=0 

61 ) # [1, BLOCK_N] 

62 

63 if invert: 

64 block = tl.where(mask, block & (in0 != in1), block) 

65 else: 

66 block = tl.where(mask, block | (in0 == in1), block) 

67 

68 # 沿列方向规约 

69 if invert: 

70 out = tl.reduce(block, axis=1, combine_fn=reduce_all) 

71 else: 

72 out = tl.reduce(block, axis=1, combine_fn=reduce_any) 

73 

74 # 存储结果 

75 tl.store(out_ptr + rows, out, row_mask) 

76 

77 

78@libentry() 

79@triton.jit 

80def isin_by_comparation_kernel( 

81 in0_ravel_ptr: tl.tensor, 

82 in1_ravel_ptr: tl.tensor, # in 

83 out_ptr: tl.tensor, # out 

84 M: int, # num_tasks 

85 N: int, # num_tasks_1 

86 BLOCK_M: tl.constexpr, # tile_size 

87 BLOCK_N: tl.constexpr, # tile_size_1 

88 tiles_per_cta: int, 

89 invert: tl.constexpr, 

90): 

91 pid = tle.program_id(0) 

92 ctas_num = tle.num_programs(0) 

93 # grid-stride-loop style kernel 

94 for j in range(0, tiles_per_cta): 

95 global_pid = pid + j * ctas_num 

96 isin_by_comparation_impl( 

97 global_pid, 

98 in0_ravel_ptr, 

99 in1_ravel_ptr, # in 

100 out_ptr, # out 

101 M, 

102 N, 

103 BLOCK_M, 

104 BLOCK_N, 

105 invert, 

106 ) 

107 

108 

109def isin_by_comparation( 

110 in0: torch.tensor, 

111 in1: torch.tensor, 

112 invert: bool, 

113): 

114 in0_ravel = in0.contiguous().ravel() 

115 in1_ravel = in1.contiguous().ravel() 

116 M = in0.numel() 

117 N = in1.numel() 

118 if M <= 1024: 

119 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4) 

120 elif M <= 3072: 

121 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4) 

122 elif M <= 6144: 

123 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

124 elif M <= 9216: 

125 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8) 

126 else: 

127 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

128 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

129 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

130 grid = (ctas_num,) 

131 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

132 with torch_device_fn.device(in0_ravel.device.index): 

133 isin_by_comparation_kernel[grid]( 

134 in0_ravel, 

135 in1_ravel, # in 

136 out, # out 

137 M, 

138 N, 

139 BLOCK_M, 

140 BLOCK_N, 

141 tiles_per_cta=tiles_per_cta, 

142 invert=invert, 

143 num_warps=num_warps, 

144 ) 

145 return out.view_as(in0) 

146 

147 

148@triton.jit 

149def isin_by_search_impl( 

150 global_pid, 

151 in0_ravel_ptr: tl.tensor, 

152 in1_sorted_ptr: tl.tensor, # in 

153 out_ptr: tl.tensor, # out 

154 M: int, # num_tasks 

155 N: int, # num_tasks_1 

156 log_n: tl.constexpr, 

157 BLOCK_M: tl.constexpr, # tile_size 

158 invert: tl.constexpr, 

159): 

160 r = tl.arange(0, BLOCK_M) 

161 i0 = global_pid * BLOCK_M + r 

162 mask = i0 < M 

163 

164 # load in0_ravel 

165 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask) 

166 

167 # binary search: lower_bound 

168 out = tl.zeros_like(r).to(tl.int1) 

169 start = tl.zeros_like(r) 

170 end = start + N 

171 while_mask = start < end 

172 for i in range(log_n): 

173 mid = tl.where(while_mask, start + (end - start) // 2, 0) 

174 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask) 

175 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found 

176 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start) 

177 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end) 

178 while_mask = start < end 

179 

180 # store out 

181 tl.store(out_ptr + i0, not out if invert else out, mask=mask) 

182 

183 

184@libentry() 

185@triton.jit 

186def isin_by_search_kernel( 

187 in0_ravel_ptr: tl.tensor, 

188 in1_sorted_ptr: tl.tensor, # in 

189 out_ptr: tl.tensor, # out 

190 M: int, # num_tasks 

191 N: int, # num_tasks_1 

192 log_n: tl.constexpr, 

193 BLOCK_M: tl.constexpr, # tile_size 

194 tiles_per_cta: int, 

195 invert: tl.constexpr, 

196): 

197 pid = tle.program_id(0) 

198 ctas_num = tle.num_programs(0) 

199 # grid-stride-loop style kernel 

200 for j in range(0, tiles_per_cta): 

201 global_pid = pid + j * ctas_num 

202 isin_by_search_impl( 

203 global_pid, 

204 in0_ravel_ptr, 

205 in1_sorted_ptr, # in 

206 out_ptr, # out 

207 M, 

208 N, 

209 log_n, 

210 BLOCK_M, 

211 invert, 

212 ) 

213 

214 

215def isin_by_search( 

216 in0: torch.tensor, 

217 in1: torch.tensor, 

218 invert: bool, 

219 unique_in0: bool, 

220 unique_in1: bool, 

221): 

222 # unique or sort or ravel 

223 if unique_in0: 

224 in0_ravel, unique_order, _ = _unique2( 

225 in0, sorted=True, return_inverse=True, return_counts=False 

226 ) 

227 else: 

228 in0_ravel = in0.contiguous().ravel() 

229 if unique_in1: 

230 in1_ravel, _, _ = _unique2( 

231 in1, sorted=True, return_inverse=False, return_counts=False 

232 ) 

233 else: 

234 in1_ravel, _ = torch.sort(in1.ravel()) 

235 # launch kernel func 

236 M = in0_ravel.numel() 

237 N = in1_ravel.numel() 

238 if M <= 1048576: # 2 ** 20 = 1024 * 1024 

239 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8) 

240 elif M <= 4194304: # 2 ** 22 = 1024 * 4096 

241 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8) 

242 elif M <= 8388608: # 2 ** 23 = 1024 * 8192 

243 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16) 

244 elif M <= 268435456: # 2 ** 28 = 1024 * 262144 

245 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32) 

246 else: 

247 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16) 

248 log_n = int(math.log2(N)) + 1 

249 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

250 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

251 grid = (ctas_num,) 

252 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

253 with torch_device_fn.device(in0_ravel.device.index): 

254 isin_by_search_kernel[grid]( 

255 in0_ravel, 

256 in1_ravel, # in 

257 out, # out 

258 M, 

259 N, 

260 log_n, 

261 BLOCK_M, 

262 tiles_per_cta=tiles_per_cta, 

263 invert=invert, 

264 num_warps=num_warps, 

265 ) 

266 if unique_in0: 

267 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64)) 

268 return out.view_as(in0) 

269 

270 

271def isin( 

272 in0, 

273 in1, 

274 *, 

275 assume_unique: bool = False, 

276 invert: bool = False, 

277) -> torch.Tensor: 

278 logger.debug("GEMS_ASCEND ISIN") 

279 if not torch.is_tensor(in0): 

280 assert torch.is_tensor(in1) 

281 in0 = torch.tensor(in0, device=in1.device) 

282 elif not torch.is_tensor(in1): 

283 assert torch.is_tensor(in0) 

284 in1 = torch.tensor(in1, device=in0.device) 

285 if in0.numel() == 0 or in1.numel() == 0: 

286 return torch.zeros_like(in0, dtype=torch.bool) 

287 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12 

288 return isin_by_comparation(in0, in1, invert) 

289 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096 

290 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False) 

291 else: 

292 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)