Coverage for src/flag_gems/runtime/backend/_metax/ops/index.py: 0%
281 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger("flag_gems." + __name__)
14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]:
15 # Filter out None values (basic indexing markers)
16 tensor_indices = [idx for idx in indices if idx is not None]
17 if len(tensor_indices) == 0:
18 return []
19 max_rank = max([len(index.shape) for index in tensor_indices])
20 shape = [0 for _ in range(max_rank)]
21 for i in range(max_rank):
22 max_num = 0
23 for index in tensor_indices:
24 axis = len(index.shape) - 1 - i
25 if axis >= 0:
26 max_num = max(max_num, index.shape[axis]) #
27 shape[max_rank - 1 - i] = max_num
28 return shape
31def broadcast_indices(indices, target_shape):
32 for i, index in enumerate(indices):
33 if index is not None and tuple(index.shape) != tuple(target_shape):
34 indices[i] = torch.broadcast_to(index, target_shape)
37def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
38 code.writeline("import triton")
39 code.writeline("import triton.language as tl")
40 code.newline()
41 code.writeline("from flag_gems.utils import libentry, libtuner")
42 code.writeline("from flag_gems import runtime")
43 code.writeline("from flag_gems.utils.shape_utils import volume")
44 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
46 code.newline()
47 code.newline()
48 return code
51def generate_index_kernel(
52 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer
53):
54 code.writeline("@libentry()")
55 code.writeline("@libtuner(")
56 with code.indent():
57 code.writeline('configs=runtime.get_tuned_config("index"),')
58 code.writeline('key=["M", "N"],')
59 code.writeline('strategy=["align32", "align32"],')
60 code.writeline("warmup=5,")
61 code.writeline("rep=10,")
62 code.writeline(")")
63 code.writeline("@triton.jit")
64 code.writeline(f"def {kernel_name}(")
65 with code.indent():
66 args = ["input_ptr,"]
67 args += [f"indices{i}_ptr," for i in range(indices_len)]
68 args += ["out_ptr,"]
69 args += [f"input_shape{i}," for i in range(inp_rank)]
70 for i in range(indices_len):
71 args += [f"indices{i}_shape{j}," for j in range(index_rank)]
72 args += [f"input_stride{i}," for i in range(inp_rank)]
73 for i in range(indices_len):
74 args += [f"indices{i}_stride{j}," for j in range(index_rank)]
75 args += [f"out_stride{i}," for i in range(index_rank + inp_rank - indices_len)]
76 args += [
77 "M,",
78 "N,",
79 "BLOCK_SIZE0: tl.constexpr,",
80 "BLOCK_SIZE1: tl.constexpr,",
81 ]
82 code.writelines(args)
83 code.writeline("):")
85 with code.indent():
86 code.writeline("pid0 = tle.program_id(axis=0)")
87 code.writeline("pid1 = tle.program_id(axis=1)")
88 code.writeline(
89 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]"
90 )
91 if inp_rank == indices_len:
92 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]")
93 else:
94 code.writeline(
95 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]"
96 )
97 code.newline()
98 code.writeline("cur_idx = offset0")
99 for i in range(index_rank - 1, -1, -1):
100 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}")
101 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}")
102 code.newline()
103 code.writeline("cur_idx = offset1")
104 for i in range(inp_rank - 1, indices_len - 1, -1):
105 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}")
106 code.writeline(f"cur_idx = cur_idx // input_shape{i}")
107 code.newline()
108 code.writeline("mask0 = offset0 < M")
109 for i in range(indices_len):
110 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)]
111 code.writeline(
112 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)"
113 )
114 code.newline()
115 index_mask = [
116 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})"
117 for i in range(indices_len)
118 ]
119 code.writeline(f"index_mask = {' & '.join(index_mask)}")
120 code.writeline("mask1 = offset1 < N")
121 code.writeline("mask = index_mask & mask0 & mask1")
122 code.newline()
123 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)]
124 comp += [
125 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank)
126 ]
127 code.writeline(f"input_offset = {' + '.join(comp)}")
128 comp = [f"indices_idx{i} * out_stride{i}" for i in range(index_rank)]
129 comp += [
130 f"input_idx{indices_len + i} * out_stride{index_rank + i}"
131 for i in range(inp_rank - indices_len)
132 ]
133 code.writeline(f"out_offset = {' + '.join(comp)}")
134 code.newline()
135 code.writeline("cur_value = tl.load(input_ptr + input_offset , mask = mask)")
136 code.writeline("tl.store(out_ptr + out_offset, cur_value, mask=mask)")
138 code.newline()
139 code.newline()
140 return code
143def generate_index_wrapper(
144 inp_rank,
145 indices_len,
146 index_rank,
147 wrapper_name: str,
148 kernel_name: str,
149 code: IndentedBuffer,
150):
151 code.writeline(f"def {wrapper_name}(input, indices, out):")
152 with code.indent():
153 code.writeline("input_shape = input.shape")
154 code.writeline("input_stride = input.stride()")
155 for i in range(indices_len):
156 code.writeline(f"indices{i}_shape = indices[{i}].shape")
157 code.writeline(f"indices{i}_stride = indices[{i}].stride()")
158 code.writeline("out_shape = out.shape")
159 code.writeline("out_stride = out.stride()")
160 code.writeline("M = indices[0].numel()")
161 code.writeline(f"N = volume(input_shape[{indices_len}: ])")
162 code.newline()
163 code.writeline("grid = lambda meta: (")
164 with code.indent():
165 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ")
166 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ")
167 code.writeline(")")
168 code.newline()
169 code.writeline(f"{kernel_name}[grid](")
170 with code.indent():
171 args = ["input,"]
172 args += [f"indices[{i}]," for i in range(indices_len)]
173 args += ["out,"]
174 args += [f"input_shape[{i}]," for i in range(inp_rank)]
175 for i in range(indices_len):
176 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)]
177 args += [f"input_stride[{i}]," for i in range(inp_rank)]
178 for i in range(indices_len):
179 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)]
180 args += [
181 f"out_stride[{i}]," for i in range(index_rank + inp_rank - indices_len)
182 ]
183 args += ["M,", "N,"]
184 code.writelines(args)
185 code.writeline(")")
186 code.writeline("return input")
187 code.newline()
188 code.newline()
189 return code
192def generate_code(
193 inputs: Tuple[Any],
194 wrapper_name: str,
195 kernel_name: str,
196 code: IndentedBuffer,
197):
198 inp_rank = inputs[0].ndim
199 # Filter out None values to get actual tensor indices
200 tensor_indices = [idx for idx in inputs[1] if idx is not None]
201 indices_len = len(tensor_indices)
202 if indices_len == 0:
203 raise ValueError("At least one non-None index tensor is required")
204 index_rank = tensor_indices[0].ndim
205 code = generate_imports(code)
206 generate_index_kernel(inp_rank, indices_len, index_rank, kernel_name, code)
207 generate_index_wrapper(
208 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code
209 )
210 return code
213class IndexFunction:
214 def __init__(self):
215 self.pid = os.getpid()
216 self.overloads: Mapping[str, Callable] = {}
218 def __call__(self, *args, **kwargs):
219 inp, tensor_indices, out = args
220 full_args = (inp, tensor_indices)
222 key = self.arg_key(*full_args)
223 if key in self.overloads:
224 overload = self.overloads[key]
225 else:
226 code = IndentedBuffer()
227 code = generate_code(
228 full_args,
229 "_index_wrapper",
230 "_index_jit_function",
231 code,
232 )
234 file_name = f"index_{key}.py"
235 file_path = code_cache_dir() / file_name
236 write_atomic(file_path, code.getvalue())
238 spec = importlib.util.spec_from_file_location(
239 f"_gen_module_rank_{key}",
240 file_path,
241 )
243 m = importlib.util.module_from_spec(spec)
244 spec.loader.exec_module(m)
245 overload = getattr(m, "_index_wrapper")
246 self.overloads[key] = overload
248 return overload(*args)
250 def arg_key(self, *args, **kwargs):
251 inp, tensor_indices = args[0], args[1]
252 inp_rank = inp.ndim
253 indices_len = len(tensor_indices)
254 if indices_len == 0:
255 index_rank = 0
256 else:
257 index_rank = tensor_indices[0].ndim
258 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}"
261_index_func = IndexFunction()
264def index(inp, indices):
265 logger.debug("GEMS INDEX")
266 original_indices = list(indices) # Save original indices for later checks
267 indices = list(indices)
269 if not indices:
270 raise ValueError("at least one index must be provided")
272 indices = [
273 index.to(inp.device)
274 if index is not None and index.device != inp.device
275 else index
276 for index in indices
277 ]
279 # Step 1: Process indices (convert bool/int8 to long, handle None)
280 # Following PyTorch meta implementation
281 processed_indices = []
282 for i, index in enumerate(indices):
283 if index is not None:
284 # Check dtype
285 if index.dtype in [torch.int8, torch.bool]:
286 # Convert boolean/int8 mask to long indices
287 nonzero = index.nonzero()
288 k = len(processed_indices)
289 if k + index.ndim > inp.ndim:
290 raise IndexError(
291 f"too many indices for tensor of dimension {inp.ndim}"
292 )
293 # Check shape matches
294 for j in range(index.ndim):
295 if index.shape[j] != inp.shape[k + j]:
296 raise IndexError(
297 f"The shape of the mask {index.shape} at index {i} "
298 f"does not match the shape of the indexed tensor {inp.shape} at index {k + j}"
299 )
300 # Extract indices from nonzero
301 for j in range(index.ndim):
302 processed_indices.append(nonzero.select(1, j))
303 elif index.dtype in [torch.long, torch.int, torch.int32, torch.int64]:
304 processed_indices.append(index)
305 else:
306 raise TypeError(
307 "tensors used as indices must be long, int, byte or bool tensors"
308 )
309 else:
310 processed_indices.append(None)
312 indices = processed_indices
314 # Check indices count
315 if len(indices) > inp.ndim:
316 raise IndexError(
317 f"too many indices for tensor of dimension {inp.ndim} (got {len(indices)})"
318 )
320 # Save for later use
321 has_any_tensor = any(idx is not None for idx in indices)
322 starts_with_none = indices[0] is None if indices else False
324 # Step 2: Broadcast indices (only tensor indices, not None)
325 tensor_indices = [idx for idx in indices if idx is not None]
326 if tensor_indices:
327 # Broadcast all tensor indices together
328 if len(tensor_indices) > 1:
329 tensor_indices = list(torch.broadcast_tensors(*tensor_indices))
330 # Update indices list with broadcasted tensors
331 tensor_idx = 0
332 for i in range(len(indices)):
333 if indices[i] is not None:
334 indices[i] = tensor_indices[tensor_idx]
335 tensor_idx += 1
337 # Step 3: Add missing None indices (pad to input.ndim)
338 while len(indices) < inp.ndim:
339 indices.append(None)
341 # Step 4: Check if has contiguous subspace
342 # (all non-None tensors are adjacent)
343 state = 0
344 has_contiguous_subspace = False
345 for index in indices:
346 if state == 0:
347 if index is not None:
348 state = 1
349 elif state == 1:
350 if index is None:
351 state = 2
352 else:
353 if index is not None:
354 break
355 else:
356 has_contiguous_subspace = True
358 # Transpose if not contiguous OR starts with None (and has tensor indices)
359 need_post_process = False
360 first_tensor_dim = None
361 if not has_contiguous_subspace or (starts_with_none and has_any_tensor):
362 dims = []
363 transposed_indices = []
364 # First add all non-None index positions
365 for i, index in enumerate(indices):
366 if index is not None:
367 dims.append(i)
368 transposed_indices.append(index)
369 # Then add all None positions
370 for i, index in enumerate(indices):
371 if index is None:
372 dims.append(i)
373 transposed_indices.append(index)
374 # Permute input
375 inp = inp.permute(dims)
376 indices = transposed_indices
378 # Check if we need post-processing
379 # (only when originally started with None and was contiguous)
380 if starts_with_none and has_any_tensor and has_contiguous_subspace:
381 need_post_process = True
382 # Find first tensor dimension in original indices
383 for i, idx in enumerate(original_indices):
384 if idx is not None:
385 first_tensor_dim = i
386 break
388 # Step 5: Now indices have contiguous subspace (after potential transpose)
389 # Calculate output shape: before_shape + replacement_shape + after_shape
390 before_shape = []
391 after_shape = []
392 replacement_shape = []
394 for dim, index in enumerate(indices):
395 if index is None:
396 if replacement_shape:
397 # None after tensor indices -> goes to after_shape
398 after_shape.append(inp.shape[dim])
399 else:
400 # None before tensor indices -> goes to before_shape
401 before_shape.append(inp.shape[dim])
402 else:
403 # First tensor index determines replacement_shape
404 if not replacement_shape:
405 replacement_shape = list(index.shape)
407 # Step 6: Build output shape and create output tensor
408 out_shape = before_shape + replacement_shape + after_shape
409 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
411 # Step 7: Handle empty tensor case
412 if inp.numel() == 0:
413 return out
415 # Step 8: Extract only tensor indices for kernel
416 tensor_indices = [idx for idx in indices if idx is not None]
417 if not tensor_indices:
418 # All None, just reshape
419 return inp.view(*out_shape)
421 # Step 9: Call kernel with tensor indices
422 _index_func(inp, tensor_indices, out)
424 # Step 10: Post-process if needed (for originally contiguous tensor indices starting with None)
425 if need_post_process:
426 # Calculate index_rank from the first tensor index
427 index_rank = tensor_indices[0].ndim
428 # Create permutation order to move broadcast dimensions to correct position
429 pre_dims = list(range(index_rank, index_rank + first_tensor_dim))
430 broadcast_dims = list(range(index_rank))
431 post_dims = list(range(index_rank + first_tensor_dim, out.ndim))
432 new_order = pre_dims + broadcast_dims + post_dims
433 out = out.permute(new_order)
435 return out