Coverage for src/flag_gems/runtime/backend/_mthreads/ops/isin.py: 0%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.all import reduce_all 

8from flag_gems.ops.any import reduce_any 

9from flag_gems.ops.unique import _unique2 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.libentry import libentry 

13 

14 

15def launch_arg(BLOCK_M, BLOCK_N, N, num_warps): 

16 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps 

17 

18 

19@triton.jit 

20def isin_by_comparation_impl( 

21 global_pid, 

22 in0_ravel_ptr: tl.tensor, 

23 in1_ravel_ptr: tl.tensor, # in 

24 out_ptr: tl.tensor, # out 

25 M: int, # num_tasks 

26 N: int, # num_tasks_1 

27 BLOCK_M: tl.constexpr, # tile_size 

28 BLOCK_N: tl.constexpr, # tile_size_1 

29 invert: tl.constexpr, 

30): 

31 row_off = global_pid * BLOCK_M 

32 rows = row_off + tl.arange(0, BLOCK_M)[:, None] 

33 row_mask = rows < M 

34 out_ptr += rows 

35 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32) 

36 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None] 

37 

38 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1) 

39 in0 = tl.load(in0_ravel_ptr, row_mask, other=0) 

40 for col_off in range(0, N, BLOCK_N): 

41 cols = col_off + tl.arange(0, BLOCK_N)[None, :] 

42 col_mask = cols < N 

43 mask = row_mask and col_mask 

44 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0) 

45 block = tl.where( 

46 mask, 

47 tl.where(invert, block and (in0 != in1), block or (in0 == in1)), 

48 invert, 

49 ) 

50 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any)) 

51 tl.store(out_ptr, out[:, None], row_mask) 

52 

53 

54@libentry() 

55@triton.jit 

56def isin_by_comparation_kernel( 

57 in0_ravel_ptr: tl.tensor, 

58 in1_ravel_ptr: tl.tensor, # in 

59 out_ptr: tl.tensor, # out 

60 M: int, # num_tasks 

61 N: int, # num_tasks_1 

62 BLOCK_M: tl.constexpr, # tile_size 

63 BLOCK_N: tl.constexpr, # tile_size_1 

64 tiles_per_cta: int, 

65 invert: tl.constexpr, 

66): 

67 pid = tle.program_id(0) 

68 ctas_num = tle.num_programs(0) 

69 # grid-stride-loop style kernel 

70 for j in range(0, tiles_per_cta): 

71 global_pid = pid + j * ctas_num 

72 isin_by_comparation_impl( 

73 global_pid, 

74 in0_ravel_ptr, 

75 in1_ravel_ptr, # in 

76 out_ptr, # out 

77 M, 

78 N, 

79 BLOCK_M, 

80 BLOCK_N, 

81 invert, 

82 ) 

83 

84 

85def isin_by_comparation( 

86 in0: torch.tensor, 

87 in1: torch.tensor, 

88 invert: bool, 

89): 

90 in0_ravel = in0.contiguous().ravel() 

91 in1_ravel = in1.contiguous().ravel() 

92 M = in0.numel() 

93 N = in1.numel() 

94 if M <= 1024: 

95 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4) 

96 elif M <= 3072: 

97 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4) 

98 elif M <= 6144: 

99 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

100 elif M <= 9216: 

101 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8) 

102 else: 

103 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

104 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

105 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

106 grid = (ctas_num,) 

107 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

108 with torch_device_fn.device(in0_ravel.device.index): 

109 isin_by_comparation_kernel[grid]( 

110 in0_ravel, 

111 in1_ravel, # in 

112 out, # out 

113 M, 

114 N, 

115 BLOCK_M, 

116 BLOCK_N, 

117 tiles_per_cta=tiles_per_cta, 

118 invert=invert, 

119 num_warps=num_warps, 

120 ) 

121 return out.view_as(in0) 

122 

123 

124@triton.jit 

125def isin_by_search_impl( 

126 global_pid, 

127 in0_ravel_ptr: tl.tensor, 

128 in1_sorted_ptr: tl.tensor, # in 

129 out_ptr: tl.tensor, # out 

130 M: int, # num_tasks 

131 N: int, # num_tasks_1 

132 log_n: tl.constexpr, 

133 BLOCK_M: tl.constexpr, # tile_size 

134 invert: tl.constexpr, 

135): 

136 r = tl.arange(0, BLOCK_M) 

