Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/isin.py: 0%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.all import reduce_all 

8from flag_gems.ops.any import reduce_any 

9from flag_gems.ops.unique import _unique2 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils.libentry import libentry 

12 

13 

14def launch_arg(BLOCK_M, BLOCK_N, N, num_warps): 

15 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps 

16 

17 

18@triton.jit 

19def isin_by_comparation_impl( 

20 global_pid, 

21 in0_ravel_ptr: tl.tensor, 

22 in1_ravel_ptr: tl.tensor, # in 

23 out_ptr: tl.tensor, # out 

24 M: int, # num_tasks 

25 N: int, # num_tasks_1 

26 BLOCK_M: tl.constexpr, # tile_size 

27 BLOCK_N: tl.constexpr, # tile_size_1 

28 invert: tl.constexpr, 

29): 

30 row_off = global_pid * BLOCK_M 

31 rows = row_off + tl.arange(0, BLOCK_M)[:, None] 

32 row_mask = rows < M 

33 out_ptr += rows 

34 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32) 

35 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None] 

36 

37 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1) 

38 in0 = tl.load(in0_ravel_ptr, row_mask, other=0) 

39 for col_off in range(0, N, BLOCK_N): 

40 cols = col_off + tl.arange(0, BLOCK_N)[None, :] 

41 col_mask = cols < N 

42 mask = row_mask and col_mask 

43 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0) 

44 block = tl.where( 

45 mask, 

46 tl.where(invert, block and (in0 != in1), block or (in0 == in1)), 

47 invert, 

48 ) 

49 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any)) 

50 tl.store(out_ptr, out[:, None], row_mask) 

51 

52 

53@libentry() 

54@triton.jit 

55def isin_by_comparation_kernel( 

56 in0_ravel_ptr: tl.tensor, 

57 in1_ravel_ptr: tl.tensor, # in 

58 out_ptr: tl.tensor, # out 

59 M: int, # num_tasks 

60 N: int, # num_tasks_1 

61 BLOCK_M: tl.constexpr, # tile_size 

62 BLOCK_N: tl.constexpr, # tile_size_1 

63 tiles_per_cta: int, 

64 invert: tl.constexpr, 

65): 

66 pid = tl.program_id(0) 

67 ctas_num = tl.num_programs(0) 

68 # grid-stride-loop style kernel 

69 for j in range(0, tiles_per_cta): 

70 global_pid = pid + j * ctas_num 

71 isin_by_comparation_impl( 

72 global_pid, 

73 in0_ravel_ptr, 

74 in1_ravel_ptr, # in 

75 out_ptr, # out 

76 M, 

77 N, 

78 BLOCK_M, 

79 BLOCK_N, 

80 invert, 

81 ) 

82 

83 

84def isin_by_comparation( 

85 in0: torch.tensor, 

86 in1: torch.tensor, 

87 invert: bool, 

88): 

89 in0_ravel = in0.contiguous().ravel() 

90 in1_ravel = in1.contiguous().ravel() 

91 M = in0.numel() 

92 N = in1.numel() 

93 if M <= 1024: 

94 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 1) 

95 elif M <= 3072: 

96 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 1) 

97 elif M <= 6144: 

98 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 1) 

99 elif M <= 9216: 

100 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 1) 

101 else: 

102 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 1) 

103 ctas_num = min(16 // num_warps, triton.cdiv(M, BLOCK_M)) 

104 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

105 grid = (ctas_num,) 

106 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

107 with torch_device_fn.device(in0_ravel.device.index): 

108 isin_by_comparation_kernel[grid]( 

109 in0_ravel, 

110 in1_ravel, # in 

111 out, # out 

112 M, 

113 N, 

114 BLOCK_M, 

115 BLOCK_N, 

116 tiles_per_cta=tiles_per_cta, 

117 invert=invert, 

118 num_warps=num_warps, 

119 ) 

120 return out.view_as(in0) 

121 

122 

123@triton.jit 

124def isin_by_search_impl( 

125 global_pid, 

126 in0_ravel_ptr: tl.tensor, 

127 in1_sorted_ptr: tl.tensor, # in 

128 out_ptr: tl.tensor, # out 

129 M: int, # num_tasks 

130 N: int, # num_tasks_1 

131 log_n: tl.constexpr, 

132 BLOCK_M: tl.constexpr, # tile_size 

133 invert: tl.constexpr, 

134): 

135 r = tl.arange(0, BLOCK_M) 

136 i0 = global_pid * BLOCK_M + r 

137 mask = i0 < M 

138 

139 # load in0_ravel 

140 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask) 

