Coverage for src/flag_gems/runtime/backend/_metax/ops/index.py: 0%

281 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger("flag_gems." + __name__) 

12 

13 

14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

15 # Filter out None values (basic indexing markers) 

16 tensor_indices = [idx for idx in indices if idx is not None] 

17 if len(tensor_indices) == 0: 

18 return [] 

19 max_rank = max([len(index.shape) for index in tensor_indices]) 

20 shape = [0 for _ in range(max_rank)] 

21 for i in range(max_rank): 

22 max_num = 0 

23 for index in tensor_indices: 

24 axis = len(index.shape) - 1 - i 

25 if axis >= 0: 

26 max_num = max(max_num, index.shape[axis]) # 

27 shape[max_rank - 1 - i] = max_num 

28 return shape 

29 

30 

31def broadcast_indices(indices, target_shape): 

32 for i, index in enumerate(indices): 

33 if index is not None and tuple(index.shape) != tuple(target_shape): 

34 indices[i] = torch.broadcast_to(index, target_shape) 

35 

36 

37def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

38 code.writeline("import triton") 

39 code.writeline("import triton.language as tl") 

40 code.newline() 

41 code.writeline("from flag_gems.utils import libentry, libtuner") 

42 code.writeline("from flag_gems import runtime") 

43 code.writeline("from flag_gems.utils.shape_utils import volume") 

44 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

45 

46 code.newline() 

47 code.newline() 

48 return code 

49 

50 

51def generate_index_kernel( 

52 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer 

53): 

54 code.writeline("@libentry()") 

55 code.writeline("@libtuner(") 

56 with code.indent(): 

57 code.writeline('configs=runtime.get_tuned_config("index"),') 

58 code.writeline('key=["M", "N"],') 

59 code.writeline('strategy=["align32", "align32"],') 

60 code.writeline("warmup=5,") 

61 code.writeline("rep=10,") 

62 code.writeline(")") 

63 code.writeline("@triton.jit") 

64 code.writeline(f"def {kernel_name}(") 

65 with code.indent(): 

66 args = ["input_ptr,"] 

67 args += [f"indices{i}_ptr," for i in range(indices_len)] 

68 args += ["out_ptr,"] 

69 args += [f"input_shape{i}," for i in range(inp_rank)] 

70 for i in range(indices_len): 

71 args += [f"indices{i}_shape{j}," for j in range(index_rank)] 

72 args += [f"input_stride{i}," for i in range(inp_rank)] 

73 for i in range(indices_len): 

74 args += [f"indices{i}_stride{j}," for j in range(index_rank)] 

75 args += [f"out_stride{i}," for i in range(index_rank + inp_rank - indices_len)] 

76 args += [ 

77 "M,", 

78 "N,", 

79 "BLOCK_SIZE0: tl.constexpr,", 

80 "BLOCK_SIZE1: tl.constexpr,", 

81 ] 

82 code.writelines(args) 

83 code.writeline("):") 

84 

85 with code.indent(): 

86 code.writeline("pid0 = tle.program_id(axis=0)") 

87 code.writeline("pid1 = tle.program_id(axis=1)") 

88 code.writeline( 

89 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" 

90 ) 

91 if inp_rank == indices_len: 

92 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]") 

93 else: 

94 code.writeline( 

95 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]" 

96 ) 

97 code.newline() 

98 code.writeline("cur_idx = offset0") 

99 for i in range(index_rank - 1, -1, -1): 

100 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}") 

101 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}") 

102 code.newline() 

103 code.writeline("cur_idx = offset1") 

104 for i in range(inp_rank - 1, indices_len - 1, -1): 

105 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}") 

106 code.writeline(f"cur_idx = cur_idx // input_shape{i}") 

107 code.newline() 

108 code.writeline("mask0 = offset0 < M") 

109 for i in range(indices_len): 

110 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)] 

111 code.writeline( 

112 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)" 

113 ) 

