Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/isin.py: 0%

138 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import math 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import triton_lang_extension as tle 

10from flag_gems.utils.libentry import libentry 

11 

12from .all import reduce_all 

13from .any import reduce_any 

14from .unique import _unique2 

15 

16 

17def launch_arg(BLOCK_M, BLOCK_N, N, num_warps): 

18 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps 

19 

20 

21@triton.jit 

22def isin_by_comparation_impl( 

23 global_pid, 

24 in0_ravel_ptr: tl.tensor, 

25 in1_ravel_ptr: tl.tensor, # in 

26 out_ptr: tl.tensor, # out 

27 M: int, # num_tasks 

28 N: int, # num_tasks_1 

29 BLOCK_M: tl.constexpr, # tile_size 

30 BLOCK_N: tl.constexpr, # tile_size_1 

31 invert: tl.constexpr, 

32): 

33 row_off = global_pid * BLOCK_M 

34 rows = row_off + tl.arange(0, BLOCK_M)[:, None] 

35 row_mask = rows < M 

36 out_ptr += rows 

37 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32) 

38 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None] 

39 

40 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1) 

41 in0 = tl.load(in0_ravel_ptr, row_mask, other=0) 

42 for col_off in range(0, N, BLOCK_N): 

43 cols = col_off + tl.arange(0, BLOCK_N)[None, :] 

44 col_mask = cols < N 

45 mask = row_mask and col_mask 

46 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0) 

47 block = tl.where( 

48 mask, 

49 tl.where(invert, block and (in0 != in1), block or (in0 == in1)), 

50 invert, 

51 ) 

52 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any)) 

53 tl.store(out_ptr, out[:, None], row_mask) 

54 

55 

56@libentry() 

57@triton.jit 

58def isin_by_comparation_kernel( 

59 in0_ravel_ptr: tl.tensor, 

60 in1_ravel_ptr: tl.tensor, # in 

61 out_ptr: tl.tensor, # out 

62 M: int, # num_tasks 

63 N: int, # num_tasks_1 

64 BLOCK_M: tl.constexpr, # tile_size 

65 BLOCK_N: tl.constexpr, # tile_size_1 

66 tiles_per_cta: int, 

67 invert: tl.constexpr, 

68): 

69 pid = tle.program_id(0) 

70 ctas_num = tle.num_programs(0) 

71 # grid-stride-loop style kernel 

72 for j in range(0, tiles_per_cta): 

73 global_pid = pid + j * ctas_num 

74 isin_by_comparation_impl( 

75 global_pid, 

76 in0_ravel_ptr, 

77 in1_ravel_ptr, # in 

78 out_ptr, # out 

79 M, 

80 N, 

81 BLOCK_M, 

82 BLOCK_N, 

83 invert, 

84 ) 

85 

86 

87def isin_by_comparation( 

88 in0: torch.tensor, 

89 in1: torch.tensor, 

90 invert: bool, 

91): 

92 in0_ravel = in0.contiguous().ravel() 

93 in1_ravel = in1.contiguous().ravel() 

94 M = in0.numel() 

95 N = in1.numel() 

96 if M <= 1024: 

97 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4) 

98 elif M <= 3072: 

99 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4) 

100 elif M <= 6144: 

101 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

102 elif M <= 9216: 

103 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8) 

104 else: 

105 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

106 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

107 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

108 grid = (ctas_num,) 

109 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

110 with torch_device_fn.device(in0_ravel.device.index): 

111 isin_by_comparation_kernel[grid]( 

112 in0_ravel, 

113 in1_ravel, # in 

114 out, # out 

115 M, 

116 N, 

117 BLOCK_M, 

118 BLOCK_N, 

119 tiles_per_cta=tiles_per_cta, 

120 invert=invert, 

121 num_warps=num_warps, 

122 ) 

123 return out.view_as(in0) 

124 

125 

126@triton.jit 

127def isin_by_search_impl( 

128 global_pid, 

129 in0_ravel_ptr: tl.tensor, 

130 in1_sorted_ptr: tl.tensor, # in 

131 out_ptr: tl.tensor, # out 

132 M: int, # num_tasks 

133 N: int, # num_tasks_1 

134 log_n: tl.constexpr, 

135 BLOCK_M: tl.constexpr, # tile_size 

136 invert: tl.constexpr, 

137): 

138 r = tl.arange(0, BLOCK_M) 

139 i0 = global_pid * BLOCK_M + r 

140 mask = i0 < M 

141 

142 # load in0_ravel 

143 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask) 

144 

145 # binary search: lower_bound 

146 out = tl.zeros_like(r).to(tl.int1) 

147 start = tl.zeros_like(r) 

148 end = start + N 

149 while_mask = start < end 

150 for i in range(log_n): 

151 mid = tl.where(while_mask, start + (end - start) // 2, 0) 

152 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask) 

