Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/index_put.py: 0%

252 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

15 # Filter out None values (basic indexing markers) 

16 tensor_indices = [idx for idx in indices if idx is not None] 

17 if len(tensor_indices) == 0: 

18 return [] 

19 max_rank = max([len(index.shape) for index in tensor_indices]) 

20 shape = [0 for _ in range(max_rank)] 

21 for i in range(max_rank): 

22 max_num = 0 

23 for index in tensor_indices: 

24 axis = len(index.shape) - 1 - i 

25 if axis >= 0: 

26 max_num = max(max_num, index.shape[axis]) 

27 shape[max_rank - 1 - i] = max_num 

28 return shape 

29 

30 

31def broadcast_indices(indices, target_shape): 

32 for i, index in enumerate(indices): 

33 if index is not None and tuple(index.shape) != tuple(target_shape): 

34 indices[i] = torch.broadcast_to(index, target_shape) 

35 

36 

37def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

38 code.writeline("import triton") 

39 code.writeline("import triton.language as tl") 

40 code.writeline("import builtins") 

41 code.newline() 

42 code.writeline("from flag_gems.utils import libentry") 

43 code.writeline("from flag_gems import runtime") 

44 code.writeline("from flag_gems.utils.shape_utils import volume") 

45 

46 code.newline() 

47 code.newline() 

48 

49 code.writeline("def heur_block_m(args):") 

50 with code.indent(): 

51 code.writeline('if args["M"] == 0:') 

52 with code.indent(): 

53 code.writeline("return 2") 

54 code.writeline('return triton.next_power_of_2(triton.cdiv(args["M"], 12))') 

55 

56 code.newline() 

57 

58 code.writeline("def heur_block_n(args):") 

59 with code.indent(): 

60 code.writeline('return builtins.min(triton.next_power_of_2(args["N"]), 8192)') 

61 

62 code.newline() 

63 code.newline() 

64 return code 

65 

66 

67def generate_index_put_kernel( 

68 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer 

69): 

70 code.writeline("@libentry()") 

71 # code.writeline( 

72 # '@triton.autotune(configs=runtime.get_tuned_config("index_put"), key=["M", "N"], restore_value=["input_ptr"])' 

73 # ) 

74 code.writeline("@triton.heuristics(") 

75 with code.indent(): 

76 code.writeline("values={") 

77 with code.indent(): 

78 code.writeline('"BLOCK_SIZE0": heur_block_m,') 

79 code.writeline('"BLOCK_SIZE1": heur_block_n,') 

80 code.writeline("},") 

81 code.writeline(")") 

82 code.writeline("@triton.jit") 

83 code.writeline(f"def {kernel_name}(") 

84 with code.indent(): 

85 args = ["input_ptr,"] 

86 args += [f"indices{i}_ptr," for i in range(indices_len)] 

87 args += ["values_ptr,"] 

88 args += [f"input_shape{i}: tl.constexpr," for i in range(inp_rank)] 

89 for i in range(indices_len): 

90 args += [f"indices{i}_shape{j}: tl.constexpr," for j in range(index_rank)] 

91 args += [f"input_stride{i}: tl.constexpr," for i in range(inp_rank)] 

92 for i in range(indices_len): 

93 args += [f"indices{i}_stride{j}: tl.constexpr," for j in range(index_rank)] 

94 args += [ 

95 f"values_stride{i}: tl.constexpr," 

96 for i in range(index_rank + inp_rank - indices_len) 

97 ] 

98 args += [ 

99 "M: tl.constexpr,", 

100 "N: tl.constexpr,", 

101 "IS_ACCUMULATE: tl.constexpr,", 

102 "BLOCK_SIZE0: tl.constexpr,", 

103 "BLOCK_SIZE1: tl.constexpr,", 

104 ] 

105 code.writelines(args) 

106 code.writeline("):") 

107 

108 with code.indent(): 

109 code.writeline("pid0 = tl.program_id(axis=0)") 

110 code.writeline("pid1 = tl.program_id(axis=1)") 

111 code.writeline( 

112 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" 

113 ) 

114 if inp_rank == indices_len: 

115 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]") 

116 else: 

117 code.writeline( 

118 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]" 

119 ) 

120 code.newline() 

121 code.writeline("cur_idx = offset0") 

122 for i in range(index_rank - 1, -1, -1): 

123 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}") 

124 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}") 

125 code.newline() 

126 code.writeline("cur_idx = offset1") 