114 code.newline() 

115 index_mask = [ 

116 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})" 

117 for i in range(indices_len) 

118 ] 

119 code.writeline(f"index_mask = {' & '.join(index_mask)}") 

120 code.writeline("mask1 = offset1 < N") 

121 code.writeline("mask = index_mask & mask0 & mask1") 

122 code.newline() 

123 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)] 

124 comp += [ 

125 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank) 

126 ] 

127 code.writeline(f"input_offset = {' + '.join(comp)}") 

128 comp = [f"indices_idx{i} * out_stride{i}" for i in range(index_rank)] 

129 comp += [ 

130 f"input_idx{indices_len + i} * out_stride{index_rank + i}" 

131 for i in range(inp_rank - indices_len) 

132 ] 

133 code.writeline(f"out_offset = {' + '.join(comp)}") 

134 code.newline() 

135 code.writeline("cur_value = tl.load(input_ptr + input_offset , mask = mask)") 

136 code.writeline("tl.store(out_ptr + out_offset, cur_value, mask=mask)") 

137 

138 code.newline() 

139 code.newline() 

140 return code 

141 

142 

143def generate_index_wrapper( 

144 inp_rank, 

145 indices_len, 

146 index_rank, 

147 wrapper_name: str, 

148 kernel_name: str, 

149 code: IndentedBuffer, 

150): 

151 code.writeline(f"def {wrapper_name}(input, indices, out):") 

152 with code.indent(): 

153 code.writeline("input_shape = input.shape") 

154 code.writeline("input_stride = input.stride()") 

155 for i in range(indices_len): 

156 code.writeline(f"indices{i}_shape = indices[{i}].shape") 

157 code.writeline(f"indices{i}_stride = indices[{i}].stride()") 

158 code.writeline("out_shape = out.shape") 

159 code.writeline("out_stride = out.stride()") 

160 code.writeline("M = indices[0].numel()") 

161 code.writeline(f"N = volume(input_shape[{indices_len}: ])") 

162 code.newline() 

163 code.writeline("grid = lambda meta: (") 

164 with code.indent(): 

165 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ") 

166 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ") 

167 code.writeline(")") 

168 code.newline() 

169 code.writeline(f"{kernel_name}[grid](") 

170 with code.indent(): 

171 args = ["input,"] 

172 args += [f"indices[{i}]," for i in range(indices_len)] 

173 args += ["out,"] 

174 args += [f"input_shape[{i}]," for i in range(inp_rank)] 

175 for i in range(indices_len): 

176 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)] 

177 args += [f"input_stride[{i}]," for i in range(inp_rank)] 

178 for i in range(indices_len): 

179 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)] 

180 args += [ 

181 f"out_stride[{i}]," for i in range(index_rank + inp_rank - indices_len) 

182 ] 

183 args += ["M,", "N,"] 

184 code.writelines(args) 

185 code.writeline(")") 

186 code.writeline("return input") 

187 code.newline() 

188 code.newline() 

189 return code 

190 

191 

192def generate_code( 

193 inputs: Tuple[Any], 

194 wrapper_name: str, 

195 kernel_name: str, 

196 code: IndentedBuffer, 

197): 

198 inp_rank = inputs[0].ndim 

199 # Filter out None values to get actual tensor indices 

200 tensor_indices = [idx for idx in inputs[1] if idx is not None] 

201 indices_len = len(tensor_indices) 

202 if indices_len == 0: 

203 raise ValueError("At least one non-None index tensor is required") 

204 index_rank = tensor_indices[0].ndim 

205 code = generate_imports(code) 

206 generate_index_kernel(inp_rank, indices_len, index_rank, kernel_name, code) 

207 generate_index_wrapper( 

208 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code 

209 ) 

210 return code 

211 

212 

213class IndexFunction: 

214 def __init__(self): 

215 self.pid = os.getpid() 

216 self.overloads: Mapping[str, Callable] = {} 