153 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found 

154 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start) 

155 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end) 

156 while_mask = start < end 

157 

158 # store out 

159 out_offset = tl.where(mask, i0, M + 1) 

160 tl.store(out_ptr + out_offset, not out if invert else out, mask=mask) 

161 

162 

163@libentry() 

164@triton.jit 

165def isin_by_search_kernel( 

166 in0_ravel_ptr: tl.tensor, 

167 in1_sorted_ptr: tl.tensor, # in 

168 out_ptr: tl.tensor, # out 

169 M: int, # num_tasks 

170 N: int, # num_tasks_1 

171 log_n: tl.constexpr, 

172 BLOCK_M: tl.constexpr, # tile_size 

173 tiles_per_cta: int, 

174 invert: tl.constexpr, 

175): 

176 pid = tle.program_id(0) 

177 ctas_num = tle.num_programs(0) 

178 # grid-stride-loop style kernel 

179 for j in range(0, tiles_per_cta): 

180 global_pid = pid + j * ctas_num 

181 isin_by_search_impl( 

182 global_pid, 

183 in0_ravel_ptr, 

184 in1_sorted_ptr, # in 

185 out_ptr, # out 

186 M, 

187 N, 

188 log_n, 

189 BLOCK_M, 

190 invert, 

191 ) 

192 

193 

194def isin_by_search( 

195 in0: torch.tensor, 

196 in1: torch.tensor, 

197 invert: bool, 

198 unique_in0: bool, 

199 unique_in1: bool, 

200): 

201 # unique or sort or ravel 

202 if unique_in0: 

203 # print("hit _unique2!!!") 

204 in0_ravel, unique_order, _ = _unique2( 

205 in0, sorted=True, return_inverse=True, return_counts=False 

206 ) 

207 else: 

208 in0_ravel = in0.contiguous().ravel() 

209 if unique_in1: 

210 # print("hit _unique2!!!") 

211 in1_ravel, _, _ = _unique2( 

212 in1, sorted=True, return_inverse=False, return_counts=False 

213 ) 

214 else: 

215 in1_ravel, _ = torch.sort(in1.ravel()) 

216 # launch kernel func 

217 M = in0_ravel.numel() 

218 N = in1_ravel.numel() 

219 if M <= 1048576: # 2 ** 20 = 1024 * 1024 

220 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8) 

221 elif M <= 4194304: # 2 ** 22 = 1024 * 4096 

222 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8) 

223 elif M <= 8388608: # 2 ** 23 = 1024 * 8192 

224 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16) 

225 elif M <= 268435456: # 2 ** 28 = 1024 * 262144 

226 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32) 

227 else: 

228 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16) 

229 log_n = int(math.log2(N)) + 1 

230 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

231 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

232 # print(f"M = {M}") 

233 # print(f"BLOCK_M = {BLOCK_M}") 

234 # print(f"ctas_num = {ctas_num}") 

235 # print(f"tiles_per_cta = {tiles_per_cta}") 

236 grid = (ctas_num,) 

237 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

238 with torch_device_fn.device(in0_ravel.device.index): 

239 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

240 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

241 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

242 isin_by_search_kernel[grid]( 

243 in0_ravel, 

244 in1_ravel, # in 

245 out, # out 

246 M, 

247 N, 

248 log_n, 

249 BLOCK_M, 

250 tiles_per_cta=tiles_per_cta, 

251 invert=invert, 

252 num_warps=num_warps, 

253 isCloseUnrollControl=True, 

254 ) 

255 if "TRITONXPU_OTHER_SIM" in os.environ: 

256 del os.environ["TRITONXPU_OTHER_SIM"] 

257 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

258 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

259 if "TRITONXPU_INTERLEAVE" in os.environ: 

260 del os.environ["TRITONXPU_INTERLEAVE"] 

261 

262 if unique_in0: 

263 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64)) 

264 return out.view_as(in0) 

265 

266 

267def isin( 

268 in0, 

269 in1, 

270 *, 

271 assume_unique: bool = False, 

272 invert: bool = False, 

273) -> torch.Tensor: 

274 if not torch.is_tensor(in0): 

275 assert torch.is_tensor(in1) 

276 in0 = torch.tensor(in0, device=in1.device) 

277 elif not torch.is_tensor(in1): 

278 assert torch.is_tensor(in0) 

279 in1 = torch.tensor(in1, device=in0.device) 

280 if in0.numel() == 0 or in1.numel() == 0: 

281 return torch.zeros_like(in0, dtype=torch.bool) 

282 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12 

283 # print("hit isin_by_comparation!") 

284 return isin_by_comparation(in0, in1, invert) 

285 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096 

286 # print("hit isin_by_search unique_in1=False!") 

287 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False) 

288 else: 

289 # print("hit isin_by_search unique_in1=True!") 

290 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)