Coverage for src/flag_gems/ops/index_put.py: 10%

232 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

15 # Filter out None values (basic indexing markers) 

16 tensor_indices = [idx for idx in indices if idx is not None] 

17 if len(tensor_indices) == 0: 

18 return [] 

19 max_rank = max([len(index.shape) for index in tensor_indices]) 

20 shape = [0 for _ in range(max_rank)] 

21 for i in range(max_rank): 

22 max_num = 0 

23 for index in tensor_indices: 

24 axis = len(index.shape) - 1 - i 

25 if axis >= 0: 

26 max_num = max(max_num, index.shape[axis]) 

27 shape[max_rank - 1 - i] = max_num 

28 return shape 

29 

30 

31def broadcast_indices(indices, target_shape): 

32 for i, index in enumerate(indices): 

33 if index is not None and tuple(index.shape) != tuple(target_shape): 

34 indices[i] = torch.broadcast_to(index, target_shape) 

35 

36 

37def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

38 code.writeline("import triton") 

39 code.writeline("import triton.language as tl") 

40 code.newline() 

41 code.writeline("from flag_gems.utils import libentry") 

42 code.writeline("from flag_gems import runtime") 

43 code.writeline("from flag_gems.utils.shape_utils import volume") 

44 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

45 

46 code.newline() 

47 code.newline() 

48 return code 

49 

50 

51def generate_index_put_kernel( 

52 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer 

53): 

54 code.writeline("@libentry()") 

55 code.writeline("@triton.jit") 

56 code.writeline(f"def {kernel_name}(") 

57 with code.indent(): 

58 args = ["input_ptr,"] 

59 args += [f"indices{i}_ptr," for i in range(indices_len)] 

60 args += ["values_ptr,"] 

61 args += [f"input_shape{i}," for i in range(inp_rank)] 

62 for i in range(indices_len): 

63 args += [f"indices{i}_shape{j}," for j in range(index_rank)] 

64 args += [f"input_stride{i}," for i in range(inp_rank)] 

65 for i in range(indices_len): 

66 args += [f"indices{i}_stride{j}," for j in range(index_rank)] 

67 args += [ 

68 f"values_stride{i}," for i in range(index_rank + inp_rank - indices_len) 

69 ] 

70 args += [ 

71 "M,", 

72 "N,", 

73 "IS_ACCUMULATE: tl.constexpr,", 

74 "BLOCK_SIZE0: tl.constexpr = 2,", 

75 "BLOCK_SIZE1: tl.constexpr = 2048,", 

76 ] 

77 code.writelines(args) 

78 code.writeline("):") 

79 

80 with code.indent(): 

81 code.writeline("pid0 = tle.program_id(axis=0)") 

82 code.writeline("pid1 = tle.program_id(axis=1)") 

83 code.writeline( 

84 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" 

85 ) 

86 if inp_rank == indices_len: 

87 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]") 

88 else: 

89 code.writeline( 

90 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]" 

91 ) 

92 code.newline() 

93 code.writeline("cur_idx = offset0") 

94 for i in range(index_rank - 1, -1, -1): 

95 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}") 

96 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}") 

97 code.newline() 

98 code.writeline("cur_idx = offset1") 

99 for i in range(inp_rank - 1, indices_len - 1, -1): 

100 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}") 

101 code.writeline(f"cur_idx = cur_idx // input_shape{i}") 

102 code.newline() 

103 code.writeline("mask0 = offset0 < M") 

104 for i in range(indices_len): 

105 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)] 

106 code.writeline( 

107 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)" 

108 ) 

109 code.newline() 

110 index_mask = [ 

111 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})" 

112 for i in range(indices_len) 

113 ] 

114 code.writeline(f"index_mask = {' & '.join(index_mask)}") 

115 code.writeline("mask1 = offset1 < N") 

116 code.writeline("mask = index_mask & mask0 & mask1") 

117 code.newline() 

118 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)] 

119 comp += [ 

120 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank) 

121 ] 