217 

218 def __call__(self, *args, **kwargs): 

219 inp, tensor_indices, out = args 

220 full_args = (inp, tensor_indices) 

221 

222 key = self.arg_key(*full_args) 

223 if key in self.overloads: 

224 overload = self.overloads[key] 

225 else: 

226 code = IndentedBuffer() 

227 code = generate_code( 

228 full_args, 

229 "_index_wrapper", 

230 "_index_jit_function", 

231 code, 

232 ) 

233 

234 file_name = f"index_{key}.py" 

235 file_path = code_cache_dir() / file_name 

236 write_atomic(file_path, code.getvalue()) 

237 

238 spec = importlib.util.spec_from_file_location( 

239 f"_gen_module_rank_{key}", 

240 file_path, 

241 ) 

242 

243 m = importlib.util.module_from_spec(spec) 

244 spec.loader.exec_module(m) 

245 overload = getattr(m, "_index_wrapper") 

246 self.overloads[key] = overload 

247 

248 return overload(*args) 

249 

250 def arg_key(self, *args, **kwargs): 

251 inp, tensor_indices = args[0], args[1] 

252 inp_rank = inp.ndim 

253 indices_len = len(tensor_indices) 

254 if indices_len == 0: 

255 index_rank = 0 

256 else: 

257 index_rank = tensor_indices[0].ndim 

258 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}" 

259 

260 

261_index_func = IndexFunction() 

262 

263 

264def index(inp, indices): 

265 logger.debug("GEMS INDEX") 

266 original_indices = list(indices) # Save original indices for later checks 

267 indices = list(indices) 

268 

269 if not indices: 

270 raise ValueError("at least one index must be provided") 

271 

272 indices = [ 

273 index.to(inp.device) 

274 if index is not None and index.device != inp.device 

275 else index 

276 for index in indices 

277 ] 

278 

279 # Step 1: Process indices (convert bool/int8 to long, handle None) 

280 # Following PyTorch meta implementation 

281 processed_indices = [] 

282 for i, index in enumerate(indices): 

283 if index is not None: 

284 # Check dtype 

285 if index.dtype in [torch.int8, torch.bool]: 

286 # Convert boolean/int8 mask to long indices 

287 nonzero = index.nonzero() 

288 k = len(processed_indices) 

289 if k + index.ndim > inp.ndim: 

290 raise IndexError( 

291 f"too many indices for tensor of dimension {inp.ndim}" 

292 ) 

293 # Check shape matches 

294 for j in range(index.ndim): 

295 if index.shape[j] != inp.shape[k + j]: 

296 raise IndexError( 

297 f"The shape of the mask {index.shape} at index {i} " 

298 f"does not match the shape of the indexed tensor {inp.shape} at index {k + j}" 

299 ) 

300 # Extract indices from nonzero 

301 for j in range(index.ndim): 

302 processed_indices.append(nonzero.select(1, j)) 

303 elif index.dtype in [torch.long, torch.int, torch.int32, torch.int64]: 

304 processed_indices.append(index) 

305 else: 

306 raise TypeError( 

307 "tensors used as indices must be long, int, byte or bool tensors" 

308 ) 

309 else: 

310 processed_indices.append(None) 

311 

312 indices = processed_indices 

313 

314 # Check indices count 

315 if len(indices) > inp.ndim: 

316 raise IndexError( 

317 f"too many indices for tensor of dimension {inp.ndim} (got {len(indices)})" 

318 ) 

319 

320 # Save for later use 

321 has_any_tensor = any(idx is not None for idx in indices) 

322 starts_with_none = indices[0] is None if indices else False 

323 

324 # Step 2: Broadcast indices (only tensor indices, not None) 

325 tensor_indices = [idx for idx in indices if idx is not None] 

326 if tensor_indices: 

327 # Broadcast all tensor indices together 

328 if len(tensor_indices) > 1: 

329 tensor_indices = list(torch.broadcast_tensors(*tensor_indices)) 

