Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_put.py: 0%
244 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger(
12 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
13)
16def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]:
17 # Filter out None values (basic indexing markers)
18 tensor_indices = [idx for idx in indices if idx is not None]
19 if len(tensor_indices) == 0:
20 return []
21 max_rank = max([len(index.shape) for index in tensor_indices])
22 shape = [0 for _ in range(max_rank)]
23 for i in range(max_rank):
24 max_num = 0
25 for index in tensor_indices:
26 axis = len(index.shape) - 1 - i
27 if axis >= 0:
28 max_num = max(max_num, index.shape[axis])
29 shape[max_rank - 1 - i] = max_num
30 return shape
33def broadcast_indices(indices, target_shape):
34 for i, index in enumerate(indices):
35 if index is not None and tuple(index.shape) != tuple(target_shape):
36 indices[i] = torch.broadcast_to(index, target_shape)
39def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
40 code.writeline("import triton")
41 code.writeline("import triton.language as tl")
42 code.newline()
43 code.writeline("from flag_gems.utils import libentry, libtuner")
44 code.writeline("from flag_gems import runtime")
45 code.writeline("from flag_gems.utils.shape_utils import volume")
46 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
48 code.newline()
49 code.newline()
50 return code
53def generate_index_put_kernel(
54 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer
55):
56 code.writeline("@libentry()")
57 code.writeline("@libtuner(")
58 with code.indent():
59 code.writeline('configs=runtime.get_tuned_config("index_put"),')
60 code.writeline('key=["M", "N"],')
61 code.writeline('restore_value=["input_ptr"],')
62 code.writeline('strategy=["align32", "align32"],')
63 code.writeline("warmup=5,")
64 code.writeline("rep=10,")
65 code.writeline(")")
66 code.writeline("@triton.jit")
67 code.writeline(f"def {kernel_name}(")
68 with code.indent():
69 args = ["input_ptr,"]
70 args += [f"indices{i}_ptr," for i in range(indices_len)]
71 args += ["values_ptr,"]
72 args += [f"input_shape{i}," for i in range(inp_rank)]
73 for i in range(indices_len):
74 args += [f"indices{i}_shape{j}," for j in range(index_rank)]
75 args += [f"input_stride{i}," for i in range(inp_rank)]
76 for i in range(indices_len):
77 args += [f"indices{i}_stride{j}," for j in range(index_rank)]
78 args += [
79 f"values_stride{i}," for i in range(index_rank + inp_rank - indices_len)
80 ]
81 args += [
82 "M,",
83 "N,",
84 "IS_ACCUMULATE: tl.constexpr,",
85 "BLOCK_SIZE0: tl.constexpr,",
86 "BLOCK_SIZE1: tl.constexpr,",
87 ]
88 code.writelines(args)
89 code.writeline("):")
91 with code.indent():
92 code.writeline("pid0 = tle.program_id(axis=0)")
93 code.writeline("pid1 = tle.program_id(axis=1)")
94 code.writeline(
95 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]"
96 )
97 if inp_rank == indices_len:
98 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]")
99 else:
100 code.writeline(
101 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]"
102 )
103 code.newline()
104 code.writeline("cur_idx = offset0")
105 for i in range(index_rank - 1, -1, -1):
106 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}")
107 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}")
108 code.newline()
109 code.writeline("cur_idx = offset1")
110 for i in range(inp_rank - 1, indices_len - 1, -1):
111 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}")
112 code.writeline(f"cur_idx = cur_idx // input_shape{i}")
113 code.newline()
114 code.writeline("mask0 = offset0 < M")
115 for i in range(indices_len):
116 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)]
117 code.writeline(
118 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)"
119 )
120 code.newline()
121 index_mask = [
122 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})"
123 for i in range(indices_len)
124 ]
125 code.writeline(f"index_mask = {' & '.join(index_mask)}")
126 code.writeline("mask1 = offset1 < N")
127 code.writeline("mask = index_mask & mask0 & mask1")
128 code.newline()
129 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)]
130 comp += [
131 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank)
132 ]
133 code.writeline(f"input_offset = {' + '.join(comp)}")
134 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)]
135 comp += [
136 f"input_idx{indices_len + i} * values_stride{index_rank + i}"
137 for i in range(inp_rank - indices_len)
138 ]
139 code.writeline(f"values_offset = {' + '.join(comp)}")
140 code.newline()
141 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)")
142 code.writeline("if IS_ACCUMULATE:")
143 with code.indent():
144 code.writeline(
145 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)"
146 )
147 code.writeline("else:")
148 with code.indent():
149 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)")
151 code.newline()
152 code.newline()
153 return code
156def generate_index_put_wrapper(
157 inp_rank,
158 indices_len,
159 index_rank,
160 wrapper_name: str,
161 kernel_name: str,
162 code: IndentedBuffer,
163):
164 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):")
165 with code.indent():
166 code.writeline("input_shape = input.shape")
167 code.writeline("input_stride = input.stride()")
168 for i in range(indices_len):
169 code.writeline(f"indices{i}_shape = indices[{i}].shape")
170 code.writeline(f"indices{i}_stride = indices[{i}].stride()")
171 code.writeline("values_shape = values.shape")
172 code.writeline("values_stride = values.stride()")
173 code.writeline("M = indices[0].numel()")
174 code.writeline(f"N = volume(input_shape[{indices_len}: ])")
175 code.newline()
176 code.writeline("grid = lambda meta: (")
177 with code.indent():
178 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ")
179 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ")
180 code.writeline(")")
181 code.newline()
182 code.writeline(f"{kernel_name}[grid](")
183 with code.indent():
184 args = ["input,"]
185 args += [f"indices[{i}]," for i in range(indices_len)]
186 args += ["values,"]
187 args += [f"input_shape[{i}]," for i in range(inp_rank)]
188 for i in range(indices_len):
189 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)]
190 args += [f"input_stride[{i}]," for i in range(inp_rank)]
191 for i in range(indices_len):
192 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)]
193 args += [
194 f"values_stride[{i}],"
195 for i in range(index_rank + inp_rank - indices_len)
196 ]
197 args += ["M,", "N,", "accumulate==True,"]
198 code.writelines(args)
199 code.writeline(")")
200 code.writeline("return input")
201 code.newline()
202 code.newline()
203 return code
206def generate_code(
207 inputs: Tuple[Any],
208 wrapper_name: str,
209 kernel_name: str,
210 code: IndentedBuffer,
211):
212 inp_rank = inputs[0].ndim
213 # Filter out None values to get actual tensor indices
214 tensor_indices = [idx for idx in inputs[1] if idx is not None]
215 indices_len = len(tensor_indices)
216 if indices_len == 0:
217 raise ValueError("At least one non-None index tensor is required")
218 index_rank = tensor_indices[0].ndim
219 code = generate_imports(code)
220 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code)
221 generate_index_put_wrapper(
222 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code
223 )
224 return code
227class IndexPutFunction:
228 def __init__(self):
229 self.pid = os.getpid()
230 self.overloads: Mapping[str, Callable] = {}
232 def __call__(self, *args, **kwargs):
233 inp, tensor_indices, values, accumulate = args
234 full_args = (inp, tensor_indices, values)
236 key = self.arg_key(*full_args)
237 if key in self.overloads:
238 overload = self.overloads[key]
239 else:
240 code = IndentedBuffer()
241 code = generate_code(
242 full_args,
243 "_index_put_wrapper",
244 "_index_put_jit_function",
245 code,
246 )
247 file_name = f"index_put_{key}.py"
248 file_path = code_cache_dir() / file_name
249 write_atomic(file_path, code.getvalue())
251 spec = importlib.util.spec_from_file_location(
252 f"_gen_module_rank_{key}",
253 file_path,
254 )
256 m = importlib.util.module_from_spec(spec)
257 spec.loader.exec_module(m)
258 overload = getattr(m, "_index_put_wrapper")
259 self.overloads[key] = overload
261 return overload(*args)
263 def arg_key(self, *args, **kwargs):
264 inp, tensor_indices, _ = args[0], args[1], args[2]
265 inp_rank = inp.ndim
266 indices_len = len(tensor_indices)
267 if indices_len == 0:
268 index_rank = 0
269 else:
270 index_rank = tensor_indices[0].ndim
271 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}"
274_index_put_func = IndexPutFunction()
277def index_put(inp, indices, values, accumulate=False):
278 logger.debug("GEMS_MTHREADS INDEX PUT")
280 indices = list(indices)
281 if len(indices) == 1 and indices[0].dtype == torch.bool:
282 mask = indices[0]
284 if mask.device != inp.device:
285 mask = mask.to(inp.device)
287 indices = list(torch.where(mask))
289 K = indices[0].numel()
290 target_shape = (K,) + inp.shape[len(indices) :]
292 if values.numel() == 1:
293 values = torch.full(
294 target_shape, values.item(), dtype=inp.dtype, device=inp.device
295 )
296 elif values.numel() == K:
297 values = values.reshape((K,)).expand(target_shape)
299 indices = [
300 index.to(inp.device)
301 if index is not None and index.device != inp.device
302 else index
303 for index in indices
304 ]
306 target_shape = get_max_rank_shape(indices)
307 broadcast_indices(indices, target_shape)
308 target_shape += inp.shape[len(indices) :]
309 # Filter out None values for kernel call (only tensor indices)
310 # Must be done AFTER broadcast_indices, as broadcast may create new tensors
311 tensor_indices = [idx for idx in indices if idx is not None]
312 if not tensor_indices:
313 raise ValueError("At least one non-None index tensor is required")
315 if values.device != inp.device:
316 values = values.to(inp.device)
317 values = torch.broadcast_to(values, target_shape)
319 out = inp.clone()
320 _index_put_func(out, tensor_indices, values, accumulate)
321 return out
324def index_put_(inp, indices, values, accumulate=False):
325 logger.debug("GEMS_MTHREADS INDEX PUT_")
327 indices = list(indices)
328 if len(indices) == 1 and indices[0].dtype == torch.bool:
329 mask = indices[0]
331 if mask.device != inp.device:
332 mask = mask.to(inp.device)
334 indices = list(torch.where(mask))
336 K = indices[0].numel()
337 target_shape = (K,) + inp.shape[len(indices) :]
339 if values.numel() == 1:
340 values = torch.full(
341 target_shape, values.item(), dtype=inp.dtype, device=inp.device
342 )
343 elif values.numel() == K:
344 values = values.reshape((K,)).expand(target_shape)
346 indices = [
347 index.to(inp.device)
348 if index is not None and index.device != inp.device
349 else index
350 for index in indices
351 ]
353 target_shape = get_max_rank_shape(indices)
354 broadcast_indices(indices, target_shape)
355 target_shape += inp.shape[len(indices) :]
356 # Filter out None values for kernel call (only tensor indices)
357 # Must be done AFTER broadcast_indices, as broadcast may create new tensors
358 tensor_indices = [idx for idx in indices if idx is not None]
359 if not tensor_indices:
360 raise ValueError("At least one non-None index tensor is required")
362 if values.device != inp.device:
363 values = values.to(inp.device)
364 try:
365 values = torch.broadcast_to(values, target_shape)
366 except Exception:
367 return torch.ops.aten.index_put_.default.redispatch(
368 torch._C.DispatchKeySet(torch._C.DispatchKey.CompositeExplicitAutograd),
369 inp,
370 tensor_indices,
371 values,
372 accumulate,
373 )
375 _index_put_func(inp, tensor_indices, values, accumulate)
376 return inp