Coverage for src/flag_gems/runtime/backend/_metax/ops/index_put.py: 0%

233 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger("flag_gems." + __name__) 

12 

13 

14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

15 # Filter out None values (basic indexing markers) 

16 tensor_indices = [idx for idx in indices if idx is not None] 

17 if len(tensor_indices) == 0: 

18 return [] 

19 max_rank = max([len(index.shape) for index in tensor_indices]) 

20 shape = [0 for _ in range(max_rank)] 

21 for i in range(max_rank): 

22 max_num = 0 

23 for index in tensor_indices: 

24 axis = len(index.shape) - 1 - i 

25 if axis >= 0: 

26 max_num = max(max_num, index.shape[axis]) 

27 shape[max_rank - 1 - i] = max_num 

28 return shape 

29 

30 

31def broadcast_indices(indices, target_shape): 

32 for i, index in enumerate(indices): 

33 if index is not None and tuple(index.shape) != tuple(target_shape): 

34 indices[i] = torch.broadcast_to(index, target_shape) 

35 

36 

37def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

38 code.writeline("import triton") 

39 code.writeline("import triton.language as tl") 

40 code.newline() 

41 code.writeline("from flag_gems.utils import libentry, libtuner") 

42 code.writeline("from flag_gems import runtime") 

43 code.writeline("from flag_gems.utils.shape_utils import volume") 

44 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

45 

46 code.newline() 

47 code.newline() 

48 return code 

49 

50 

51def generate_index_put_kernel( 

52 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer 

53): 

54 code.writeline("@libentry()") 

55 code.writeline( 

56 '@triton.autotune(configs=runtime.get_tuned_config("index_put"), key=["M", "N"], restore_value=["input_ptr"])' 

57 ) 

58 code.writeline("@triton.jit") 

59 code.writeline(f"def {kernel_name}(") 

60 with code.indent(): 

61 args = ["input_ptr,"] 

62 args += [f"indices{i}_ptr," for i in range(indices_len)] 

63 args += ["values_ptr,"] 

64 args += [f"input_shape{i}," for i in range(inp_rank)] 

65 for i in range(indices_len): 

66 args += [f"indices{i}_shape{j}," for j in range(index_rank)] 

67 args += [f"input_stride{i}," for i in range(inp_rank)] 

68 for i in range(indices_len): 

69 args += [f"indices{i}_stride{j}," for j in range(index_rank)] 

70 args += [ 

71 f"values_stride{i}," for i in range(index_rank + inp_rank - indices_len) 

72 ] 

73 args += [ 

74 "M,", 

75 "N,", 

76 "IS_ACCUMULATE: tl.constexpr,", 

77 "BLOCK_SIZE0: tl.constexpr,", 

78 "BLOCK_SIZE1: tl.constexpr,", 

79 ] 

80 code.writelines(args) 

81 code.writeline("):") 

82 

83 with code.indent(): 

84 code.writeline("pid0 = tle.program_id(axis=0)") 

85 code.writeline("pid1 = tle.program_id(axis=1)") 

86 code.writeline( 

87 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" 

88 ) 

89 if inp_rank == indices_len: 

90 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]") 

91 else: 

92 code.writeline( 

93 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]" 

94 ) 

95 code.newline() 

96 code.writeline("cur_idx = offset0") 

97 for i in range(index_rank - 1, -1, -1): 

98 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}") 

99 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}") 

100 code.newline() 

101 code.writeline("cur_idx = offset1") 

102 for i in range(inp_rank - 1, indices_len - 1, -1): 

103 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}") 

104 code.writeline(f"cur_idx = cur_idx // input_shape{i}") 

105 code.newline() 

106 code.writeline("mask0 = offset0 < M") 

107 for i in range(indices_len): 

108 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)] 

109 code.writeline( 

110 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)" 

111 ) 

112 code.newline() 

113 index_mask = [ 

114 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})" 

115 for i in range(indices_len) 

116 ] 

117 code.writeline(f"index_mask = {' & '.join(index_mask)}") 

118 code.writeline("mask1 = offset1 < N") 

