Coverage for src/flag_gems/runtime/backend/_metax/ops/index_put.py: 0%
232 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger("flag_gems." + __name__)
14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]:
15 # Filter out None values (basic indexing markers)
16 tensor_indices = [idx for idx in indices if idx is not None]
17 if len(tensor_indices) == 0:
18 return []
19 max_rank = max([len(index.shape) for index in tensor_indices])
20 shape = [0 for _ in range(max_rank)]
21 for i in range(max_rank):
22 max_num = 0
23 for index in tensor_indices:
24 axis = len(index.shape) - 1 - i
25 if axis >= 0:
26 max_num = max(max_num, index.shape[axis])
27 shape[max_rank - 1 - i] = max_num
28 return shape
31def broadcast_indices(indices, target_shape):
32 for i, index in enumerate(indices):
33 if index is not None and tuple(index.shape) != tuple(target_shape):
34 indices[i] = torch.broadcast_to(index, target_shape)
37def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
38 code.writeline("import triton")
39 code.writeline("import triton.language as tl")
40 code.newline()
41 code.writeline("from flag_gems.utils import libentry, libtuner")
42 code.writeline("from flag_gems import runtime")
43 code.writeline("from flag_gems.utils.shape_utils import volume")
44 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
46 code.newline()
47 code.newline()
48 return code
51def generate_index_put_kernel(
52 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer
53):
54 code.writeline("@libentry()")
55 code.writeline("@triton.jit")
56 code.writeline(f"def {kernel_name}(")
57 with code.indent():
58 args = ["input_ptr,"]
59 args += [f"indices{i}_ptr," for i in range(indices_len)]
60 args += ["values_ptr,"]
61 args += [f"input_shape{i}," for i in range(inp_rank)]
62 for i in range(indices_len):
63 args += [f"indices{i}_shape{j}," for j in range(index_rank)]
64 args += [f"input_stride{i}," for i in range(inp_rank)]
65 for i in range(indices_len):
66 args += [f"indices{i}_stride{j}," for j in range(index_rank)]
67 args += [
68 f"values_stride{i}," for i in range(index_rank + inp_rank - indices_len)
69 ]
70 args += [
71 "M,",
72 "N,",
73 "IS_ACCUMULATE: tl.constexpr,",
74 "BLOCK_SIZE0: tl.constexpr = 2,",
75 "BLOCK_SIZE1: tl.constexpr = 2048,",
76 ]
77 code.writelines(args)
78 code.writeline("):")
80 with code.indent():
81 code.writeline("pid0 = tle.program_id(axis=0)")
82 code.writeline("pid1 = tle.program_id(axis=1)")
83 code.writeline(
84 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]"
85 )
86 if inp_rank == indices_len:
87 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]")
88 else:
89 code.writeline(
90 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]"
91 )
92 code.newline()
93 code.writeline("cur_idx = offset0")
94 for i in range(index_rank - 1, -1, -1):
95 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}")
96 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}")
97 code.newline()
98 code.writeline("cur_idx = offset1")
99 for i in range(inp_rank - 1, indices_len - 1, -1):
100 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}")
101 code.writeline(f"cur_idx = cur_idx // input_shape{i}")
102 code.newline()
103 code.writeline("mask0 = offset0 < M")
104 for i in range(indices_len):
105 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)]
106 code.writeline(
107 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)"
108 )
109 code.newline()
110 index_mask = [
111 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})"
112 for i in range(indices_len)
113 ]
114 code.writeline(f"index_mask = {' & '.join(index_mask)}")
115 code.writeline("mask1 = offset1 < N")
116 code.writeline("mask = index_mask & mask0 & mask1")
117 code.newline()
118 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)]
119 comp += [
120 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank)
121 ]
122 code.writeline(f"input_offset = {' + '.join(comp)}")
123 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)]
124 comp += [
125 f"input_idx{indices_len + i} * values_stride{index_rank + i}"
126 for i in range(inp_rank - indices_len)
127 ]
128 code.writeline(f"values_offset = {' + '.join(comp)}")
129 code.newline()
130 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)")
131 code.writeline("if IS_ACCUMULATE:")
132 with code.indent():
133 code.writeline(
134 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)"
135 )
136 code.writeline("else:")
137 with code.indent():
138 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)")
140 code.newline()
141 code.newline()
142 return code
145def generate_index_put_wrapper(
146 inp_rank,
147 indices_len,
148 index_rank,
149 wrapper_name: str,
150 kernel_name: str,
151 code: IndentedBuffer,
152):
153 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):")
154 with code.indent():
155 code.writeline("input_shape = input.shape")
156 code.writeline("input_stride = input.stride()")
157 for i in range(indices_len):
158 code.writeline(f"indices{i}_shape = indices[{i}].shape")
159 code.writeline(f"indices{i}_stride = indices[{i}].stride()")
160 code.writeline("values_shape = values.shape")
161 code.writeline("values_stride = values.stride()")
162 code.writeline("M = indices[0].numel()")
163 code.writeline(f"N = volume(input_shape[{indices_len}: ])")
164 code.newline()
165 code.writeline("grid = lambda meta: (")
166 with code.indent():
167 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ")
168 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ")
169 code.writeline(")")
170 code.newline()
171 code.writeline(f"{kernel_name}[grid](")
172 with code.indent():
173 args = ["input,"]
174 args += [f"indices[{i}]," for i in range(indices_len)]
175 args += ["values,"]
176 args += [f"input_shape[{i}]," for i in range(inp_rank)]
177 for i in range(indices_len):
178 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)]
179 args += [f"input_stride[{i}]," for i in range(inp_rank)]
180 for i in range(indices_len):
181 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)]
182 args += [
183 f"values_stride[{i}],"
184 for i in range(index_rank + inp_rank - indices_len)
185 ]
186 args += ["M,", "N,", "accumulate==True,"]
187 code.writelines(args)
188 code.writeline(")")
189 code.writeline("return input")
190 code.newline()
191 code.newline()
192 return code
195def generate_code(
196 inputs: Tuple[Any],
197 wrapper_name: str,
198 kernel_name: str,
199 code: IndentedBuffer,
200):
201 inp_rank = inputs[0].ndim
202 # Filter out None values to get actual tensor indices
203 tensor_indices = [idx for idx in inputs[1] if idx is not None]
204 indices_len = len(tensor_indices)
205 if indices_len == 0:
206 raise ValueError("At least one non-None index tensor is required")
207 index_rank = tensor_indices[0].ndim
208 code = generate_imports(code)
209 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code)
210 generate_index_put_wrapper(
211 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code
212 )
213 return code
216class IndexPutFunction:
217 def __init__(self):
218 self.pid = os.getpid()
219 self.overloads: Mapping[str, Callable] = {}
221 def __call__(self, *args, **kwargs):
222 inp, tensor_indices, values, accumulate = args
223 full_args = (inp, tensor_indices, values)
225 key = self.arg_key(*full_args)
226 if key in self.overloads:
227 overload = self.overloads[key]
228 else:
229 code = IndentedBuffer()
230 code = generate_code(
231 full_args,
232 "_index_put_wrapper",
233 "_index_put_jit_function",
234 code,
235 )
236 file_name = f"index_put_{key}.py"
237 file_path = code_cache_dir() / file_name
238 write_atomic(file_path, code.getvalue())
240 spec = importlib.util.spec_from_file_location(
241 f"_gen_module_rank_{key}",
242 file_path,
243 )
245 m = importlib.util.module_from_spec(spec)
246 spec.loader.exec_module(m)
247 overload = getattr(m, "_index_put_wrapper")
248 self.overloads[key] = overload
250 return overload(*args)
252 def arg_key(self, *args, **kwargs):
253 inp, tensor_indices, _ = args[0], args[1], args[2]
254 inp_rank = inp.ndim
255 indices_len = len(tensor_indices)
256 if indices_len == 0:
257 index_rank = 0
258 else:
259 index_rank = tensor_indices[0].ndim
260 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}"
263_index_put_func = IndexPutFunction()
266def index_put(inp, indices, values, accumulate=False):
267 logger.debug("GEMS INDEX PUT")
269 indices = list(indices)
270 if len(indices) == 1 and indices[0].dtype == torch.bool:
271 mask = indices[0]
273 if mask.device != inp.device:
274 mask = mask.to(inp.device)
276 indices = list(torch.where(mask))
278 K = indices[0].numel()
279 target_shape = (K,) + inp.shape[len(indices) :]
281 if values.numel() == 1:
282 values = torch.full(
283 target_shape, values.item(), dtype=inp.dtype, device=inp.device
284 )
285 elif values.numel() == K:
286 values = values.reshape((K,)).expand(target_shape)
288 indices = [
289 index.to(inp.device)
290 if index is not None and index.device != inp.device
291 else index
292 for index in indices
293 ]
295 target_shape = get_max_rank_shape(indices)
296 broadcast_indices(indices, target_shape)
297 target_shape += inp.shape[len(indices) :]
298 # Filter out None values for kernel call (only tensor indices)
299 # Must be done AFTER broadcast_indices, as broadcast may create new tensors
300 tensor_indices = [idx for idx in indices if idx is not None]
301 if not tensor_indices:
302 raise ValueError("At least one non-None index tensor is required")
304 if values.device != inp.device:
305 values = values.to(inp.device)
306 values = torch.broadcast_to(values, target_shape)
308 out = inp.clone()
309 _index_put_func(out, tensor_indices, values, accumulate)
310 return out
313def index_put_(inp, indices, values, accumulate=False):
314 logger.debug("GEMS INDEX PUT_")
316 indices = list(indices)
317 if len(indices) == 1 and indices[0].dtype == torch.bool:
318 mask = indices[0]
320 if mask.device != inp.device:
321 mask = mask.to(inp.device)
323 indices = list(torch.where(mask))
325 K = indices[0].numel()
326 target_shape = (K,) + inp.shape[len(indices) :]
328 if values.numel() == 1:
329 values = torch.full(
330 target_shape, values.item(), dtype=inp.dtype, device=inp.device
331 )
332 elif values.numel() == K:
333 values = values.reshape((K,)).expand(target_shape)
335 indices = [
336 index.to(inp.device)
337 if index is not None and index.device != inp.device
338 else index
339 for index in indices
340 ]
342 target_shape = get_max_rank_shape(indices)
343 broadcast_indices(indices, target_shape)
344 target_shape += inp.shape[len(indices) :]
345 # Filter out None values for kernel call (only tensor indices)
346 # Must be done AFTER broadcast_indices, as broadcast may create new tensors
347 tensor_indices = [idx for idx in indices if idx is not None]
348 if not tensor_indices:
349 raise ValueError("At least one non-None index tensor is required")
351 if values.device != inp.device:
352 values = values.to(inp.device)
353 values = torch.broadcast_to(values, target_shape)
355 _index_put_func(inp, tensor_indices, values, accumulate)
356 return inp