Coverage for src/flag_gems/runtime/backend/_metax/ops/index_put.py: 0%
233 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger("flag_gems." + __name__)
14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]:
15 # Filter out None values (basic indexing markers)
16 tensor_indices = [idx for idx in indices if idx is not None]
17 if len(tensor_indices) == 0:
18 return []
19 max_rank = max([len(index.shape) for index in tensor_indices])
20 shape = [0 for _ in range(max_rank)]
21 for i in range(max_rank):
22 max_num = 0
23 for index in tensor_indices:
24 axis = len(index.shape) - 1 - i
25 if axis >= 0:
26 max_num = max(max_num, index.shape[axis])
27 shape[max_rank - 1 - i] = max_num
28 return shape
31def broadcast_indices(indices, target_shape):
32 for i, index in enumerate(indices):
33 if index is not None and tuple(index.shape) != tuple(target_shape):
34 indices[i] = torch.broadcast_to(index, target_shape)
37def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
38 code.writeline("import triton")
39 code.writeline("import triton.language as tl")
40 code.newline()
41 code.writeline("from flag_gems.utils import libentry, libtuner")
42 code.writeline("from flag_gems import runtime")
43 code.writeline("from flag_gems.utils.shape_utils import volume")
44 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
46 code.newline()
47 code.newline()
48 return code
51def generate_index_put_kernel(
52 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer
53):
54 code.writeline("@libentry()")
55 code.writeline(
56 '@triton.autotune(configs=runtime.get_tuned_config("index_put"), key=["M", "N"], restore_value=["input_ptr"])'
57 )
58 code.writeline("@triton.jit")
59 code.writeline(f"def {kernel_name}(")
60 with code.indent():
61 args = ["input_ptr,"]
62 args += [f"indices{i}_ptr," for i in range(indices_len)]
63 args += ["values_ptr,"]
64 args += [f"input_shape{i}," for i in range(inp_rank)]
65 for i in range(indices_len):
66 args += [f"indices{i}_shape{j}," for j in range(index_rank)]
67 args += [f"input_stride{i}," for i in range(inp_rank)]
68 for i in range(indices_len):
69 args += [f"indices{i}_stride{j}," for j in range(index_rank)]
70 args += [
71 f"values_stride{i}," for i in range(index_rank + inp_rank - indices_len)
72 ]
73 args += [
74 "M,",
75 "N,",
76 "IS_ACCUMULATE: tl.constexpr,",
77 "BLOCK_SIZE0: tl.constexpr,",
78 "BLOCK_SIZE1: tl.constexpr,",
79 ]
80 code.writelines(args)
81 code.writeline("):")
83 with code.indent():
84 code.writeline("pid0 = tle.program_id(axis=0)")
85 code.writeline("pid1 = tle.program_id(axis=1)")
86 code.writeline(
87 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]"
88 )
89 if inp_rank == indices_len:
90 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]")
91 else:
92 code.writeline(
93 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]"
94 )
95 code.newline()
96 code.writeline("cur_idx = offset0")
97 for i in range(index_rank - 1, -1, -1):
98 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}")
99 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}")
100 code.newline()
101 code.writeline("cur_idx = offset1")
102 for i in range(inp_rank - 1, indices_len - 1, -1):
103 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}")
104 code.writeline(f"cur_idx = cur_idx // input_shape{i}")
105 code.newline()
106 code.writeline("mask0 = offset0 < M")
107 for i in range(indices_len):
108 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)]
109 code.writeline(
110 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)"
111 )
112 code.newline()
113 index_mask = [
114 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})"
115 for i in range(indices_len)
116 ]
117 code.writeline(f"index_mask = {' & '.join(index_mask)}")
118 code.writeline("mask1 = offset1 < N")
119 code.writeline("mask = index_mask & mask0 & mask1")
120 code.newline()
121 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)]
122 comp += [
123 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank)
124 ]
125 code.writeline(f"input_offset = {' + '.join(comp)}")
126 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)]
127 comp += [
128 f"input_idx{indices_len + i} * values_stride{index_rank + i}"
129 for i in range(inp_rank - indices_len)
130 ]
131 code.writeline(f"values_offset = {' + '.join(comp)}")
132 code.newline()
133 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)")
134 code.writeline("if IS_ACCUMULATE:")
135 with code.indent():
136 code.writeline(
137 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)"
138 )
139 code.writeline("else:")
140 with code.indent():
141 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)")
143 code.newline()
144 code.newline()
145 return code
148def generate_index_put_wrapper(
149 inp_rank,
150 indices_len,
151 index_rank,
152 wrapper_name: str,
153 kernel_name: str,
154 code: IndentedBuffer,
155):
156 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):")
157 with code.indent():
158 code.writeline("input_shape = input.shape")
159 code.writeline("input_stride = input.stride()")
160 for i in range(indices_len):
161 code.writeline(f"indices{i}_shape = indices[{i}].shape")
162 code.writeline(f"indices{i}_stride = indices[{i}].stride()")
163 code.writeline("values_shape = values.shape")
164 code.writeline("values_stride = values.stride()")
165 code.writeline("M = indices[0].numel()")
166 code.writeline(f"N = volume(input_shape[{indices_len}: ])")
167 code.newline()
168 code.writeline("grid = lambda meta: (")
169 with code.indent():
170 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ")
171 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ")
172 code.writeline(")")
173 code.newline()
174 code.writeline(f"{kernel_name}[grid](")
175 with code.indent():
176 args = ["input,"]
177 args += [f"indices[{i}]," for i in range(indices_len)]
178 args += ["values,"]
179 args += [f"input_shape[{i}]," for i in range(inp_rank)]
180 for i in range(indices_len):
181 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)]
182 args += [f"input_stride[{i}]," for i in range(inp_rank)]
183 for i in range(indices_len):
184 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)]
185 args += [
186 f"values_stride[{i}],"
187 for i in range(index_rank + inp_rank - indices_len)
188 ]
189 args += ["M,", "N,", "accumulate==True,"]
190 code.writelines(args)
191 code.writeline(")")
192 code.writeline("return input")
193 code.newline()
194 code.newline()
195 return code
198def generate_code(
199 inputs: Tuple[Any],
200 wrapper_name: str,
201 kernel_name: str,
202 code: IndentedBuffer,
203):
204 inp_rank = inputs[0].ndim
205 # Filter out None values to get actual tensor indices
206 tensor_indices = [idx for idx in inputs[1] if idx is not None]
207 indices_len = len(tensor_indices)
208 if indices_len == 0:
209 raise ValueError("At least one non-None index tensor is required")
210 index_rank = tensor_indices[0].ndim
211 code = generate_imports(code)
212 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code)
213 generate_index_put_wrapper(
214 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code
215 )
216 return code
219class IndexPutFunction:
220 def __init__(self):
221 self.pid = os.getpid()
222 self.overloads: Mapping[str, Callable] = {}
224 def __call__(self, *args, **kwargs):
225 inp, tensor_indices, values, accumulate = args
226 full_args = (inp, tensor_indices, values)
228 key = self.arg_key(*full_args)
229 if key in self.overloads:
230 overload = self.overloads[key]
231 else:
232 code = IndentedBuffer()
233 code = generate_code(
234 full_args,
235 "_index_put_wrapper",
236 "_index_put_jit_function",
237 code,
238 )
239 file_name = f"index_put_{key}.py"
240 file_path = code_cache_dir() / file_name
241 write_atomic(file_path, code.getvalue())
243 spec = importlib.util.spec_from_file_location(
244 f"_gen_module_rank_{key}",
245 file_path,
246 )
248 m = importlib.util.module_from_spec(spec)
249 spec.loader.exec_module(m)
250 overload = getattr(m, "_index_put_wrapper")
251 self.overloads[key] = overload
253 return overload(*args)
255 def arg_key(self, *args, **kwargs):
256 inp, tensor_indices, _ = args[0], args[1], args[2]
257 inp_rank = inp.ndim
258 indices_len = len(tensor_indices)
259 if indices_len == 0:
260 index_rank = 0
261 else:
262 index_rank = tensor_indices[0].ndim
263 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}"
266_index_put_func = IndexPutFunction()
269def index_put(inp, indices, values, accumulate=False):
270 logger.debug("GEMS INDEX PUT")
272 indices = list(indices)
273 if len(indices) == 1 and indices[0].dtype == torch.bool:
274 mask = indices[0]
276 if mask.device != inp.device:
277 mask = mask.to(inp.device)
279 indices = list(torch.where(mask))
281 K = indices[0].numel()
282 target_shape = (K,) + inp.shape[len(indices) :]
284 if values.numel() == 1:
285 values = torch.full(
286 target_shape, values.item(), dtype=inp.dtype, device=inp.device
287 )
288 elif values.numel() == K:
289 values = values.reshape((K,)).expand(target_shape)
291 indices = [
292 index.to(inp.device)
293 if index is not None and index.device != inp.device
294 else index
295 for index in indices
296 ]
298 target_shape = get_max_rank_shape(indices)
299 broadcast_indices(indices, target_shape)
300 target_shape += inp.shape[len(indices) :]
301 # Filter out None values for kernel call (only tensor indices)
302 # Must be done AFTER broadcast_indices, as broadcast may create new tensors
303 tensor_indices = [idx for idx in indices if idx is not None]
304 if not tensor_indices:
305 raise ValueError("At least one non-None index tensor is required")
307 if values.device != inp.device:
308 values = values.to(inp.device)
309 values = torch.broadcast_to(values, target_shape)
311 out = inp.clone()
312 _index_put_func(out, tensor_indices, values, accumulate)
313 return out
316def index_put_(inp, indices, values, accumulate=False):
317 logger.debug("GEMS INDEX PUT_")
319 indices = list(indices)
320 if len(indices) == 1 and indices[0].dtype == torch.bool:
321 mask = indices[0]
323 if mask.device != inp.device:
324 mask = mask.to(inp.device)
326 indices = list(torch.where(mask))
328 K = indices[0].numel()
329 target_shape = (K,) + inp.shape[len(indices) :]
331 if values.numel() == 1:
332 values = torch.full(
333 target_shape, values.item(), dtype=inp.dtype, device=inp.device
334 )
335 elif values.numel() == K:
336 values = values.reshape((K,)).expand(target_shape)
338 indices = [
339 index.to(inp.device)
340 if index is not None and index.device != inp.device
341 else index
342 for index in indices
343 ]
345 target_shape = get_max_rank_shape(indices)
346 broadcast_indices(indices, target_shape)
347 target_shape += inp.shape[len(indices) :]
348 # Filter out None values for kernel call (only tensor indices)
349 # Must be done AFTER broadcast_indices, as broadcast may create new tensors
350 tensor_indices = [idx for idx in indices if idx is not None]
351 if not tensor_indices:
352 raise ValueError("At least one non-None index tensor is required")
354 if values.device != inp.device:
355 values = values.to(inp.device)
356 values = torch.broadcast_to(values, target_shape)
358 _index_put_func(inp, tensor_indices, values, accumulate)
359 return inp