122 code.writeline(f"input_offset = {' + '.join(comp)}") 

123 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)] 

124 comp += [ 

125 f"input_idx{indices_len + i} * values_stride{index_rank + i}" 

126 for i in range(inp_rank - indices_len) 

127 ] 

128 code.writeline(f"values_offset = {' + '.join(comp)}") 

129 code.newline() 

130 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)") 

131 code.writeline("if IS_ACCUMULATE:") 

132 with code.indent(): 

133 code.writeline( 

134 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)" 

135 ) 

136 code.writeline("else:") 

137 with code.indent(): 

138 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)") 

139 

140 code.newline() 

141 code.newline() 

142 return code 

143 

144 

145def generate_index_put_wrapper( 

146 inp_rank, 

147 indices_len, 

148 index_rank, 

149 wrapper_name: str, 

150 kernel_name: str, 

151 code: IndentedBuffer, 

152): 

153 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):") 

154 with code.indent(): 

155 code.writeline("input_shape = input.shape") 

156 code.writeline("input_stride = input.stride()") 

157 for i in range(indices_len): 

158 code.writeline(f"indices{i}_shape = indices[{i}].shape") 

159 code.writeline(f"indices{i}_stride = indices[{i}].stride()") 

160 code.writeline("values_shape = values.shape") 

161 code.writeline("values_stride = values.stride()") 

162 code.writeline("M = indices[0].numel()") 

163 code.writeline(f"N = volume(input_shape[{indices_len}: ])") 

164 code.newline() 

165 code.writeline("grid = lambda meta: (") 

166 with code.indent(): 

167 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ") 

168 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ") 

169 code.writeline(")") 

170 code.newline() 

171 code.writeline(f"{kernel_name}[grid](") 

172 with code.indent(): 

173 args = ["input,"] 

174 args += [f"indices[{i}]," for i in range(indices_len)] 

175 args += ["values,"] 

176 args += [f"input_shape[{i}]," for i in range(inp_rank)] 

177 for i in range(indices_len): 

178 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)] 

179 args += [f"input_stride[{i}]," for i in range(inp_rank)] 

180 for i in range(indices_len): 

181 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)] 

182 args += [ 

183 f"values_stride[{i}]," 

184 for i in range(index_rank + inp_rank - indices_len) 

185 ] 

186 args += ["M,", "N,", "accumulate==True,"] 

187 code.writelines(args) 

188 code.writeline(")") 

189 code.writeline("return input") 

190 code.newline() 

191 code.newline() 

192 return code 

193 

194 

195def generate_code( 

196 inputs: Tuple[Any], 

197 wrapper_name: str, 

198 kernel_name: str, 

199 code: IndentedBuffer, 

200): 

201 inp_rank = inputs[0].ndim 

202 # Filter out None values to get actual tensor indices 

203 tensor_indices = [idx for idx in inputs[1] if idx is not None] 

204 indices_len = len(tensor_indices) 

205 if indices_len == 0: 

206 raise ValueError("At least one non-None index tensor is required") 

207 index_rank = tensor_indices[0].ndim 

208 code = generate_imports(code) 

209 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code) 

210 generate_index_put_wrapper( 

211 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code 

212 ) 

213 return code 

214 

215 

216class IndexPutFunction: 

217 def __init__(self): 

218 self.pid = os.getpid() 

219 self.overloads: Mapping[str, Callable] = {} 

220 

221 def __call__(self, *args, **kwargs): 

222 inp, tensor_indices, values, accumulate = args 

223 full_args = (inp, tensor_indices, values) 

224 

225 key = self.arg_key(*full_args) 

226 if key in self.overloads: 

227 overload = self.overloads[key] 

228 else: 

229 code = IndentedBuffer() 

230 code = generate_code( 

231 full_args, 

232 "_index_put_wrapper", 

233 "_index_put_jit_function", 

234 code, 

235 ) 

236 file_name = f"index_put_{key}.py" 

237 file_path = code_cache_dir() / file_name 

238 write_atomic(file_path, code.getvalue()) 

