Coverage for src/flag_gems/ops/index_put.py: 91%

237 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

15 # Filter out None values (basic indexing markers) 

16 tensor_indices = [idx for idx in indices if idx is not None] 

17 if len(tensor_indices) == 0: 

18 return [] 

19 max_rank = max([len(index.shape) for index in tensor_indices]) 

20 shape = [0 for _ in range(max_rank)] 

21 for i in range(max_rank): 

22 max_num = 0 

23 for index in tensor_indices: 

24 axis = len(index.shape) - 1 - i 

25 if axis >= 0: 

26 max_num = max(max_num, index.shape[axis]) 

27 shape[max_rank - 1 - i] = max_num 

28 return shape 

29 

30 

31def broadcast_indices(indices, target_shape): 

32 for i, index in enumerate(indices): 

33 if index is not None and tuple(index.shape) != tuple(target_shape): 

34 indices[i] = torch.broadcast_to(index, target_shape) 

35 

36 

37def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

38 code.writeline("import triton") 

39 code.writeline("import triton.language as tl") 

40 code.newline() 

41 code.writeline("from flag_gems.utils import libentry") 

42 code.writeline("from flag_gems import runtime") 

43 code.writeline("from flag_gems.utils.shape_utils import volume") 

44 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

45 

46 code.newline() 

47 code.newline() 

48 return code 

49 

50 

51def generate_index_put_kernel( 

52 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer 

53): 

54 code.writeline("@libentry()") 

55 code.writeline("@triton.jit") 

56 code.writeline(f"def {kernel_name}(") 

57 with code.indent(): 

58 args = ["input_ptr,"] 

59 args += [f"indices{i}_ptr," for i in range(indices_len)] 

60 args += ["values_ptr,"] 

61 args += [f"input_shape{i}," for i in range(inp_rank)] 

62 for i in range(indices_len): 

63 args += [f"indices{i}_shape{j}," for j in range(index_rank)] 

64 args += [f"input_stride{i}," for i in range(inp_rank)] 

65 for i in range(indices_len): 

66 args += [f"indices{i}_stride{j}," for j in range(index_rank)] 

67 args += [ 

68 f"values_stride{i}," for i in range(index_rank + inp_rank - indices_len) 

69 ] 

70 args += [ 

71 "M,", 

72 "N,", 

73 "IS_ACCUMULATE: tl.constexpr,", 

74 "BLOCK_SIZE0: tl.constexpr = 2,", 

75 "BLOCK_SIZE1: tl.constexpr = 2048,", 

76 ] 

77 code.writelines(args) 

78 code.writeline("):") 

79 

80 with code.indent(): 

81 code.writeline("pid0 = tle.program_id(axis=0)") 

82 code.writeline("pid1 = tle.program_id(axis=1)") 

83 code.writeline( 

84 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" 

85 ) 

86 if inp_rank == indices_len: 

87 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]") 

88 else: 

89 code.writeline( 

90 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]" 

91 ) 

92 code.newline() 

93 code.writeline("cur_idx = offset0") 

94 for i in range(index_rank - 1, -1, -1): 

95 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}") 

96 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}") 

97 code.newline() 

98 code.writeline("cur_idx = offset1") 

99 for i in range(inp_rank - 1, indices_len - 1, -1): 

100 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}") 

101 code.writeline(f"cur_idx = cur_idx // input_shape{i}") 

102 code.newline() 

103 code.writeline("mask0 = offset0 < M") 

104 for i in range(indices_len): 

105 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)] 

106 code.writeline( 

107 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)" 

108 ) 

109 code.newline() 

110 index_mask = [ 

111 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})" 

112 for i in range(indices_len) 

113 ] 

114 code.writeline(f"index_mask = {' & '.join(index_mask)}") 

115 code.writeline("mask1 = offset1 < N") 

116 code.writeline("mask = index_mask & mask0 & mask1") 

117 code.newline() 

118 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)] 

119 comp += [ 

120 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank) 

121 ] 

122 code.writeline(f"input_offset = {' + '.join(comp)}") 

123 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)] 

124 comp += [ 

125 f"input_idx{indices_len + i} * values_stride{index_rank + i}" 

126 for i in range(inp_rank - indices_len) 

127 ] 

