Coverage for src/flag_gems/runtime/backend/_metax/ops/isin.py: 0%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.all import reduce_all 

8from flag_gems.ops.any import reduce_any 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from .unique import _unique2 

14 

15 

16def launch_arg(BLOCK_M, BLOCK_N, N, num_warps): 

17 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps 

18 

19 

20@triton.jit 

21def isin_by_comparation_impl( 

22 global_pid, 

23 in0_ravel_ptr: tl.tensor, 

24 in1_ravel_ptr: tl.tensor, # in 

25 out_ptr: tl.tensor, # out 

26 M: int, # num_tasks 

27 N: int, # num_tasks_1 

28 BLOCK_M: tl.constexpr, # tile_size 

29 BLOCK_N: tl.constexpr, # tile_size_1 

30 invert: tl.constexpr, 

31): 

32 row_off = global_pid * BLOCK_M 

33 rows = row_off + tl.arange(0, BLOCK_M)[:, None] 

34 row_mask = rows < M 

35 out_ptr += rows 

36 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32) 

37 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None] 

38 

39 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1) 

40 in0 = tl.load(in0_ravel_ptr, row_mask, other=0) 

41 for col_off in range(0, N, BLOCK_N): 

42 cols = col_off + tl.arange(0, BLOCK_N)[None, :] 

43 col_mask = cols < N 

44 mask = row_mask and col_mask 

45 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0) 

46 block = tl.where( 

47 mask, 

48 tl.where(invert, block and (in0 != in1), block or (in0 == in1)), 

49 invert, 

50 ) 

51 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any)) 

52 tl.store(out_ptr, out[:, None], row_mask) 

53 

54 

55@libentry() 

56@triton.jit 

57def isin_by_comparation_kernel( 

58 in0_ravel_ptr: tl.tensor, 

59 in1_ravel_ptr: tl.tensor, # in 

60 out_ptr: tl.tensor, # out 

61 M: int, # num_tasks 

62 N: int, # num_tasks_1 

63 BLOCK_M: tl.constexpr, # tile_size 

64 BLOCK_N: tl.constexpr, # tile_size_1 

65 tiles_per_cta: int, 

66 invert: tl.constexpr, 

67): 

68 pid = tle.program_id(0) 

69 ctas_num = tle.num_programs(0) 

70 # grid-stride-loop style kernel 

71 for j in range(0, tiles_per_cta): 

72 global_pid = pid + j * ctas_num 

73 isin_by_comparation_impl( 

74 global_pid, 

75 in0_ravel_ptr, 

76 in1_ravel_ptr, # in 

77 out_ptr, # out 

78 M, 

79 N, 

80 BLOCK_M, 

81 BLOCK_N, 

82 invert, 

83 ) 

84 

85 

86def isin_by_comparation( 

87 in0: torch.tensor, 

88 in1: torch.tensor, 

89 invert: bool, 

90): 

91 in0_ravel = in0.contiguous().ravel() 

92 in1_ravel = in1.contiguous().ravel() 

93 M = in0.numel() 

94 N = in1.numel() 

95 if M <= 1024: 

96 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4) 

97 elif M <= 3072: 

98 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4) 

99 elif M <= 6144: 

100 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

101 elif M <= 9216: 

102 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8) 

103 else: 

104 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

105 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

106 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

107 grid = (ctas_num,) 

108 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

109 with torch_device_fn.device(in0_ravel.device.index): 

110 isin_by_comparation_kernel[grid]( 

111 in0_ravel, 

112 in1_ravel, # in 

113 out, # out 

114 M, 

115 N, 

116 BLOCK_M, 

117 BLOCK_N, 

118 tiles_per_cta=tiles_per_cta, 

119 invert=invert, 

120 num_warps=num_warps, 

121 ) 

122 return out.view_as(in0) 

123 

124 

125@triton.jit 

126def isin_by_search_impl( 

127 global_pid, 

128 in0_ravel_ptr: tl.tensor, 

129 in1_sorted_ptr: tl.tensor, # in 

130 out_ptr: tl.tensor, # out 

131 M: int, # num_tasks 

132 N: int, # num_tasks_1 

133 log_n: tl.constexpr, 

134 BLOCK_M: tl.constexpr, # tile_size 

135 invert: tl.constexpr, 

136): 

137 r = tl.arange(0, BLOCK_M) 