239 

240 spec = importlib.util.spec_from_file_location( 

241 f"_gen_module_rank_{key}", 

242 file_path, 

243 ) 

244 

245 m = importlib.util.module_from_spec(spec) 

246 spec.loader.exec_module(m) 

247 overload = getattr(m, "_index_put_wrapper") 

248 self.overloads[key] = overload 

249 

250 return overload(*args) 

251 

252 def arg_key(self, *args, **kwargs): 

253 inp, tensor_indices, _ = args[0], args[1], args[2] 

254 inp_rank = inp.ndim 

255 indices_len = len(tensor_indices) 

256 if indices_len == 0: 

257 index_rank = 0 

258 else: 

259 index_rank = tensor_indices[0].ndim 

260 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}" 

261 

262 

263_index_put_func = IndexPutFunction() 

264 

265 

266def index_put(inp, indices, values, accumulate=False): 

267 logger.debug("GEMS INDEX PUT") 

268 

269 indices = list(indices) 

270 if len(indices) == 1 and indices[0].dtype == torch.bool: 

271 mask = indices[0] 

272 

273 if mask.device != inp.device: 

274 mask = mask.to(inp.device) 

275 

276 indices = list(torch.where(mask)) 

277 

278 K = indices[0].numel() 

279 target_shape = (K,) + inp.shape[len(indices) :] 

280 

281 if values.numel() == 1: 

282 values = torch.full( 

283 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

284 ) 

285 elif values.numel() == K: 

286 values = values.reshape((K,)).expand(target_shape) 

287 

288 indices = [ 

289 index.to(inp.device) 

290 if index is not None and index.device != inp.device 

291 else index 

292 for index in indices 

293 ] 

294 

295 target_shape = get_max_rank_shape(indices) 

296 broadcast_indices(indices, target_shape) 

297 target_shape += inp.shape[len(indices) :] 

298 # Filter out None values for kernel call (only tensor indices) 

299 # Must be done AFTER broadcast_indices, as broadcast may create new tensors 

300 tensor_indices = [idx for idx in indices if idx is not None] 

301 if not tensor_indices: 

302 raise ValueError("At least one non-None index tensor is required") 

303 

304 if values.device != inp.device: 

305 values = values.to(inp.device) 

306 values = torch.broadcast_to(values, target_shape) 

307 

308 out = inp.clone() 

309 _index_put_func(out, tensor_indices, values, accumulate) 

310 return out 

311 

312 

313def index_put_(inp, indices, values, accumulate=False): 

314 logger.debug("GEMS INDEX PUT_") 

315 

316 indices = list(indices) 

317 if len(indices) == 1 and indices[0].dtype == torch.bool: 

318 mask = indices[0] 

319 

320 if mask.device != inp.device: 

321 mask = mask.to(inp.device) 

322 

323 indices = list(torch.where(mask)) 

324 

325 K = indices[0].numel() 

326 target_shape = (K,) + inp.shape[len(indices) :] 

327 

328 if values.numel() == 1: 

329 values = torch.full( 

330 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

331 ) 

332 elif values.numel() == K: 

333 values = values.reshape((K,)).expand(target_shape) 

334 

335 indices = [ 

336 index.to(inp.device) 

337 if index is not None and index.device != inp.device 

338 else index 

339 for index in indices 

340 ] 

341 

342 target_shape = get_max_rank_shape(indices) 

343 broadcast_indices(indices, target_shape) 

344 target_shape += inp.shape[len(indices) :] 

345 # Filter out None values for kernel call (only tensor indices) 

346 # Must be done AFTER broadcast_indices, as broadcast may create new tensors 

347 tensor_indices = [idx for idx in indices if idx is not None] 

348 if not tensor_indices: 

349 raise ValueError("At least one non-None index tensor is required") 

350 

351 if values.device != inp.device: 

352 values = values.to(inp.device) 

353 values = torch.broadcast_to(values, target_shape) 

354 

355 _index_put_func(inp, tensor_indices, values, accumulate) 

356 return inp