119 code.writeline("mask = index_mask & mask0 & mask1") 

120 code.newline() 

121 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)] 

122 comp += [ 

123 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank) 

124 ] 

125 code.writeline(f"input_offset = {' + '.join(comp)}") 

126 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)] 

127 comp += [ 

128 f"input_idx{indices_len + i} * values_stride{index_rank + i}" 

129 for i in range(inp_rank - indices_len) 

130 ] 

131 code.writeline(f"values_offset = {' + '.join(comp)}") 

132 code.newline() 

133 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)") 

134 code.writeline("if IS_ACCUMULATE:") 

135 with code.indent(): 

136 code.writeline( 

137 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)" 

138 ) 

139 code.writeline("else:") 

140 with code.indent(): 

141 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)") 

142 

143 code.newline() 

144 code.newline() 

145 return code 

146 

147 

148def generate_index_put_wrapper( 

149 inp_rank, 

150 indices_len, 

151 index_rank, 

152 wrapper_name: str, 

153 kernel_name: str, 

154 code: IndentedBuffer, 

155): 

156 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):") 

157 with code.indent(): 

158 code.writeline("input_shape = input.shape") 

159 code.writeline("input_stride = input.stride()") 

160 for i in range(indices_len): 

161 code.writeline(f"indices{i}_shape = indices[{i}].shape") 

162 code.writeline(f"indices{i}_stride = indices[{i}].stride()") 

163 code.writeline("values_shape = values.shape") 

164 code.writeline("values_stride = values.stride()") 

165 code.writeline("M = indices[0].numel()") 

166 code.writeline(f"N = volume(input_shape[{indices_len}: ])") 

167 code.newline() 

168 code.writeline("grid = lambda meta: (") 

169 with code.indent(): 

170 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ") 

171 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ") 

172 code.writeline(")") 

173 code.newline() 

174 code.writeline(f"{kernel_name}[grid](") 

175 with code.indent(): 

176 args = ["input,"] 

177 args += [f"indices[{i}]," for i in range(indices_len)] 

178 args += ["values,"] 

179 args += [f"input_shape[{i}]," for i in range(inp_rank)] 

180 for i in range(indices_len): 

181 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)] 

182 args += [f"input_stride[{i}]," for i in range(inp_rank)] 

183 for i in range(indices_len): 

184 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)] 

185 args += [ 

186 f"values_stride[{i}]," 

187 for i in range(index_rank + inp_rank - indices_len) 

188 ] 

189 args += ["M,", "N,", "accumulate==True,"] 

190 code.writelines(args) 

191 code.writeline(")") 

192 code.writeline("return input") 

193 code.newline() 

194 code.newline() 

195 return code 

196 

197 

198def generate_code( 

199 inputs: Tuple[Any], 

200 wrapper_name: str, 

201 kernel_name: str, 

202 code: IndentedBuffer, 

203): 

204 inp_rank = inputs[0].ndim 

205 # Filter out None values to get actual tensor indices 

206 tensor_indices = [idx for idx in inputs[1] if idx is not None] 

207 indices_len = len(tensor_indices) 

208 if indices_len == 0: 

209 raise ValueError("At least one non-None index tensor is required") 

210 index_rank = tensor_indices[0].ndim 

211 code = generate_imports(code) 

212 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code) 

213 generate_index_put_wrapper( 

214 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code 

215 ) 

216 return code 

217 

218 

219class IndexPutFunction: 

220 def __init__(self): 

221 self.pid = os.getpid() 

222 self.overloads: Mapping[str, Callable] = {} 

223 

224 def __call__(self, *args, **kwargs): 

225 inp, tensor_indices, values, accumulate = args 

226 full_args = (inp, tensor_indices, values) 

227 

228 key = self.arg_key(*full_args) 

229 if key in self.overloads: 

230 overload = self.overloads[key] 

231 else: 

232 code = IndentedBuffer() 

233 code = generate_code( 

234 full_args, 

235 "_index_put_wrapper", 

236 "_index_put_jit_function", 

237 code, 

238 ) 

239 file_name = f"index_put_{key}.py" 

