Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/index.py: 0%

293 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

15 # Filter out None values (basic indexing markers) 

16 tensor_indices = [idx for idx in indices if idx is not None] 

17 if len(tensor_indices) == 0: 

18 return [] 

19 max_rank = max([len(index.shape) for index in tensor_indices]) 

20 shape = [0 for _ in range(max_rank)] 

21 for i in range(max_rank): 

22 max_num = 0 

23 for index in tensor_indices: 

24 axis = len(index.shape) - 1 - i 

25 if axis >= 0: 

26 max_num = max(max_num, index.shape[axis]) # 

27 shape[max_rank - 1 - i] = max_num 

28 return shape 

29 

30 

31def broadcast_indices(indices, target_shape): 

32 for i, index in enumerate(indices): 

33 if index is not None and tuple(index.shape) != tuple(target_shape): 

34 indices[i] = torch.broadcast_to(index, target_shape) 

35 

36 

37def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

38 code.writeline("import triton") 

39 code.writeline("import triton.language as tl") 

40 code.writeline("import builtins") 

41 code.newline() 

42 code.writeline("from flag_gems.utils import libentry") 

43 code.writeline("from flag_gems import runtime") 

44 code.writeline("from flag_gems.utils.shape_utils import volume") 

45 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

46 

47 code.newline() 

48 code.newline() 

49 return code 

50 

51 

52def generate_index_kernel( 

53 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer 

54): 

55 code.newline() 

56 code.newline() 

57 

58 code.writeline("def heur_block_m(args):") 

59 with code.indent(): 

60 code.writeline( 

61 'return builtins.max(1, triton.next_power_of_2(triton.cdiv(args["M"], 12)))' 

62 ) 

63 

64 code.newline() 

65 

66 code.writeline("def heur_block_n(args):") 

67 with code.indent(): 

68 code.writeline( 

69 'return builtins.max(1, builtins.min(triton.next_power_of_2(args["N"]), 4096))' 

70 ) 

71 

72 code.newline() 

73 code.newline() 

74 code.writeline("@libentry()") 

75 code.writeline("@triton.heuristics(") 

76 with code.indent(): 

77 code.writeline("values={") 

78 with code.indent(): 

79 code.writeline('"BLOCK_SIZE0": heur_block_m,') 

80 code.writeline('"BLOCK_SIZE1": heur_block_n,') 

81 code.writeline("},") 

82 code.writeline(")") 

83 

84 # code.writeline("@libtuner(") 

85 # with code.indent(): 

86 # code.writeline('configs=runtime.get_tuned_config("index"),') 

87 # code.writeline('key=["M", "N"],') 

88 # code.writeline('strategy=["align32", "align32"],') 

89 # code.writeline("warmup=5,") 

90 # code.writeline("rep=10,") 

91 # code.writeline(")") 

92 

93 code.writeline("@triton.jit") 

94 code.writeline(f"def {kernel_name}(") 

95 with code.indent(): 

96 args = ["input_ptr,"] 

97 args += [f"indices{i}_ptr," for i in range(indices_len)] 

98 args += ["out_ptr,"] 

99 args += [f"input_shape{i}," for i in range(inp_rank)] 

100 for i in range(indices_len): 

101 args += [f"indices{i}_shape{j}," for j in range(index_rank)] 

102 args += [f"input_stride{i}," for i in range(inp_rank)] 

103 for i in range(indices_len): 

104 args += [f"indices{i}_stride{j}," for j in range(index_rank)] 

105 args += [f"out_stride{i}," for i in range(index_rank + inp_rank - indices_len)] 

106 args += [ 

107 "M,", 

108 "N,", 

109 "BLOCK_SIZE0: tl.constexpr,", 

110 "BLOCK_SIZE1: tl.constexpr,", 

111 ] 

112 code.writelines(args) 

113 code.writeline("):") 

114 

115 with code.indent(): 

116 code.writeline("pid0 = tle.program_id(axis=0)") 

117 code.writeline("pid1 = tle.program_id(axis=1)") 

118 code.writeline( 

119 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" 

120 ) 

121 if inp_rank == indices_len: 

122 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]") 

