Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_put.py: 0%

244 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger( 

12 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

13) 

14 

15 

16def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

17 # Filter out None values (basic indexing markers) 

18 tensor_indices = [idx for idx in indices if idx is not None] 

19 if len(tensor_indices) == 0: 

20 return [] 

21 max_rank = max([len(index.shape) for index in tensor_indices]) 

22 shape = [0 for _ in range(max_rank)] 

23 for i in range(max_rank): 

24 max_num = 0 

25 for index in tensor_indices: 

26 axis = len(index.shape) - 1 - i 

27 if axis >= 0: 

28 max_num = max(max_num, index.shape[axis]) 

29 shape[max_rank - 1 - i] = max_num 

30 return shape 

31 

32 

33def broadcast_indices(indices, target_shape): 

34 for i, index in enumerate(indices): 

35 if index is not None and tuple(index.shape) != tuple(target_shape): 

36 indices[i] = torch.broadcast_to(index, target_shape) 

37 

38 

39def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

40 code.writeline("import triton") 

41 code.writeline("import triton.language as tl") 

42 code.newline() 

43 code.writeline("from flag_gems.utils import libentry, libtuner") 

44 code.writeline("from flag_gems import runtime") 

45 code.writeline("from flag_gems.utils.shape_utils import volume") 

46 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

47 

48 code.newline() 

49 code.newline() 

50 return code 

51 

52 

53def generate_index_put_kernel( 

54 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer 

55): 

56 code.writeline("@libentry()") 

57 code.writeline("@libtuner(") 

58 with code.indent(): 

59 code.writeline('configs=runtime.get_tuned_config("index_put"),') 

60 code.writeline('key=["M", "N"],') 

61 code.writeline('restore_value=["input_ptr"],') 

62 code.writeline('strategy=["align32", "align32"],') 

63 code.writeline("warmup=5,") 

64 code.writeline("rep=10,") 

65 code.writeline(")") 

66 code.writeline("@triton.jit") 

67 code.writeline(f"def {kernel_name}(") 

68 with code.indent(): 

69 args = ["input_ptr,"] 

70 args += [f"indices{i}_ptr," for i in range(indices_len)] 

71 args += ["values_ptr,"] 

72 args += [f"input_shape{i}," for i in range(inp_rank)] 

73 for i in range(indices_len): 

74 args += [f"indices{i}_shape{j}," for j in range(index_rank)] 

75 args += [f"input_stride{i}," for i in range(inp_rank)] 

76 for i in range(indices_len): 

77 args += [f"indices{i}_stride{j}," for j in range(index_rank)] 

78 args += [ 

79 f"values_stride{i}," for i in range(index_rank + inp_rank - indices_len) 

80 ] 

81 args += [ 

82 "M,", 

83 "N,", 

84 "IS_ACCUMULATE: tl.constexpr,", 

85 "BLOCK_SIZE0: tl.constexpr,", 

86 "BLOCK_SIZE1: tl.constexpr,", 

87 ] 

88 code.writelines(args) 

89 code.writeline("):") 

90 

91 with code.indent(): 

92 code.writeline("pid0 = tle.program_id(axis=0)") 

93 code.writeline("pid1 = tle.program_id(axis=1)") 

94 code.writeline( 

95 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" 

96 ) 

97 if inp_rank == indices_len: 

98 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]") 

99 else: 

100 code.writeline( 

101 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]" 

102 ) 

103 code.newline() 

104 code.writeline("cur_idx = offset0") 

105 for i in range(index_rank - 1, -1, -1): 

106 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}") 

107 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}") 

108 code.newline() 

109 code.writeline("cur_idx = offset1") 

110 for i in range(inp_rank - 1, indices_len - 1, -1): 

111 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}") 

112 code.writeline(f"cur_idx = cur_idx // input_shape{i}") 

113 code.newline() 

114 code.writeline("mask0 = offset0 < M") 

115 for i in range(indices_len): 

116 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)] 

117 code.writeline( 

118 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)" 

119 ) 

120 code.newline() 

121 index_mask = [ 

122 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})" 

123 for i in range(indices_len) 

124 ] 

125 code.writeline(f"index_mask = {' & '.join(index_mask)}") 

126 code.writeline("mask1 = offset1 < N") 