141 

142 # binary search: lower_bound 

143 out = tl.zeros_like(r).to(tl.int1) 

144 start = tl.zeros_like(r) 

145 end = start + N 

146 while_mask = start < end 

147 for i in range(log_n): 

148 mid = tl.where(while_mask, start + (end - start) // 2, 0) 

149 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask) 

150 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found 

151 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start) 

152 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end) 

153 while_mask = start < end 

154 

155 # store out 

156 tl.store(out_ptr + i0, not out if invert else out, mask=mask) 

157 

158 

159@libentry() 

160@triton.jit 

161def isin_by_search_kernel( 

162 in0_ravel_ptr: tl.tensor, 

163 in1_sorted_ptr: tl.tensor, # in 

164 out_ptr: tl.tensor, # out 

165 M: int, # num_tasks 

166 N: int, # num_tasks_1 

167 log_n: tl.constexpr, 

168 BLOCK_M: tl.constexpr, # tile_size 

169 tiles_per_cta: int, 

170 invert: tl.constexpr, 

171): 

172 pid = tl.program_id(0) 

173 ctas_num = tl.num_programs(0) 

174 # grid-stride-loop style kernel 

175 for j in range(0, tiles_per_cta): 

176 global_pid = pid + j * ctas_num 

177 isin_by_search_impl( 

178 global_pid, 

179 in0_ravel_ptr, 

180 in1_sorted_ptr, # in 

181 out_ptr, # out 

182 M, 

183 N, 

184 log_n, 

185 BLOCK_M, 

186 invert, 

187 ) 

188 

189 

190def isin_by_search( 

191 in0: torch.tensor, 

192 in1: torch.tensor, 

193 invert: bool, 

194 unique_in0: bool, 

195 unique_in1: bool, 

196): 

197 # unique or sort or ravel 

198 if unique_in0: 

199 in0_ravel, unique_order, _ = _unique2( 

200 in0, sorted=True, return_inverse=True, return_counts=False 

201 ) 

202 else: 

203 in0_ravel = in0.contiguous().ravel() 

204 if unique_in1: 

205 in1_ravel, _, _ = _unique2( 

206 in1, sorted=True, return_inverse=False, return_counts=False 

207 ) 

208 else: 

209 in1_ravel, _ = torch.sort(in1.ravel()) 

210 # launch kernel func 

211 M = in0_ravel.numel() 

212 N = in1_ravel.numel() 

213 if M <= 1048576: # 2 ** 20 = 1024 * 1024 

214 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 1) 

215 elif M <= 4194304: # 2 ** 22 = 1024 * 4096 

216 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 1) 

217 elif M <= 8388608: # 2 ** 23 = 1024 * 8192 

218 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 1) 

219 elif M <= 268435456: # 2 ** 28 = 1024 * 262144 

220 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 1) 

221 else: 

222 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 1) 

223 log_n = int(math.log2(N)) + 1 

224 ctas_num = min(16 // num_warps, triton.cdiv(M, BLOCK_M)) 

225 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

226 grid = (ctas_num,) 

227 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

228 with torch_device_fn.device(in0_ravel.device.index): 

229 isin_by_search_kernel[grid]( 

230 in0_ravel, 

231 in1_ravel, # in 

232 out, # out 

233 M, 

234 N, 

235 log_n, 

236 BLOCK_M, 

237 tiles_per_cta=tiles_per_cta, 

238 invert=invert, 

239 num_warps=num_warps, 

240 ) 

241 if unique_in0: 

242 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64)) 

243 return out.view_as(in0) 

244 

245 

246def isin( 

247 in0, 

248 in1, 

249 *, 

250 assume_unique: bool = False, 

251 invert: bool = False, 

252) -> torch.Tensor: 

253 if not torch.is_tensor(in0): 

254 assert torch.is_tensor(in1) 

255 in0 = torch.tensor(in0, device=in1.device) 

256 elif not torch.is_tensor(in1): 

257 assert torch.is_tensor(in0) 

258 in1 = torch.tensor(in1, device=in0.device) 

259 if in0.numel() == 0 or in1.numel() == 0: 

260 return torch.zeros_like(in0, dtype=torch.bool) 

261 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12 

262 return isin_by_comparation(in0, in1, invert) 

263 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096 

264 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False) 

265 else: 

266 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)