123 else: 

124 code.writeline( 

125 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]" 

126 ) 

127 code.newline() 

128 code.writeline("cur_idx = offset0") 

129 for i in range(index_rank - 1, -1, -1): 

130 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}") 

131 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}") 

132 code.newline() 

133 code.writeline("cur_idx = offset1") 

134 for i in range(inp_rank - 1, indices_len - 1, -1): 

135 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}") 

136 code.writeline(f"cur_idx = cur_idx // input_shape{i}") 

137 code.newline() 

138 code.writeline("mask0 = offset0 < M") 

139 for i in range(indices_len): 

140 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)] 

141 code.writeline( 

142 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)" 

143 ) 

144 code.newline() 

145 index_mask = [ 

146 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})" 

147 for i in range(indices_len) 

148 ] 

149 code.writeline(f"index_mask = {' & '.join(index_mask)}") 

150 code.writeline("mask1 = offset1 < N") 

151 code.writeline("mask = index_mask & mask0 & mask1") 

152 code.newline() 

153 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)] 

154 comp += [ 

155 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank) 

156 ] 

157 code.writeline(f"input_offset = {' + '.join(comp)}") 

158 comp = [f"indices_idx{i} * out_stride{i}" for i in range(index_rank)] 

159 comp += [ 

160 f"input_idx{indices_len + i} * out_stride{index_rank + i}" 

161 for i in range(inp_rank - indices_len) 

162 ] 

163 code.writeline(f"out_offset = {' + '.join(comp)}") 

164 code.newline() 

165 code.writeline("cur_value = tl.load(input_ptr + input_offset , mask = mask)") 

166 code.writeline("tl.store(out_ptr + out_offset, cur_value, mask=mask)") 

167 

168 code.newline() 

169 code.newline() 

170 return code 

171 

172 

173def generate_index_wrapper( 

174 inp_rank, 

175 indices_len, 

176 index_rank, 

177 wrapper_name: str, 

178 kernel_name: str, 

179 code: IndentedBuffer, 

180): 

181 code.writeline(f"def {wrapper_name}(input, indices, out):") 

182 with code.indent(): 

183 code.writeline("input_shape = input.shape") 

184 code.writeline("input_stride = input.stride()") 

185 for i in range(indices_len): 

186 code.writeline(f"indices{i}_shape = indices[{i}].shape") 

187 code.writeline(f"indices{i}_stride = indices[{i}].stride()") 

188 code.writeline("out_shape = out.shape") 

189 code.writeline("out_stride = out.stride()") 

190 code.writeline("M = indices[0].numel()") 

191 code.writeline(f"N = volume(input_shape[{indices_len}: ])") 

192 code.newline() 

193 code.writeline("grid = lambda meta: (") 

194 with code.indent(): 

195 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ") 

196 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ") 

197 code.writeline(")") 

198 code.newline() 

199 code.writeline(f"{kernel_name}[grid](") 

200 with code.indent(): 

201 args = ["input,"] 

202 args += [f"indices[{i}]," for i in range(indices_len)] 

203 args += ["out,"] 

204 args += [f"input_shape[{i}]," for i in range(inp_rank)] 

205 for i in range(indices_len): 

206 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)] 

207 args += [f"input_stride[{i}]," for i in range(inp_rank)] 

208 for i in range(indices_len): 

209 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)] 

210 args += [ 

211 f"out_stride[{i}]," for i in range(index_rank + inp_rank - indices_len) 

212 ] 

213 args += ["M,", "N,"] 

214 code.writelines(args) 

215 code.writeline(")") 

216 code.writeline("return input") 

217 code.newline() 

218 code.newline() 

219 return code 

220 

221 

222def generate_code( 

223 inputs: Tuple[Any], 

224 wrapper_name: str, 

225 kernel_name: str, 

226 code: IndentedBuffer, 

227): 

228 inp_rank = inputs[0].ndim 

229 # Filter out None values to get actual tensor indices 

230 tensor_indices = [idx for idx in inputs[1] if idx is not None] 

231 indices_len = len(tensor_indices) 

232 if indices_len == 0: 

233 raise ValueError("At least one non-None index tensor is required") 