128 code.writeline(f"values_offset = {' + '.join(comp)}") 

129 code.newline() 

130 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)") 

131 code.writeline("if IS_ACCUMULATE:") 

132 with code.indent(): 

133 code.writeline( 

134 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)" 

135 ) 

136 code.writeline("else:") 

137 with code.indent(): 

138 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)") 

139 

140 code.newline() 

141 code.newline() 

142 return code 

143 

144 

145def generate_index_put_wrapper( 

146 inp_rank, 

147 indices_len, 

148 index_rank, 

149 wrapper_name: str, 

150 kernel_name: str, 

151 code: IndentedBuffer, 

152): 

153 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):") 

154 with code.indent(): 

155 code.writeline("input_shape = input.shape") 

156 code.writeline("input_stride = input.stride()") 

157 for i in range(indices_len): 

158 code.writeline(f"indices{i}_shape = indices[{i}].shape") 

159 code.writeline(f"indices{i}_stride = indices[{i}].stride()") 

160 code.writeline("values_shape = values.shape") 

161 code.writeline("values_stride = values.stride()") 

162 code.writeline("M = indices[0].numel()") 

163 code.writeline(f"N = volume(input_shape[{indices_len}: ])") 

164 code.newline() 

165 code.writeline("grid = lambda meta: (") 

166 with code.indent(): 

167 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ") 

168 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ") 

169 code.writeline(")") 

170 code.newline() 

171 code.writeline(f"{kernel_name}[grid](") 

172 with code.indent(): 

173 args = ["input,"] 

174 args += [f"indices[{i}]," for i in range(indices_len)] 

175 args += ["values,"] 

176 args += [f"input_shape[{i}]," for i in range(inp_rank)] 

177 for i in range(indices_len): 

178 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)] 

179 args += [f"input_stride[{i}]," for i in range(inp_rank)] 

180 for i in range(indices_len): 

181 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)] 

182 args += [ 

183 f"values_stride[{i}]," 

184 for i in range(index_rank + inp_rank - indices_len) 

185 ] 

186 args += ["M,", "N,", "accumulate==True,"] 

187 code.writelines(args) 

188 code.writeline(")") 

189 code.writeline("return input") 

190 code.newline() 

191 code.newline() 

192 return code 

193 

194 

195def generate_code( 

196 inputs: Tuple[Any], 

197 wrapper_name: str, 

198 kernel_name: str, 

199 code: IndentedBuffer, 

200): 

201 inp_rank = inputs[0].ndim 

202 # Filter out None values to get actual tensor indices 

203 tensor_indices = [idx for idx in inputs[1] if idx is not None] 

204 indices_len = len(tensor_indices) 

205 if indices_len == 0: 

206 raise ValueError("At least one non-None index tensor is required") 

207 index_rank = tensor_indices[0].ndim 

208 code = generate_imports(code) 

209 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code) 

210 generate_index_put_wrapper( 

211 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code 

212 ) 

213 return code 

214 

215 

216class IndexPutFunction: 

217 def __init__(self): 

218 self.pid = os.getpid() 

219 self.overloads: Mapping[str, Callable] = {} 

220 

221 def __call__(self, *args, **kwargs): 

222 inp, tensor_indices, values, accumulate = args 

223 full_args = (inp, tensor_indices, values) 

224 

225 key = self.arg_key(*full_args) 

226 if key in self.overloads: 

227 overload = self.overloads[key] 

228 else: 

229 code = IndentedBuffer() 

230 code = generate_code( 

231 full_args, 

232 "_index_put_wrapper", 

233 "_index_put_jit_function", 

234 code, 

235 ) 

236 file_name = f"index_put_{key}.py" 

237 file_path = code_cache_dir() / file_name 

238 write_atomic(file_path, code.getvalue()) 

239 

240 spec = importlib.util.spec_from_file_location( 

241 f"_gen_module_rank_{key}", 

242 file_path, 

243 ) 

244 

245 m = importlib.util.module_from_spec(spec) 

246 spec.loader.exec_module(m) 

247 overload = getattr(m, "_index_put_wrapper") 

248 self.overloads[key] = overload 

249 

250 return overload(*args) 

251 

252 def arg_key(self, *args, **kwargs): 

253 inp, tensor_indices, _ = args[0], args[1], args[2] 