137 i0 = global_pid * BLOCK_M + r 

138 mask = i0 < M 

139 

140 # load in0_ravel 

141 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask) 

142 

143 # binary search: lower_bound 

144 out = tl.zeros_like(r).to(tl.int1) 

145 start = tl.zeros_like(r) 

146 end = start + N 

147 while_mask = start < end 

148 for i in range(log_n): 

149 mid = tl.where(while_mask, start + (end - start) // 2, 0) 

150 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask) 

151 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found 

152 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start) 

153 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end) 

154 while_mask = start < end 

155 

156 # store out 

157 tl.store(out_ptr + i0, not out if invert else out, mask=mask) 

158 

159 

160@libentry() 

161@triton.jit 

162def isin_by_search_kernel( 

163 in0_ravel_ptr: tl.tensor, 

164 in1_sorted_ptr: tl.tensor, # in 

165 out_ptr: tl.tensor, # out 

166 M: int, # num_tasks 

167 N: int, # num_tasks_1 

168 log_n: tl.constexpr, 

169 BLOCK_M: tl.constexpr, # tile_size 

170 tiles_per_cta: int, 

171 invert: tl.constexpr, 

172): 

173 pid = tle.program_id(0) 

174 ctas_num = tle.num_programs(0) 

175 # grid-stride-loop style kernel 

176 for j in range(0, tiles_per_cta): 

177 global_pid = pid + j * ctas_num 

178 isin_by_search_impl( 

179 global_pid, 

180 in0_ravel_ptr, 

181 in1_sorted_ptr, # in 

182 out_ptr, # out 

183 M, 

184 N, 

185 log_n, 

186 BLOCK_M, 

187 invert, 

188 ) 

189 

190 

191def isin_by_search( 

192 in0: torch.tensor, 

193 in1: torch.tensor, 

194 invert: bool, 

195 unique_in0: bool, 

196 unique_in1: bool, 

197): 

198 # unique or sort or ravel 

199 if unique_in0: 

200 in0_ravel, unique_order, _ = _unique2( 

201 in0, sorted=True, return_inverse=True, return_counts=False 

202 ) 

203 else: 

204 in0_ravel = in0.contiguous().ravel() 

205 if unique_in1: 

206 in1_ravel, _, _ = _unique2( 

207 in1, sorted=True, return_inverse=False, return_counts=False 

208 ) 

209 else: 

210 in1_ravel, _ = torch.sort(in1.ravel()) 

211 # launch kernel func 

212 M = in0_ravel.numel() 

213 N = in1_ravel.numel() 

214 if M <= 1048576: # 2 ** 20 = 1024 * 1024 

215 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8) 

216 elif M <= 4194304: # 2 ** 22 = 1024 * 4096 

217 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8) 

218 elif M <= 8388608: # 2 ** 23 = 1024 * 8192 

219 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 8) 

220 elif M <= 268435456: # 2 ** 28 = 1024 * 262144 

221 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 8) 

222 else: 

223 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 8) 

224 log_n = int(math.log2(N)) + 1 

225 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

226 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

227 grid = (ctas_num,) 

228 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

229 with torch_device_fn.device(in0_ravel.device.index): 

230 isin_by_search_kernel[grid]( 

231 in0_ravel, 

232 in1_ravel, # in 

233 out, # out 

234 M, 

235 N, 

236 log_n, 

237 BLOCK_M, 

238 tiles_per_cta=tiles_per_cta, 

239 invert=invert, 

240 num_warps=num_warps, 

241 ) 

242 if unique_in0: 

243 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64)) 

244 return out.view_as(in0) 

245 

246 

247def isin( 

248 in0, 

249 in1, 

250 *, 

251 assume_unique: bool = False, 

252 invert: bool = False, 

253) -> torch.Tensor: 

254 if not torch.is_tensor(in0): 

255 assert torch.is_tensor(in1) 

256 in0 = torch.tensor(in0, device=in1.device) 

257 elif not torch.is_tensor(in1): 

258 assert torch.is_tensor(in0) 

259 in1 = torch.tensor(in1, device=in0.device) 

260 if in0.numel() == 0 or in1.numel() == 0: 

261 return torch.zeros_like(in0, dtype=torch.bool) 

262 elif in0.numel() <= 12288 and in1.numel() <= 12288: # 1024 * 12 

263 return isin_by_comparation(in0, in1, invert) 

264 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096 

265 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False) 

266 else: 

267 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)