234 index_rank = tensor_indices[0].ndim 

235 code = generate_imports(code) 

236 generate_index_kernel(inp_rank, indices_len, index_rank, kernel_name, code) 

237 generate_index_wrapper( 

238 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code 

239 ) 

240 return code 

241 

242 

243class IndexFunction: 

244 def __init__(self): 

245 self.pid = os.getpid() 

246 self.overloads: Mapping[str, Callable] = {} 

247 

248 def __call__(self, *args, **kwargs): 

249 inp, tensor_indices, out = args 

250 full_args = (inp, tensor_indices) 

251 

252 key = self.arg_key(*full_args) 

253 if key in self.overloads: 

254 overload = self.overloads[key] 

255 else: 

256 code = IndentedBuffer() 

257 code = generate_code( 

258 full_args, 

259 "_index_wrapper", 

260 "_index_jit_function", 

261 code, 

262 ) 

263 

264 file_name = f"index_{key}.py" 

265 file_path = code_cache_dir() / file_name 

266 write_atomic(file_path, code.getvalue()) 

267 

268 spec = importlib.util.spec_from_file_location( 

269 f"_gen_module_rank_{key}", 

270 file_path, 

271 ) 

272 

273 m = importlib.util.module_from_spec(spec) 

274 spec.loader.exec_module(m) 

275 overload = getattr(m, "_index_wrapper") 

276 self.overloads[key] = overload 

277 

278 return overload(*args) 

279 

280 def arg_key(self, *args, **kwargs): 

281 inp, tensor_indices = args[0], args[1] 

282 inp_rank = inp.ndim 

283 indices_len = len(tensor_indices) 

284 if indices_len == 0: 

285 index_rank = 0 

286 else: 

287 index_rank = tensor_indices[0].ndim 

288 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}" 

289 

290 

291_index_func = IndexFunction() 

292 

293 

294def index(inp, indices): 

295 logger.debug("GEMS INDEX") 

296 original_indices = list(indices) # Save original indices for later checks 

297 indices = list(indices) 

298 

299 if not indices: 

300 raise ValueError("at least one index must be provided") 

301 

302 indices = [ 

303 index.to(inp.device) 

304 if index is not None and index.device != inp.device 

305 else index 

306 for index in indices 

307 ] 

308 

309 # Step 1: Process indices (convert bool/int8 to long, handle None) 

310 # Following PyTorch meta implementation 

311 processed_indices = [] 

312 for i, index in enumerate(indices): 

313 if index is not None: 

314 # Check dtype 

315 if index.dtype in [torch.int8, torch.bool]: 

316 # Convert boolean/int8 mask to long indices 

317 nonzero = index.nonzero() 

318 k = len(processed_indices) 

319 if k + index.ndim > inp.ndim: 

320 raise IndexError( 

321 f"too many indices for tensor of dimension {inp.ndim}" 

322 ) 

323 # Check shape matches 

324 for j in range(index.ndim): 

325 if index.shape[j] != inp.shape[k + j]: 

326 raise IndexError( 

327 f"The shape of the mask {index.shape} at index {i} " 

328 f"does not match the shape of the indexed tensor {inp.shape} at index {k + j}" 

329 ) 

330 # Extract indices from nonzero 

331 for j in range(index.ndim): 

332 processed_indices.append(nonzero.select(1, j)) 

333 elif index.dtype in [torch.long, torch.int, torch.int32, torch.int64]: 

334 processed_indices.append(index) 

335 else: 

336 raise TypeError( 

337 "tensors used as indices must be long, int, byte or bool tensors" 

338 ) 

339 else: 

340 processed_indices.append(None) 

341 

342 indices = processed_indices 

343 

344 # Check indices count 

345 if len(indices) > inp.ndim: 

346 raise IndexError( 

347 f"too many indices for tensor of dimension {inp.ndim} (got {len(indices)})" 

348 ) 

349 

350 # Save for later use 

351 has_any_tensor = any(idx is not None for idx in indices) 

352 starts_with_none = indices[0] is None if indices else False 

353 

354 # Step 2: Broadcast indices (only tensor indices, not None) 