127 for i in range(inp_rank - 1, indices_len - 1, -1): 

128 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}") 

129 code.writeline(f"cur_idx = cur_idx // input_shape{i}") 

130 code.newline() 

131 code.writeline("mask0 = offset0 < M") 

132 for i in range(indices_len): 

133 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)] 

134 code.writeline( 

135 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)" 

136 ) 

137 code.newline() 

138 index_mask = [ 

139 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})" 

140 for i in range(indices_len) 

141 ] 

142 code.writeline(f"index_mask = {' & '.join(index_mask)}") 

143 code.writeline("mask1 = offset1 < N") 

144 code.writeline("mask = index_mask & mask0 & mask1") 

145 code.newline() 

146 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)] 

147 comp += [ 

148 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank) 

149 ] 

150 code.writeline(f"input_offset = {' + '.join(comp)}") 

151 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)] 

152 comp += [ 

153 f"input_idx{indices_len + i} * values_stride{index_rank + i}" 

154 for i in range(inp_rank - indices_len) 

155 ] 

156 code.writeline(f"values_offset = {' + '.join(comp)}") 

157 code.newline() 

158 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)") 

159 code.writeline("if IS_ACCUMULATE:") 

160 with code.indent(): 

161 code.writeline( 

162 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)" 

163 ) 

164 code.writeline("else:") 

165 with code.indent(): 

166 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)") 

167 

168 code.newline() 

169 code.newline() 

170 return code 

171 

172 

173def generate_index_put_wrapper( 

174 inp_rank, 

175 indices_len, 

176 index_rank, 

177 wrapper_name: str, 

178 kernel_name: str, 

179 code: IndentedBuffer, 

180): 

181 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):") 

182 with code.indent(): 

183 code.writeline("input_shape = input.shape") 

184 code.writeline("input_stride = input.stride()") 

185 for i in range(indices_len): 

186 code.writeline(f"indices{i}_shape = indices[{i}].shape") 

187 code.writeline(f"indices{i}_stride = indices[{i}].stride()") 

188 code.writeline("values_shape = values.shape") 

189 code.writeline("values_stride = values.stride()") 

190 code.writeline("M = indices[0].numel()") 

191 code.writeline(f"N = volume(input_shape[{indices_len}: ])") 

192 code.newline() 

193 code.writeline("grid = lambda meta: (") 

194 with code.indent(): 

195 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ") 

196 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ") 

197 code.writeline(")") 

198 code.newline() 

199 code.writeline(f"{kernel_name}[grid](") 

200 with code.indent(): 

201 args = ["input,"] 

202 args += [f"indices[{i}]," for i in range(indices_len)] 

203 args += ["values,"] 

204 args += [f"input_shape[{i}]," for i in range(inp_rank)] 

205 for i in range(indices_len): 

206 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)] 

207 args += [f"input_stride[{i}]," for i in range(inp_rank)] 

208 for i in range(indices_len): 

209 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)] 

210 args += [ 

211 f"values_stride[{i}]," 

212 for i in range(index_rank + inp_rank - indices_len) 

213 ] 

214 args += ["M,", "N,", "accumulate==True,"] 

215 code.writelines(args) 

216 code.writeline(")") 

217 code.writeline("return input") 

218 code.newline() 

219 code.newline() 

220 return code 

221 

222 

223def generate_code( 

224 inputs: Tuple[Any], 

225 wrapper_name: str, 

226 kernel_name: str, 

227 code: IndentedBuffer, 

228): 

229 inp_rank = inputs[0].ndim 

230 # Filter out None values to get actual tensor indices 

231 tensor_indices = [idx for idx in inputs[1] if idx is not None] 

232 indices_len = len(tensor_indices) 

233 if indices_len == 0: 

234 raise ValueError("At least one non-None index tensor is required") 

235 index_rank = tensor_indices[0].ndim 

236 code = generate_imports(code) 

237 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code) 

238 generate_index_put_wrapper( 

239 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code 

240 ) 

241 return code 

242 

243 

244class IndexPutFunction: 

245 def __init__(self): 

246 self.pid = os.getpid() 

247 self.overloads: Mapping[str, Callable] = {} 

248 

249 def __call__(self, *args, **kwargs): 

250 inp, tensor_indices, values, accumulate = args 

251 full_args = (inp, tensor_indices, values) 

252 

253 key = self.arg_key(*full_args) 

254 if key in self.overloads: 