127 code.writeline("mask = index_mask & mask0 & mask1") 

128 code.newline() 

129 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)] 

130 comp += [ 

131 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank) 

132 ] 

133 code.writeline(f"input_offset = {' + '.join(comp)}") 

134 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)] 

135 comp += [ 

136 f"input_idx{indices_len + i} * values_stride{index_rank + i}" 

137 for i in range(inp_rank - indices_len) 

138 ] 

139 code.writeline(f"values_offset = {' + '.join(comp)}") 

140 code.newline() 

141 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)") 

142 code.writeline("if IS_ACCUMULATE:") 

143 with code.indent(): 

144 code.writeline( 

145 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)" 

146 ) 

147 code.writeline("else:") 

148 with code.indent(): 

149 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)") 

150 

151 code.newline() 

152 code.newline() 

153 return code 

154 

155 

156def generate_index_put_wrapper( 

157 inp_rank, 

158 indices_len, 

159 index_rank, 

160 wrapper_name: str, 

161 kernel_name: str, 

162 code: IndentedBuffer, 

163): 

164 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):") 

165 with code.indent(): 

166 code.writeline("input_shape = input.shape") 

167 code.writeline("input_stride = input.stride()") 

168 for i in range(indices_len): 

169 code.writeline(f"indices{i}_shape = indices[{i}].shape") 

170 code.writeline(f"indices{i}_stride = indices[{i}].stride()") 

171 code.writeline("values_shape = values.shape") 

172 code.writeline("values_stride = values.stride()") 

173 code.writeline("M = indices[0].numel()") 

174 code.writeline(f"N = volume(input_shape[{indices_len}: ])") 

175 code.newline() 

176 code.writeline("grid = lambda meta: (") 

177 with code.indent(): 

178 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ") 

179 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ") 

180 code.writeline(")") 

181 code.newline() 

182 code.writeline(f"{kernel_name}[grid](") 

183 with code.indent(): 

184 args = ["input,"] 

185 args += [f"indices[{i}]," for i in range(indices_len)] 

186 args += ["values,"] 

187 args += [f"input_shape[{i}]," for i in range(inp_rank)] 

188 for i in range(indices_len): 

189 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)] 

190 args += [f"input_stride[{i}]," for i in range(inp_rank)] 

191 for i in range(indices_len): 

192 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)] 

193 args += [ 

194 f"values_stride[{i}]," 

195 for i in range(index_rank + inp_rank - indices_len) 

196 ] 

197 args += ["M,", "N,", "accumulate==True,"] 

198 code.writelines(args) 

199 code.writeline(")") 

200 code.writeline("return input") 

201 code.newline() 

202 code.newline() 

203 return code 

204 

205 

206def generate_code( 

207 inputs: Tuple[Any], 

208 wrapper_name: str, 

209 kernel_name: str, 

210 code: IndentedBuffer, 

211): 

212 inp_rank = inputs[0].ndim 

213 # Filter out None values to get actual tensor indices 

214 tensor_indices = [idx for idx in inputs[1] if idx is not None] 

215 indices_len = len(tensor_indices) 

216 if indices_len == 0: 

217 raise ValueError("At least one non-None index tensor is required") 

218 index_rank = tensor_indices[0].ndim 

219 code = generate_imports(code) 

220 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code) 

221 generate_index_put_wrapper( 

222 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code 

223 ) 

224 return code 

225 

226 

227class IndexPutFunction: 

228 def __init__(self): 

229 self.pid = os.getpid() 

230 self.overloads: Mapping[str, Callable] = {} 

231 

232 def __call__(self, *args, **kwargs): 

233 inp, tensor_indices, values, accumulate = args 

234 full_args = (inp, tensor_indices, values) 

235 

236 key = self.arg_key(*full_args) 

237 if key in self.overloads: 

238 overload = self.overloads[key] 

239 else: 

240 code = IndentedBuffer() 

241 code = generate_code( 

242 full_args, 

243 "_index_put_wrapper", 

244 "_index_put_jit_function", 

245 code, 

246 ) 

247 file_name = f"index_put_{key}.py" 

248 file_path = code_cache_dir() / file_name 

249 write_atomic(file_path, code.getvalue()) 

250 

251 spec = importlib.util.spec_from_file_location( 

252 f"_gen_module_rank_{key}", 

253 file_path, 

254 ) 

