Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/index_put.py: 0%
252 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]:
15 # Filter out None values (basic indexing markers)
16 tensor_indices = [idx for idx in indices if idx is not None]
17 if len(tensor_indices) == 0:
18 return []
19 max_rank = max([len(index.shape) for index in tensor_indices])
20 shape = [0 for _ in range(max_rank)]
21 for i in range(max_rank):
22 max_num = 0
23 for index in tensor_indices:
24 axis = len(index.shape) - 1 - i
25 if axis >= 0:
26 max_num = max(max_num, index.shape[axis])
27 shape[max_rank - 1 - i] = max_num
28 return shape
31def broadcast_indices(indices, target_shape):
32 for i, index in enumerate(indices):
33 if index is not None and tuple(index.shape) != tuple(target_shape):
34 indices[i] = torch.broadcast_to(index, target_shape)
37def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
38 code.writeline("import triton")
39 code.writeline("import triton.language as tl")
40 code.writeline("import builtins")
41 code.newline()
42 code.writeline("from flag_gems.utils import libentry")
43 code.writeline("from flag_gems import runtime")
44 code.writeline("from flag_gems.utils.shape_utils import volume")
46 code.newline()
47 code.newline()
49 code.writeline("def heur_block_m(args):")
50 with code.indent():
51 code.writeline('if args["M"] == 0:')
52 with code.indent():
53 code.writeline("return 2")
54 code.writeline('return triton.next_power_of_2(triton.cdiv(args["M"], 12))')
56 code.newline()
58 code.writeline("def heur_block_n(args):")
59 with code.indent():
60 code.writeline('return builtins.min(triton.next_power_of_2(args["N"]), 8192)')
62 code.newline()
63 code.newline()
64 return code
67def generate_index_put_kernel(
68 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer
69):
70 code.writeline("@libentry()")
71 # code.writeline(
72 # '@triton.autotune(configs=runtime.get_tuned_config("index_put"), key=["M", "N"], restore_value=["input_ptr"])'
73 # )
74 code.writeline("@triton.heuristics(")
75 with code.indent():
76 code.writeline("values={")
77 with code.indent():
78 code.writeline('"BLOCK_SIZE0": heur_block_m,')
79 code.writeline('"BLOCK_SIZE1": heur_block_n,')
80 code.writeline("},")
81 code.writeline(")")
82 code.writeline("@triton.jit")
83 code.writeline(f"def {kernel_name}(")
84 with code.indent():
85 args = ["input_ptr,"]
86 args += [f"indices{i}_ptr," for i in range(indices_len)]
87 args += ["values_ptr,"]
88 args += [f"input_shape{i}: tl.constexpr," for i in range(inp_rank)]
89 for i in range(indices_len):
90 args += [f"indices{i}_shape{j}: tl.constexpr," for j in range(index_rank)]
91 args += [f"input_stride{i}: tl.constexpr," for i in range(inp_rank)]
92 for i in range(indices_len):
93 args += [f"indices{i}_stride{j}: tl.constexpr," for j in range(index_rank)]
94 args += [
95 f"values_stride{i}: tl.constexpr,"
96 for i in range(index_rank + inp_rank - indices_len)
97 ]
98 args += [
99 "M: tl.constexpr,",
100 "N: tl.constexpr,",
101 "IS_ACCUMULATE: tl.constexpr,",
102 "BLOCK_SIZE0: tl.constexpr,",
103 "BLOCK_SIZE1: tl.constexpr,",
104 ]
105 code.writelines(args)
106 code.writeline("):")
108 with code.indent():
109 code.writeline("pid0 = tl.program_id(axis=0)")
110 code.writeline("pid1 = tl.program_id(axis=1)")
111 code.writeline(
112 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]"
113 )
114 if inp_rank == indices_len:
115 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]")
116 else:
117 code.writeline(
118 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]"
119 )
120 code.newline()
121 code.writeline("cur_idx = offset0")
122 for i in range(index_rank - 1, -1, -1):
123 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}")
124 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}")
125 code.newline()
126 code.writeline("cur_idx = offset1")
127 for i in range(inp_rank - 1, indices_len - 1, -1):
128 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}")
129 code.writeline(f"cur_idx = cur_idx // input_shape{i}")
130 code.newline()
131 code.writeline("mask0 = offset0 < M")
132 for i in range(indices_len):
133 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)]
134 code.writeline(
135 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)"
136 )
137 code.newline()
138 index_mask = [
139 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})"
140 for i in range(indices_len)
141 ]
142 code.writeline(f"index_mask = {' & '.join(index_mask)}")
143 code.writeline("mask1 = offset1 < N")
144 code.writeline("mask = index_mask & mask0 & mask1")
145 code.newline()
146 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)]
147 comp += [
148 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank)
149 ]
150 code.writeline(f"input_offset = {' + '.join(comp)}")
151 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)]
152 comp += [
153 f"input_idx{indices_len + i} * values_stride{index_rank + i}"
154 for i in range(inp_rank - indices_len)
155 ]
156 code.writeline(f"values_offset = {' + '.join(comp)}")
157 code.newline()
158 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)")
159 code.writeline("if IS_ACCUMULATE:")
160 with code.indent():
161 code.writeline(
162 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)"
163 )
164 code.writeline("else:")
165 with code.indent():
166 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)")
168 code.newline()
169 code.newline()
170 return code
173def generate_index_put_wrapper(
174 inp_rank,
175 indices_len,
176 index_rank,
177 wrapper_name: str,
178 kernel_name: str,
179 code: IndentedBuffer,
180):
181 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):")
182 with code.indent():
183 code.writeline("input_shape = input.shape")
184 code.writeline("input_stride = input.stride()")
185 for i in range(indices_len):
186 code.writeline(f"indices{i}_shape = indices[{i}].shape")
187 code.writeline(f"indices{i}_stride = indices[{i}].stride()")
188 code.writeline("values_shape = values.shape")
189 code.writeline("values_stride = values.stride()")
190 code.writeline("M = indices[0].numel()")
191 code.writeline(f"N = volume(input_shape[{indices_len}: ])")
192 code.newline()
193 code.writeline("grid = lambda meta: (")
194 with code.indent():
195 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ")
196 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ")
197 code.writeline(")")
198 code.newline()
199 code.writeline(f"{kernel_name}[grid](")
200 with code.indent():
201 args = ["input,"]
202 args += [f"indices[{i}]," for i in range(indices_len)]
203 args += ["values,"]
204 args += [f"input_shape[{i}]," for i in range(inp_rank)]
205 for i in range(indices_len):
206 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)]
207 args += [f"input_stride[{i}]," for i in range(inp_rank)]
208 for i in range(indices_len):
209 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)]
210 args += [
211 f"values_stride[{i}],"
212 for i in range(index_rank + inp_rank - indices_len)
213 ]
214 args += ["M,", "N,", "accumulate==True,"]
215 code.writelines(args)
216 code.writeline(")")
217 code.writeline("return input")
218 code.newline()
219 code.newline()
220 return code
223def generate_code(
224 inputs: Tuple[Any],
225 wrapper_name: str,
226 kernel_name: str,
227 code: IndentedBuffer,
228):
229 inp_rank = inputs[0].ndim
230 # Filter out None values to get actual tensor indices
231 tensor_indices = [idx for idx in inputs[1] if idx is not None]
232 indices_len = len(tensor_indices)
233 if indices_len == 0:
234 raise ValueError("At least one non-None index tensor is required")
235 index_rank = tensor_indices[0].ndim
236 code = generate_imports(code)
237 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code)
238 generate_index_put_wrapper(
239 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code
240 )
241 return code
244class IndexPutFunction:
245 def __init__(self):
246 self.pid = os.getpid()
247 self.overloads: Mapping[str, Callable] = {}
249 def __call__(self, *args, **kwargs):
250 inp, tensor_indices, values, accumulate = args
251 full_args = (inp, tensor_indices, values)
253 key = self.arg_key(*full_args)
254 if key in self.overloads:
255 overload = self.overloads[key]
256 else:
257 code = IndentedBuffer()
258 code = generate_code(
259 full_args,
260 "_index_put_wrapper",
261 "_index_put_jit_function",
262 code,
263 )
264 file_name = f"index_put_{key}.py"
265 file_path = code_cache_dir() / file_name
266 write_atomic(file_path, code.getvalue())
268 spec = importlib.util.spec_from_file_location(
269 f"_gen_module_rank_{key}",
270 file_path,
271 )
273 m = importlib.util.module_from_spec(spec)
274 spec.loader.exec_module(m)
275 overload = getattr(m, "_index_put_wrapper")
276 self.overloads[key] = overload
278 return overload(*args)
280 def arg_key(self, *args, **kwargs):
281 inp, tensor_indices, _ = args[0], args[1], args[2]
282 inp_rank = inp.ndim
283 indices_len = len(tensor_indices)
284 if indices_len == 0:
285 index_rank = 0
286 else:
287 index_rank = tensor_indices[0].ndim
288 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}"
291_index_put_func = IndexPutFunction()
294def index_put(inp, indices, values, accumulate=False):
295 logger.debug("GEMS INDEX PUT")
297 indices = list(indices)
298 if len(indices) == 1 and indices[0].dtype == torch.bool:
299 mask = indices[0]
301 if mask.device != inp.device:
302 mask = mask.to(inp.device)
304 indices = list(torch.where(mask))
306 K = indices[0].numel()
307 target_shape = (K,) + inp.shape[len(indices) :]
309 if values.numel() == 1:
310 values = torch.full(
311 target_shape, values.item(), dtype=inp.dtype, device=inp.device
312 )
313 elif values.numel() == K:
314 values = values.reshape((K,)).expand(target_shape)
316 indices = [
317 index.to(inp.device)
318 if index is not None and index.device != inp.device
319 else index
320 for index in indices
321 ]
323 target_shape = get_max_rank_shape(indices)
324 broadcast_indices(indices, target_shape)
325 target_shape += inp.shape[len(indices) :]
326 # Filter out None values for kernel call (only tensor indices)
327 # Must be done AFTER broadcast_indices, as broadcast may create new tensors
328 tensor_indices = [idx for idx in indices if idx is not None]
329 if not tensor_indices:
330 raise ValueError("At least one non-None index tensor is required")
332 if values.device != inp.device:
333 values = values.to(inp.device)
334 values = torch.broadcast_to(values, target_shape)
336 out = inp.clone()
337 _index_put_func(out, tensor_indices, values, accumulate)
338 return out
341def index_put_(inp, indices, values, accumulate=False):
342 logger.debug("GEMS INDEX PUT_")
344 indices = list(indices)
345 if len(indices) == 1 and indices[0].dtype == torch.bool:
346 mask = indices[0]
348 if mask.device != inp.device:
349 mask = mask.to(inp.device)
351 indices = list(torch.where(mask))
353 K = indices[0].numel()
354 target_shape = (K,) + inp.shape[len(indices) :]
356 if values.numel() == 1:
357 values = torch.full(
358 target_shape, values.item(), dtype=inp.dtype, device=inp.device
359 )
360 elif values.numel() == K:
361 values = values.reshape((K,)).expand(target_shape)
363 indices = [
364 index.to(inp.device)
365 if index is not None and index.device != inp.device
366 else index
367 for index in indices
368 ]
370 target_shape = get_max_rank_shape(indices)
371 broadcast_indices(indices, target_shape)
372 target_shape += inp.shape[len(indices) :]
373 # Filter out None values for kernel call (only tensor indices)
374 # Must be done AFTER broadcast_indices, as broadcast may create new tensors
375 tensor_indices = [idx for idx in indices if idx is not None]
376 if not tensor_indices:
377 raise ValueError("At least one non-None index tensor is required")
379 if values.device != inp.device:
380 values = values.to(inp.device)
381 values = torch.broadcast_to(values, target_shape)
383 _index_put_func(inp, tensor_indices, values, accumulate)
384 return inp