355 tensor_indices = [idx for idx in indices if idx is not None] 

356 if tensor_indices: 

357 # Broadcast all tensor indices together 

358 if len(tensor_indices) > 1: 

359 tensor_indices = list(torch.broadcast_tensors(*tensor_indices)) 

360 # Update indices list with broadcasted tensors 

361 tensor_idx = 0 

362 for i in range(len(indices)): 

363 if indices[i] is not None: 

364 indices[i] = tensor_indices[tensor_idx] 

365 tensor_idx += 1 

366 

367 # Step 3: Add missing None indices (pad to input.ndim) 

368 while len(indices) < inp.ndim: 

369 indices.append(None) 

370 

371 # Step 4: Check if has contiguous subspace 

372 # (all non-None tensors are adjacent) 

373 state = 0 

374 has_contiguous_subspace = False 

375 for index in indices: 

376 if state == 0: 

377 if index is not None: 

378 state = 1 

379 elif state == 1: 

380 if index is None: 

381 state = 2 

382 else: 

383 if index is not None: 

384 break 

385 else: 

386 has_contiguous_subspace = True 

387 

388 # Transpose if not contiguous OR starts with None (and has tensor indices) 

389 need_post_process = False 

390 first_tensor_dim = None 

391 if not has_contiguous_subspace or (starts_with_none and has_any_tensor): 

392 dims = [] 

393 transposed_indices = [] 

394 # First add all non-None index positions 

395 for i, index in enumerate(indices): 

396 if index is not None: 

397 dims.append(i) 

398 transposed_indices.append(index) 

399 # Then add all None positions 

400 for i, index in enumerate(indices): 

401 if index is None: 

402 dims.append(i) 

403 transposed_indices.append(index) 

404 # Permute input 

405 inp = inp.permute(dims) 

406 indices = transposed_indices 

407 

408 # Check if we need post-processing 

409 # (only when originally started with None and was contiguous) 

410 if starts_with_none and has_any_tensor and has_contiguous_subspace: 

411 need_post_process = True 

412 # Find first tensor dimension in original indices 

413 for i, idx in enumerate(original_indices): 

414 if idx is not None: 

415 first_tensor_dim = i 

416 break 

417 

418 # Step 5: Now indices have contiguous subspace (after potential transpose) 

419 # Calculate output shape: before_shape + replacement_shape + after_shape 

420 before_shape = [] 

421 after_shape = [] 

422 replacement_shape = [] 

423 

424 for dim, index in enumerate(indices): 

425 if index is None: 

426 if replacement_shape: 

427 # None after tensor indices -> goes to after_shape 

428 after_shape.append(inp.shape[dim]) 

429 else: 

430 # None before tensor indices -> goes to before_shape 

431 before_shape.append(inp.shape[dim]) 

432 else: 

433 # First tensor index determines replacement_shape 

434 if not replacement_shape: 

435 replacement_shape = list(index.shape) 

436 

437 # Step 6: Build output shape and create output tensor 

438 out_shape = before_shape + replacement_shape + after_shape 

439 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

440 

441 # Step 7: Handle empty tensor case 

442 if inp.numel() == 0: 

443 return out 

444 

445 # Step 8: Extract only tensor indices for kernel 

446 tensor_indices = [idx for idx in indices if idx is not None] 

447 if not tensor_indices: 

448 # All None, just reshape 

449 return inp.view(*out_shape) 

450 

451 # Step 9: Call kernel with tensor indices 

452 _index_func(inp, tensor_indices, out) 

453 

454 # Step 10: Post-process if needed (for originally contiguous tensor indices starting with None) 

455 if need_post_process: 

456 # Calculate index_rank from the first tensor index 

457 index_rank = tensor_indices[0].ndim 

458 # Create permutation order to move broadcast dimensions to correct position 

459 pre_dims = list(range(index_rank, index_rank + first_tensor_dim)) 

460 broadcast_dims = list(range(index_rank)) 

461 post_dims = list(range(index_rank + first_tensor_dim, out.ndim)) 

462 new_order = pre_dims + broadcast_dims + post_dims 

463 out = out.permute(new_order) 

464 

465 return out