240 file_path = code_cache_dir() / file_name 

241 write_atomic(file_path, code.getvalue()) 

242 

243 spec = importlib.util.spec_from_file_location( 

244 f"_gen_module_rank_{key}", 

245 file_path, 

246 ) 

247 

248 m = importlib.util.module_from_spec(spec) 

249 spec.loader.exec_module(m) 

250 overload = getattr(m, "_index_put_wrapper") 

251 self.overloads[key] = overload 

252 

253 return overload(*args) 

254 

255 def arg_key(self, *args, **kwargs): 

256 inp, tensor_indices, _ = args[0], args[1], args[2] 

257 inp_rank = inp.ndim 

258 indices_len = len(tensor_indices) 

259 if indices_len == 0: 

260 index_rank = 0 

261 else: 

262 index_rank = tensor_indices[0].ndim 

263 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}" 

264 

265 

266_index_put_func = IndexPutFunction() 

267 

268 

269def index_put(inp, indices, values, accumulate=False): 

270 logger.debug("GEMS INDEX PUT") 

271 

272 indices = list(indices) 

273 if len(indices) == 1 and indices[0].dtype == torch.bool: 

274 mask = indices[0] 

275 

276 if mask.device != inp.device: 

277 mask = mask.to(inp.device) 

278 

279 indices = list(torch.where(mask)) 

280 

281 K = indices[0].numel() 

282 target_shape = (K,) + inp.shape[len(indices) :] 

283 

284 if values.numel() == 1: 

285 values = torch.full( 

286 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

287 ) 

288 elif values.numel() == K: 

289 values = values.reshape((K,)).expand(target_shape) 

290 

291 indices = [ 

292 index.to(inp.device) 

293 if index is not None and index.device != inp.device 

294 else index 

295 for index in indices 

296 ] 

297 

298 target_shape = get_max_rank_shape(indices) 

299 broadcast_indices(indices, target_shape) 

300 target_shape += inp.shape[len(indices) :] 

301 # Filter out None values for kernel call (only tensor indices) 

302 # Must be done AFTER broadcast_indices, as broadcast may create new tensors 

303 tensor_indices = [idx for idx in indices if idx is not None] 

304 if not tensor_indices: 

305 raise ValueError("At least one non-None index tensor is required") 

306 

307 if values.device != inp.device: 

308 values = values.to(inp.device) 

309 values = torch.broadcast_to(values, target_shape) 

310 

311 out = inp.clone() 

312 _index_put_func(out, tensor_indices, values, accumulate) 

313 return out 

314 

315 

316def index_put_(inp, indices, values, accumulate=False): 

317 logger.debug("GEMS INDEX PUT_") 

318 

319 indices = list(indices) 

320 if len(indices) == 1 and indices[0].dtype == torch.bool: 

321 mask = indices[0] 

322 

323 if mask.device != inp.device: 

324 mask = mask.to(inp.device) 

325 

326 indices = list(torch.where(mask)) 

327 

328 K = indices[0].numel() 

329 target_shape = (K,) + inp.shape[len(indices) :] 

330 

331 if values.numel() == 1: 

332 values = torch.full( 

333 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

334 ) 

335 elif values.numel() == K: 

336 values = values.reshape((K,)).expand(target_shape) 

337 

338 indices = [ 

339 index.to(inp.device) 

340 if index is not None and index.device != inp.device 

341 else index 

342 for index in indices 

343 ] 

344 

345 target_shape = get_max_rank_shape(indices) 

346 broadcast_indices(indices, target_shape) 

347 target_shape += inp.shape[len(indices) :] 

348 # Filter out None values for kernel call (only tensor indices) 

349 # Must be done AFTER broadcast_indices, as broadcast may create new tensors 

350 tensor_indices = [idx for idx in indices if idx is not None] 

351 if not tensor_indices: 

352 raise ValueError("At least one non-None index tensor is required") 

353 

354 if values.device != inp.device: 

355 values = values.to(inp.device) 

356 values = torch.broadcast_to(values, target_shape) 

357 

358 _index_put_func(inp, tensor_indices, values, accumulate) 

359 return inp