254 inp_rank = inp.ndim 

255 indices_len = len(tensor_indices) 

256 if indices_len == 0: 

257 index_rank = 0 

258 else: 

259 index_rank = tensor_indices[0].ndim 

260 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}" 

261 

262 

263_index_put_func = IndexPutFunction() 

264 

265 

266def index_put(inp, indices, values, accumulate=False): 

267 logger.debug("GEMS INDEX PUT") 

268 

269 out = inp.clone() 

270 return index_put_(out, indices, values, accumulate) 

271 

272 

273def index_put_(inp, indices, values, accumulate=False): 

274 logger.debug("GEMS INDEX PUT_") 

275 

276 indices = list(indices) 

277 

278 if not indices: 

279 raise ValueError("At least one index tensor is required") 

280 

281 indices = [ 

282 index.to(inp.device) 

283 if index is not None and index.device != inp.device 

284 else index 

285 for index in indices 

286 ] 

287 # step 1: index preprocessing 

288 processed_indices = [] 

289 for idx in indices: 

290 if idx is None: 

291 processed_indices.append(None) 

292 elif idx.dtype in (torch.bool, torch.int8): 

293 # Expand bool masks into explicit integer indices 

294 processed_indices.extend(idx.nonzero(as_tuple=True)) 

295 elif torch.is_tensor(idx): 

296 processed_indices.append(idx) 

297 else: 

298 raise TypeError( 

299 "tensors used as indices must be long, int, byte or bool tensors" 

300 ) 

301 

302 indices = processed_indices 

303 # Pad missing None indices to match input dimension 

304 if len(indices) < inp.ndim: 

305 indices.extend([None] * (inp.ndim - len(indices))) 

306 

307 if len(indices) > inp.ndim: 

308 raise IndexError("too many indices for tensor of dimension {}".format(inp.ndim)) 

309 

310 # Step 2: Broadcast tensor indices 

311 tensor_pos = [i for i, x in enumerate(indices) if x is not None] 

312 if not tensor_pos: 

313 raise ValueError("At least one non-None index tensor is required") 

314 

315 tensor_indices = [indices[i] for i in tensor_pos] 

316 if len(tensor_indices) > 1: 

317 broadcasted = torch.broadcast_tensors(*tensor_indices) 

318 for i, pos in enumerate(tensor_pos): 

319 indices[pos] = broadcasted[i] 

320 

321 # Step 3: Transpose 

322 is_contiguous = (tensor_pos[-1] - tensor_pos[0] + 1) == len(tensor_pos) 

323 starts_with_none = indices[0] is None 

324 need_transpose = not is_contiguous or starts_with_none 

325 

326 if need_transpose: 

327 perm_order = tensor_pos + [i for i, x in enumerate(indices) if x is None] 

328 inp_view = inp.permute(perm_order) 

329 final_indices = [indices[i] for i in tensor_pos] + [None] * ( 

330 len(indices) - len(tensor_pos) 

331 ) 

332 else: 

333 inp_view = inp 

334 final_indices = indices 

335 

336 # Step 4: Handle Values shape and broadcasting 

337 tensors = [x for x in final_indices if x is not None] 

338 broadcast_shape = list(tensors[0].shape) 

339 slice_shape = [inp_view.shape[i] for i, x in enumerate(final_indices) if x is None] 

340 

341 target_shape = broadcast_shape + slice_shape 

342 values = values.to(inp.device) 

343 if need_transpose and is_contiguous: 

344 num_before = tensor_pos[0] 

345 

346 # 1. Broadcast to PyTorch natural shape 

347 before_dims = slice_shape[:num_before] 

348 after_dims = slice_shape[num_before:] 

349 natural_shape = before_dims + broadcast_shape + after_dims 

350 values = values.broadcast_to(natural_shape) 

351 

352 # 2. Permute to Kernel expectation 

353 B, T = len(before_dims), len(broadcast_shape) 

354 val_perm = ( 

355 list(range(B, B + T)) + list(range(0, B)) + list(range(B + T, values.ndim)) 

356 ) 

357 values = values.permute(val_perm) 

358 else: 

359 # direct broadcast 

360 values = values.broadcast_to(target_shape) 

361 

362 _index_put_func(inp_view, tensors, values, accumulate) 

363 

364 return inp