Coverage for src/flag_gems/runtime/backend/_cambricon/ops/index_select.py: 0%

167 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry, libtuner 

10 

11from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16def get_max_block_size(dtype_size): 

17 return MAX_NRAM_SIZE // 3 // dtype_size 

18 

19 

20def config_prune(configs, named_args, **kwargs): 

21 N = named_args["N"] 

22 dtype_size = named_args["dtype_size"] 

23 max_block_size = get_max_block_size(dtype_size) 

24 

25 pruned_configs = [] 

26 index_block_size = [] 

27 for config in configs: 

28 bs = config.kwargs["BLOCK_SIZE"] 

29 ibs = (bs + N - 1) // N 

30 if ibs not in index_block_size and ibs * N <= max_block_size: 

31 index_block_size.append(ibs) 

32 pruned_configs.append(config) 

33 

34 in_n_elements = named_args["in_n_elements"] 

35 

36 # make sure at least one config is at the load-balance sweet point 

37 if in_n_elements % TOTAL_CORE_NUM == 0: 

38 bs = min(max(in_n_elements // TOTAL_CORE_NUM, 1) * N, max_block_size) 

39 else: 

40 bs = min(max(in_n_elements // TOTAL_CORE_NUM, 1) * N + 1, max_block_size) 

41 if (bs + N - 1) // N not in index_block_size: 

42 pruned_configs.append( 

43 triton.Config(kwargs={"BLOCK_SIZE": bs}, num_stages=1, num_warps=1) 

44 ) 

45 

46 return pruned_configs 

47 

48 

49@triton.jit 

50def ld_st_1(indices, N: tl.constexpr, weight_ptr, in_mask, in_offsets, out_ptr): 

51 weight_offsets = indices[:, None] * N + tl.arange(0, N) 

52 embedding_weight = tl.load(weight_ptr + weight_offsets, in_mask[:, None]) 

53 out_offsets = in_offsets[:, None] * N + tl.arange(0, N) 

54 tl.store(out_ptr + out_offsets, embedding_weight, in_mask[:, None]) 

55 

56 

57@libentry() 

58@libtuner( 

59 configs=[ 

60 # [512, 65536] 

61 triton.Config(kwargs={"BLOCK_SIZE": 512 * 2**i}, num_stages=1, num_warps=1) 

62 for i in range(0, 8, 2) 

63 ], 

64 key=["N"], 

65 prune_configs_by={ 

66 "early_config_prune": config_prune, 

67 }, 

68) 

69@triton.jit 

70def one_batch_index_select_kernel( # 2D 

71 out_ptr, 

72 in_ptr, 

73 in_n_elements, 

74 weight_ptr, 

75 N: tl.constexpr, 

76 dtype_size, 

77 inp_numel, 

78 BLOCK_SIZE: tl.constexpr, 

79): 

80 pid = tl.program_id(0) 

81 num_jobs = tl.num_programs(axis=0) 

82 

83 INDEX_BLOCK_SIZE: tl.constexpr = (BLOCK_SIZE + N - 1) // N 

84 

85 step = num_jobs * INDEX_BLOCK_SIZE 

86 iters = tl.cdiv(in_n_elements, step) 

87 

88 # TODO: remove dtype_size once contiguous DMA is ensured 

89 small_out = inp_numel.to(tl.int64) * dtype_size <= 2**31 

90 

91 for i in tl.range(iters): 

92 iter_start = i * step 

93 iter_end = iter_start + step 

94 

95 if iter_end <= in_n_elements: 

96 block_offset = iter_start + pid * INDEX_BLOCK_SIZE 

97 block_len = INDEX_BLOCK_SIZE 

98 else: 

99 rem_n_elements = in_n_elements - iter_start 

100 base_num = rem_n_elements // num_jobs 

101 remn_num = rem_n_elements % num_jobs 

102 extra_one = pid < remn_num 

103 

104 block_offset = iter_start + ( 

105 (base_num + 1) * pid if extra_one else (base_num * pid + remn_num) 

106 ) 

107 block_len = base_num + extra_one 

108 

109 in_offsets = block_offset + tl.arange(0, INDEX_BLOCK_SIZE) 

110 in_mask = in_offsets < (block_offset + block_len) 

111 indices = tl.load(in_ptr + in_offsets, in_mask, other=0.0) 

112 if indices.dtype != tl.int32 and small_out: 

113 indices_int32 = indices.to(tl.int32) 

114 ld_st_1(indices_int32, N, weight_ptr, in_mask, in_offsets, out_ptr) 

115 else: 

116 ld_st_1(indices, N, weight_ptr, in_mask, in_offsets, out_ptr) 

117 

118 

119def config_prune(configs, named_args, **kwargs): 

120 # TODO: bad perf when BLOCK_BATCH is 1 

121 batch_dim = max(named_args["batch_dim"], 2) 

122 index_dim = named_args["index_dim"] 

123 c_dim = named_args["c_dim"] 

124 dtype_size = named_args["dtype_size"] 

125 

126 # difficult to include these critical configs while keeping number of configs small 

127 lb_block_batch_1 = triton.cdiv(batch_dim, TOTAL_CORE_NUM) 

128 lb_block_batch_2 = max(batch_dim // TOTAL_CORE_NUM, 1) 

129 lb_block_index_1 = triton.cdiv(index_dim, TOTAL_CORE_NUM) 

130 lb_block_index_2 = max(index_dim // TOTAL_CORE_NUM, 1) 

131 

132 max_bs = get_max_block_size(dtype_size) 

133 

134 block_batches = set([lb_block_batch_1, lb_block_batch_2, batch_dim]) 

135 block_indices = set([lb_block_index_1, lb_block_index_2, index_dim]) 

136 block_cs = set([c_dim, min(max_bs, c_dim)]) 

137 

138 new_configs = [] 

139 for config in configs: 

140 block_batch = config.kwargs["BLOCK_BATCH"] 

141 block_index = config.kwargs["BLOCK_INDEX"] 

142 block_c = config.kwargs["BLOCK_C"] 

143 

144 # to keep the autotune space small: if c_dim is not very large, don't split c 

145 block_c_max = 2048 * 5 

146 block_c = c_dim if c_dim <= block_c_max else block_c 

147 

148 if block_batch <= batch_dim and block_index <= index_dim and block_c <= c_dim: 

149 block_batches.add(block_batch) 

150 block_indices.add(block_index) 

151 block_cs.add(block_c) 

152 

153 for block_batch in block_batches: 

154 for block_index in block_indices: 

155 for block_c in block_cs: 

156 if block_batch * block_index * block_c <= max_bs: 

157 new_configs.append( 

158 triton.Config( 

159 { 

160 "BLOCK_BATCH": block_batch, 

161 "BLOCK_INDEX": block_index, 

162 "BLOCK_C": block_c, 

163 }, 

164 num_warps=1, 

165 num_stages=1, 

166 ) 

167 ) 

168 return new_configs 

169 

170 

171@triton.jit 

172def ld_st_2( 

173 inp, 

174 out, 

175 batch_offsets, 

176 index_offsets, 

177 c_offsets, 

178 inp_strides_0, 

179 inp_strides_1, 

180 out_strides_0, 

181 out_strides_1, 

182 index_cur, 

183 input_output_mask, 

184): 

185 input_offsets = (batch_offsets * inp_strides_0)[:, None, None] + ( 

186 (index_cur * inp_strides_1)[:, None] + c_offsets[None, :] 

187 )[None, :, :] 

188 

189 output_offsets = (batch_offsets * out_strides_0)[:, None, None] + ( 

190 (index_offsets * out_strides_1)[:, None] + c_offsets[None, :] 

191 )[None, :, :] 

192 

193 selected = tl.load(inp + input_offsets, mask=input_output_mask, other=0.0) 

194 tl.store(out + output_offsets, selected, mask=input_output_mask) 

195 

196 

197@libentry() 

198@libtuner( 

199 configs=runtime.get_tuned_config("index_select"), 

200 key=["batch_dim", "index_dim", "c_dim"], 

201 prune_configs_by={"early_config_prune": config_prune}, 

202) 

203@triton.jit 

204def multi_batch_index_select_kernel( 

205 inp, 

206 index, 

207 out, 

208 batch_dim, 

209 select_dim, 

210 c_dim, 

211 index_dim, 

212 dtype_size, 

213 inp_numel, 

214 BLOCK_BATCH: tl.constexpr, 

215 BLOCK_INDEX: tl.constexpr, 

216 BLOCK_C: tl.constexpr, 

217): 

218 pid_x = tl.program_id(axis=0) 

219 num_programs = tl.num_programs(axis=0) 

220 

221 block_id_start = pid_x 

222 block_id_step = num_programs 

223 

224 block_batch: tl.constexpr = BLOCK_BATCH 

225 block_index: tl.constexpr = BLOCK_INDEX 

226 block_c: tl.constexpr = BLOCK_C 

227 

228 block_num_batch = tl.cdiv(batch_dim, block_batch) 

229 block_num_index = tl.cdiv(index_dim, block_index) 

230 block_num_c = tl.cdiv(c_dim, block_c) 

231 

232 block_num_total = block_num_batch * block_num_index * block_num_c 

233 

234 inp_strides_0, inp_strides_1 = [select_dim * c_dim, c_dim] 

235 out_strides_0, out_strides_1 = [index_dim * c_dim, c_dim] 

236 block_strides_0, block_strides_1 = [block_num_index * block_num_c, block_num_c] 

237 

238 # TODO: remove dtype_size once contiguous DMA is ensured 

239 small_out = inp_numel.to(tl.int64) * dtype_size <= 2**31 

240 

241 for block_id in tl.range(block_id_start, block_num_total, block_id_step): 

242 block_id_batch = block_id // block_strides_0 

243 block_id_index = (block_id // block_strides_1) % block_num_index 

244 block_id_c = block_id % block_num_c 

245 

246 # arange requires constexpr 

247 batch_offsets = block_id_batch * block_batch + tl.arange(0, block_batch) 

248 batch_mask = batch_offsets < batch_dim 

249 

250 index_offsets = block_id_index * block_index + tl.arange(0, block_index) 

251 index_mask = index_offsets < index_dim 

252 

253 c_offsets = block_id_c * block_c + tl.arange(0, block_c) 

254 c_mask = c_offsets < c_dim 

255 

256 input_output_mask = ( 

257 batch_mask[:, None, None] 

258 and (index_mask[:, None] and c_mask[None, :])[None, :, :] 

259 ) 

260 

261 index_cur = tl.load(index + index_offsets, mask=index_mask, other=0) 

262 # TODO: remove dtype_size once contiguous DMA is ensured 

263 if index.dtype != tl.int32 and small_out: 

264 index_cur_int32 = index_cur.to(tl.int32) 

265 ld_st_2( 

266 inp, 

267 out, 

268 batch_offsets, 

269 index_offsets, 

270 c_offsets, 

271 inp_strides_0, 

272 inp_strides_1, 

273 out_strides_0, 

274 out_strides_1, 

275 index_cur_int32, 

276 input_output_mask, 

277 ) 

278 else: 

279 ld_st_2( 

280 inp, 

281 out, 

282 batch_offsets, 

283 index_offsets, 

284 c_offsets, 

285 inp_strides_0, 

286 inp_strides_1, 

287 out_strides_0, 

288 out_strides_1, 

289 index_cur, 

290 input_output_mask, 

291 ) 

292 

293 

294def index_select(inp, dim, index): 

295 logger.debug("GEMS_CAMBRICON INDEX SELECT") 

296 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

297 assert index.ndim <= 1, "Index should have dimension 1 or 0" 

298 # TODO: index is on device, should it be a kernel (like cnnl __assert_fail__) to check this? 

299 assert ((i >= 0 and i < inp.size(dim)) for i in index), "Index out of range" 

300 

301 # TODO: make sure input is contiguous 

302 

303 if index.ndim == 0: 

304 index = index.unsqueeze(0) 

305 dim = dim % inp.ndim 

306 inp_shape = list(inp.shape) 

307 index_dim = index.numel() 

308 

309 # input [batch_dim, select_dim, c_dim] 

310 # output [batch_dim, index_dim, c_dim] 

311 inp = inp.contiguous() 

312 index = index.contiguous() 

313 inp_numel = inp.numel() 

314 batch_dim = math.prod(inp_shape[:dim]) 

315 select_dim = inp_shape[dim] 

316 c_dim = math.prod(inp_shape[(dim + 1) :]) 

317 

318 out_shape = inp_shape 

319 out_shape[dim] = index_dim 

320 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

321 

322 if torch.is_floating_point(inp): 

323 dtype_size = torch.finfo(inp.dtype).bits // 8 

324 else: 

325 dtype_size = torch.iinfo(inp.dtype).bits // 8 

326 

327 if batch_dim == 1 and c_dim <= get_max_block_size(dtype_size): 

328 # ram: (input, output), half, extra 

329 # 2D, not split c_dim 

330 def grid_fn(meta): 

331 index_block_size_grid = max(meta["BLOCK_SIZE"] // c_dim, 1) 

332 index_block_num = triton.cdiv(index_dim, index_block_size_grid) 

333 return (min(index_block_num, TOTAL_CORE_NUM),) 

334 

335 one_batch_index_select_kernel[grid_fn]( 

336 out, index, index_dim, inp, c_dim, dtype_size, inp_numel 

337 ) 

338 else: 

339 grid = lambda meta: ( 

340 min( 

341 triton.cdiv(batch_dim, meta["BLOCK_BATCH"]) 

342 * triton.cdiv(index_dim, meta["BLOCK_INDEX"]) 

343 * triton.cdiv(c_dim, meta["BLOCK_C"]), 

344 TOTAL_CORE_NUM, 

345 ), 

346 ) 

347 multi_batch_index_select_kernel[grid]( 

348 inp, 

349 index, 

350 out, 

351 batch_dim, 

352 select_dim, 

353 c_dim, 

354 index_dim, 

355 dtype_size, 

356 inp_numel, 

357 ) 

358 return out