138 i0 = global_pid * BLOCK_M + r 

139 mask = i0 < M 

140 

141 # load in0_ravel 

142 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask) 

143 

144 # binary search: lower_bound 

145 out = tl.zeros_like(r).to(tl.int1) 

146 start = tl.zeros_like(r) 

147 end = start + N 

148 while_mask = start < end 

149 for i in range(log_n): 

150 mid = tl.where(while_mask, start + (end - start) // 2, 0) 

151 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask) 

152 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found 

153 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start) 

154 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end) 

155 while_mask = start < end 

156 

157 # store out 

158 tl.store(out_ptr + i0, not out if invert else out, mask=mask) 

159 

160 

161@libentry() 

162@triton.jit 

163def isin_by_search_kernel( 

164 in0_ravel_ptr: tl.tensor, 

165 in1_sorted_ptr: tl.tensor, # in 

166 out_ptr: tl.tensor, # out 

167 M: int, # num_tasks 

168 N: int, # num_tasks_1 

169 log_n: tl.constexpr, 

170 BLOCK_M: tl.constexpr, # tile_size 

171 tiles_per_cta: int, 

172 invert: tl.constexpr, 

173): 

174 pid = tle.program_id(0) 

175 ctas_num = tle.num_programs(0) 

176 # grid-stride-loop style kernel 

177 for j in range(0, tiles_per_cta): 

178 global_pid = pid + j * ctas_num 

179 isin_by_search_impl( 

180 global_pid, 

181 in0_ravel_ptr, 

182 in1_sorted_ptr, # in 

183 out_ptr, # out 

184 M, 

185 N, 

186 log_n, 

187 BLOCK_M, 

188 invert, 

189 ) 

190 

191 

192def isin_by_search( 

193 in0: torch.tensor, 

194 in1: torch.tensor, 

195 invert: bool, 

196 unique_in0: bool, 

197 unique_in1: bool, 

198): 

199 # unique or sort or ravel 

200 if unique_in0: 

201 in0_ravel, unique_order, _ = _unique2( 

202 in0, sorted=True, return_inverse=True, return_counts=False 

203 ) 

204 else: 

205 in0_ravel = in0.contiguous().ravel() 

206 if unique_in1: 

207 in1_ravel, _, _ = _unique2( 

208 in1, sorted=True, return_inverse=False, return_counts=False 

209 ) 

210 else: 

211 in1_ravel, _ = torch.sort(in1.ravel()) 

212 # launch kernel func 

213 M = in0_ravel.numel() 

214 N = in1_ravel.numel() 

215 if M <= 1048576: # 2 ** 20 = 1024 * 1024 

216 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8) 

217 elif M <= 4194304: # 2 ** 22 = 1024 * 4096 

218 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8) 

219 elif M <= 8388608: # 2 ** 23 = 1024 * 8192 

220 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16) 

221 elif M <= 268435456: # 2 ** 28 = 1024 * 262144 

222 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 16) 

223 else: 

224 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16) 

225 log_n = int(math.log2(N)) + 1 

226 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

227 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

228 grid = (ctas_num,) 

229 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

230 with torch_device_fn.device(in0_ravel.device.index): 

231 isin_by_search_kernel[grid]( 

232 in0_ravel, 

233 in1_ravel, # in 

234 out, # out 

235 M, 

236 N, 

237 log_n, 

238 BLOCK_M, 

239 tiles_per_cta=tiles_per_cta, 

240 invert=invert, 

241 num_warps=num_warps, 

242 ) 

243 if unique_in0: 

244 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64)) 

245 return out.view_as(in0) 

246 

247 

248def isin( 

249 in0, 

250 in1, 

251 *, 

252 assume_unique: bool = False, 

253 invert: bool = False, 

254) -> torch.Tensor: 

255 if not torch.is_tensor(in0): 

256 assert torch.is_tensor(in1) 

257 in0 = torch.tensor(in0, device=in1.device) 

258 elif not torch.is_tensor(in1): 

259 assert torch.is_tensor(in0) 

260 in1 = torch.tensor(in1, device=in0.device) 

261 if in0.numel() == 0 or in1.numel() == 0: 

262 return torch.zeros_like(in0, dtype=torch.bool) 

263 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12 

264 return isin_by_comparation(in0, in1, invert) 

265 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096 

266 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False) 

267 else: 

268 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)