330 # Update indices list with broadcasted tensors 

331 tensor_idx = 0 

332 for i in range(len(indices)): 

333 if indices[i] is not None: 

334 indices[i] = tensor_indices[tensor_idx] 

335 tensor_idx += 1 

336 

337 # Step 3: Add missing None indices (pad to input.ndim) 

338 while len(indices) < inp.ndim: 

339 indices.append(None) 

340 

341 # Step 4: Check if has contiguous subspace 

342 # (all non-None tensors are adjacent) 

343 state = 0 

344 has_contiguous_subspace = False 

345 for index in indices: 

346 if state == 0: 

347 if index is not None: 

348 state = 1 

349 elif state == 1: 

350 if index is None: 

351 state = 2 

352 else: 

353 if index is not None: 

354 break 

355 else: 

356 has_contiguous_subspace = True 

357 

358 # Transpose if not contiguous OR starts with None (and has tensor indices) 

359 need_post_process = False 

360 first_tensor_dim = None 

361 if not has_contiguous_subspace or (starts_with_none and has_any_tensor): 

362 dims = [] 

363 transposed_indices = [] 

364 # First add all non-None index positions 

365 for i, index in enumerate(indices): 

366 if index is not None: 

367 dims.append(i) 

368 transposed_indices.append(index) 

369 # Then add all None positions 

370 for i, index in enumerate(indices): 

371 if index is None: 

372 dims.append(i) 

373 transposed_indices.append(index) 

374 # Permute input 

375 inp = inp.permute(dims) 

376 indices = transposed_indices 

377 

378 # Check if we need post-processing 

379 # (only when originally started with None and was contiguous) 

380 if starts_with_none and has_any_tensor and has_contiguous_subspace: 

381 need_post_process = True 

382 # Find first tensor dimension in original indices 

383 for i, idx in enumerate(original_indices): 

384 if idx is not None: 

385 first_tensor_dim = i 

386 break 

387 

388 # Step 5: Now indices have contiguous subspace (after potential transpose) 

389 # Calculate output shape: before_shape + replacement_shape + after_shape 

390 before_shape = [] 

391 after_shape = [] 

392 replacement_shape = [] 

393 

394 for dim, index in enumerate(indices): 

395 if index is None: 

396 if replacement_shape: 

397 # None after tensor indices -> goes to after_shape 

398 after_shape.append(inp.shape[dim]) 

399 else: 

400 # None before tensor indices -> goes to before_shape 

401 before_shape.append(inp.shape[dim]) 

402 else: 

403 # First tensor index determines replacement_shape 

404 if not replacement_shape: 

405 replacement_shape = list(index.shape) 

406 

407 # Step 6: Build output shape and create output tensor 

408 out_shape = before_shape + replacement_shape + after_shape 

409 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

410 

411 # Step 7: Handle empty tensor case 

412 if inp.numel() == 0: 

413 return out 

414 

415 # Step 8: Extract only tensor indices for kernel 

416 tensor_indices = [idx for idx in indices if idx is not None] 

417 if not tensor_indices: 

418 # All None, just reshape 

419 return inp.view(*out_shape) 

420 

421 # Step 9: Call kernel with tensor indices 

422 _index_func(inp, tensor_indices, out) 

423 

424 # Step 10: Post-process if needed (for originally contiguous tensor indices starting with None) 

425 if need_post_process: 

426 # Calculate index_rank from the first tensor index 

427 index_rank = tensor_indices[0].ndim 

428 # Create permutation order to move broadcast dimensions to correct position 

429 pre_dims = list(range(index_rank, index_rank + first_tensor_dim)) 

430 broadcast_dims = list(range(index_rank)) 

431 post_dims = list(range(index_rank + first_tensor_dim, out.ndim)) 

432 new_order = pre_dims + broadcast_dims + post_dims 

433 out = out.permute(new_order) 

434 

435 return out