255 

256 m = importlib.util.module_from_spec(spec) 

257 spec.loader.exec_module(m) 

258 overload = getattr(m, "_index_put_wrapper") 

259 self.overloads[key] = overload 

260 

261 return overload(*args) 

262 

263 def arg_key(self, *args, **kwargs): 

264 inp, tensor_indices, _ = args[0], args[1], args[2] 

265 inp_rank = inp.ndim 

266 indices_len = len(tensor_indices) 

267 if indices_len == 0: 

268 index_rank = 0 

269 else: 

270 index_rank = tensor_indices[0].ndim 

271 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}" 

272 

273 

274_index_put_func = IndexPutFunction() 

275 

276 

277def index_put(inp, indices, values, accumulate=False): 

278 logger.debug("GEMS_MTHREADS INDEX PUT") 

279 

280 indices = list(indices) 

281 if len(indices) == 1 and indices[0].dtype == torch.bool: 

282 mask = indices[0] 

283 

284 if mask.device != inp.device: 

285 mask = mask.to(inp.device) 

286 

287 indices = list(torch.where(mask)) 

288 

289 K = indices[0].numel() 

290 target_shape = (K,) + inp.shape[len(indices) :] 

291 

292 if values.numel() == 1: 

293 values = torch.full( 

294 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

295 ) 

296 elif values.numel() == K: 

297 values = values.reshape((K,)).expand(target_shape) 

298 

299 indices = [ 

300 index.to(inp.device) 

301 if index is not None and index.device != inp.device 

302 else index 

303 for index in indices 

304 ] 

305 

306 target_shape = get_max_rank_shape(indices) 

307 broadcast_indices(indices, target_shape) 

308 target_shape += inp.shape[len(indices) :] 

309 # Filter out None values for kernel call (only tensor indices) 

310 # Must be done AFTER broadcast_indices, as broadcast may create new tensors 

311 tensor_indices = [idx for idx in indices if idx is not None] 

312 if not tensor_indices: 

313 raise ValueError("At least one non-None index tensor is required") 

314 

315 if values.device != inp.device: 

316 values = values.to(inp.device) 

317 values = torch.broadcast_to(values, target_shape) 

318 

319 out = inp.clone() 

320 _index_put_func(out, tensor_indices, values, accumulate) 

321 return out 

322 

323 

324def index_put_(inp, indices, values, accumulate=False): 

325 logger.debug("GEMS_MTHREADS INDEX PUT_") 

326 

327 indices = list(indices) 

328 if len(indices) == 1 and indices[0].dtype == torch.bool: 

329 mask = indices[0] 

330 

331 if mask.device != inp.device: 

332 mask = mask.to(inp.device) 

333 

334 indices = list(torch.where(mask)) 

335 

336 K = indices[0].numel() 

337 target_shape = (K,) + inp.shape[len(indices) :] 

338 

339 if values.numel() == 1: 

340 values = torch.full( 

341 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

342 ) 

343 elif values.numel() == K: 

344 values = values.reshape((K,)).expand(target_shape) 

345 

346 indices = [ 

347 index.to(inp.device) 

348 if index is not None and index.device != inp.device 

349 else index 

350 for index in indices 

351 ] 

352 

353 target_shape = get_max_rank_shape(indices) 

354 broadcast_indices(indices, target_shape) 

355 target_shape += inp.shape[len(indices) :] 

356 # Filter out None values for kernel call (only tensor indices) 

357 # Must be done AFTER broadcast_indices, as broadcast may create new tensors 

358 tensor_indices = [idx for idx in indices if idx is not None] 

359 if not tensor_indices: 

360 raise ValueError("At least one non-None index tensor is required") 

361 

362 if values.device != inp.device: 

363 values = values.to(inp.device) 

364 try: 

365 values = torch.broadcast_to(values, target_shape) 

366 except Exception: 

367 return torch.ops.aten.index_put_.default.redispatch( 

368 torch._C.DispatchKeySet(torch._C.DispatchKey.CompositeExplicitAutograd), 

369 inp, 

370 tensor_indices, 

371 values, 

372 accumulate, 

373 ) 

374 

375 _index_put_func(inp, tensor_indices, values, accumulate) 

376 return inp