255 overload = self.overloads[key] 

256 else: 

257 code = IndentedBuffer() 

258 code = generate_code( 

259 full_args, 

260 "_index_put_wrapper", 

261 "_index_put_jit_function", 

262 code, 

263 ) 

264 file_name = f"index_put_{key}.py" 

265 file_path = code_cache_dir() / file_name 

266 write_atomic(file_path, code.getvalue()) 

267 

268 spec = importlib.util.spec_from_file_location( 

269 f"_gen_module_rank_{key}", 

270 file_path, 

271 ) 

272 

273 m = importlib.util.module_from_spec(spec) 

274 spec.loader.exec_module(m) 

275 overload = getattr(m, "_index_put_wrapper") 

276 self.overloads[key] = overload 

277 

278 return overload(*args) 

279 

280 def arg_key(self, *args, **kwargs): 

281 inp, tensor_indices, _ = args[0], args[1], args[2] 

282 inp_rank = inp.ndim 

283 indices_len = len(tensor_indices) 

284 if indices_len == 0: 

285 index_rank = 0 

286 else: 

287 index_rank = tensor_indices[0].ndim 

288 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}" 

289 

290 

291_index_put_func = IndexPutFunction() 

292 

293 

294def index_put(inp, indices, values, accumulate=False): 

295 logger.debug("GEMS INDEX PUT") 

296 

297 indices = list(indices) 

298 if len(indices) == 1 and indices[0].dtype == torch.bool: 

299 mask = indices[0] 

300 

301 if mask.device != inp.device: 

302 mask = mask.to(inp.device) 

303 

304 indices = list(torch.where(mask)) 

305 

306 K = indices[0].numel() 

307 target_shape = (K,) + inp.shape[len(indices) :] 

308 

309 if values.numel() == 1: 

310 values = torch.full( 

311 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

312 ) 

313 elif values.numel() == K: 

314 values = values.reshape((K,)).expand(target_shape) 

315 

316 indices = [ 

317 index.to(inp.device) 

318 if index is not None and index.device != inp.device 

319 else index 

320 for index in indices 

321 ] 

322 

323 target_shape = get_max_rank_shape(indices) 

324 broadcast_indices(indices, target_shape) 

325 target_shape += inp.shape[len(indices) :] 

326 # Filter out None values for kernel call (only tensor indices) 

327 # Must be done AFTER broadcast_indices, as broadcast may create new tensors 

328 tensor_indices = [idx for idx in indices if idx is not None] 

329 if not tensor_indices: 

330 raise ValueError("At least one non-None index tensor is required") 

331 

332 if values.device != inp.device: 

333 values = values.to(inp.device) 

334 values = torch.broadcast_to(values, target_shape) 

335 

336 out = inp.clone() 

337 _index_put_func(out, tensor_indices, values, accumulate) 

338 return out 

339 

340 

341def index_put_(inp, indices, values, accumulate=False): 

342 logger.debug("GEMS INDEX PUT_") 

343 

344 indices = list(indices) 

345 if len(indices) == 1 and indices[0].dtype == torch.bool: 

346 mask = indices[0] 

347 

348 if mask.device != inp.device: 

349 mask = mask.to(inp.device) 

350 

351 indices = list(torch.where(mask)) 

352 

353 K = indices[0].numel() 

354 target_shape = (K,) + inp.shape[len(indices) :] 

355 

356 if values.numel() == 1: 

357 values = torch.full( 

358 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

359 ) 

360 elif values.numel() == K: 

361 values = values.reshape((K,)).expand(target_shape) 

362 

363 indices = [ 

364 index.to(inp.device) 

365 if index is not None and index.device != inp.device 

366 else index 

367 for index in indices 

368 ] 

369 

370 target_shape = get_max_rank_shape(indices) 

371 broadcast_indices(indices, target_shape) 

372 target_shape += inp.shape[len(indices) :] 

373 # Filter out None values for kernel call (only tensor indices) 

374 # Must be done AFTER broadcast_indices, as broadcast may create new tensors 

375 tensor_indices = [idx for idx in indices if idx is not None] 

376 if not tensor_indices: 

377 raise ValueError("At least one non-None index tensor is required") 

378 

379 if values.device != inp.device: 

380 values = values.to(inp.device) 

381 values = torch.broadcast_to(values, target_shape) 

382 

383 _index_put_func(inp, tensor_indices, values, accumulate) 

384 return inp