Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/index.py: 0%
293 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger(__name__)
14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]:
15 # Filter out None values (basic indexing markers)
16 tensor_indices = [idx for idx in indices if idx is not None]
17 if len(tensor_indices) == 0:
18 return []
19 max_rank = max([len(index.shape) for index in tensor_indices])
20 shape = [0 for _ in range(max_rank)]
21 for i in range(max_rank):
22 max_num = 0
23 for index in tensor_indices:
24 axis = len(index.shape) - 1 - i
25 if axis >= 0:
26 max_num = max(max_num, index.shape[axis]) #
27 shape[max_rank - 1 - i] = max_num
28 return shape
31def broadcast_indices(indices, target_shape):
32 for i, index in enumerate(indices):
33 if index is not None and tuple(index.shape) != tuple(target_shape):
34 indices[i] = torch.broadcast_to(index, target_shape)
37def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
38 code.writeline("import triton")
39 code.writeline("import triton.language as tl")
40 code.writeline("import builtins")
41 code.newline()
42 code.writeline("from flag_gems.utils import libentry")
43 code.writeline("from flag_gems import runtime")
44 code.writeline("from flag_gems.utils.shape_utils import volume")
45 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
47 code.newline()
48 code.newline()
49 return code
52def generate_index_kernel(
53 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer
54):
55 code.newline()
56 code.newline()
58 code.writeline("def heur_block_m(args):")
59 with code.indent():
60 code.writeline(
61 'return builtins.max(1, triton.next_power_of_2(triton.cdiv(args["M"], 12)))'
62 )
64 code.newline()
66 code.writeline("def heur_block_n(args):")
67 with code.indent():
68 code.writeline(
69 'return builtins.max(1, builtins.min(triton.next_power_of_2(args["N"]), 4096))'
70 )
72 code.newline()
73 code.newline()
74 code.writeline("@libentry()")
75 code.writeline("@triton.heuristics(")
76 with code.indent():
77 code.writeline("values={")
78 with code.indent():
79 code.writeline('"BLOCK_SIZE0": heur_block_m,')
80 code.writeline('"BLOCK_SIZE1": heur_block_n,')
81 code.writeline("},")
82 code.writeline(")")
84 # code.writeline("@libtuner(")
85 # with code.indent():
86 # code.writeline('configs=runtime.get_tuned_config("index"),')
87 # code.writeline('key=["M", "N"],')
88 # code.writeline('strategy=["align32", "align32"],')
89 # code.writeline("warmup=5,")
90 # code.writeline("rep=10,")
91 # code.writeline(")")
93 code.writeline("@triton.jit")
94 code.writeline(f"def {kernel_name}(")
95 with code.indent():
96 args = ["input_ptr,"]
97 args += [f"indices{i}_ptr," for i in range(indices_len)]
98 args += ["out_ptr,"]
99 args += [f"input_shape{i}," for i in range(inp_rank)]
100 for i in range(indices_len):
101 args += [f"indices{i}_shape{j}," for j in range(index_rank)]
102 args += [f"input_stride{i}," for i in range(inp_rank)]
103 for i in range(indices_len):
104 args += [f"indices{i}_stride{j}," for j in range(index_rank)]
105 args += [f"out_stride{i}," for i in range(index_rank + inp_rank - indices_len)]
106 args += [
107 "M,",
108 "N,",
109 "BLOCK_SIZE0: tl.constexpr,",
110 "BLOCK_SIZE1: tl.constexpr,",
111 ]
112 code.writelines(args)
113 code.writeline("):")
115 with code.indent():
116 code.writeline("pid0 = tle.program_id(axis=0)")
117 code.writeline("pid1 = tle.program_id(axis=1)")
118 code.writeline(
119 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]"
120 )
121 if inp_rank == indices_len:
122 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]")
123 else:
124 code.writeline(
125 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]"
126 )
127 code.newline()
128 code.writeline("cur_idx = offset0")
129 for i in range(index_rank - 1, -1, -1):
130 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}")
131 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}")
132 code.newline()
133 code.writeline("cur_idx = offset1")
134 for i in range(inp_rank - 1, indices_len - 1, -1):
135 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}")
136 code.writeline(f"cur_idx = cur_idx // input_shape{i}")
137 code.newline()
138 code.writeline("mask0 = offset0 < M")
139 for i in range(indices_len):
140 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)]
141 code.writeline(
142 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)"
143 )
144 code.newline()
145 index_mask = [
146 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})"
147 for i in range(indices_len)
148 ]
149 code.writeline(f"index_mask = {' & '.join(index_mask)}")
150 code.writeline("mask1 = offset1 < N")
151 code.writeline("mask = index_mask & mask0 & mask1")
152 code.newline()
153 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)]
154 comp += [
155 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank)
156 ]
157 code.writeline(f"input_offset = {' + '.join(comp)}")
158 comp = [f"indices_idx{i} * out_stride{i}" for i in range(index_rank)]
159 comp += [
160 f"input_idx{indices_len + i} * out_stride{index_rank + i}"
161 for i in range(inp_rank - indices_len)
162 ]
163 code.writeline(f"out_offset = {' + '.join(comp)}")
164 code.newline()
165 code.writeline("cur_value = tl.load(input_ptr + input_offset , mask = mask)")
166 code.writeline("tl.store(out_ptr + out_offset, cur_value, mask=mask)")
168 code.newline()
169 code.newline()
170 return code
173def generate_index_wrapper(
174 inp_rank,
175 indices_len,
176 index_rank,
177 wrapper_name: str,
178 kernel_name: str,
179 code: IndentedBuffer,
180):
181 code.writeline(f"def {wrapper_name}(input, indices, out):")
182 with code.indent():
183 code.writeline("input_shape = input.shape")
184 code.writeline("input_stride = input.stride()")
185 for i in range(indices_len):
186 code.writeline(f"indices{i}_shape = indices[{i}].shape")
187 code.writeline(f"indices{i}_stride = indices[{i}].stride()")
188 code.writeline("out_shape = out.shape")
189 code.writeline("out_stride = out.stride()")
190 code.writeline("M = indices[0].numel()")
191 code.writeline(f"N = volume(input_shape[{indices_len}: ])")
192 code.newline()
193 code.writeline("grid = lambda meta: (")
194 with code.indent():
195 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ")
196 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ")
197 code.writeline(")")
198 code.newline()
199 code.writeline(f"{kernel_name}[grid](")
200 with code.indent():
201 args = ["input,"]
202 args += [f"indices[{i}]," for i in range(indices_len)]
203 args += ["out,"]
204 args += [f"input_shape[{i}]," for i in range(inp_rank)]
205 for i in range(indices_len):
206 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)]
207 args += [f"input_stride[{i}]," for i in range(inp_rank)]
208 for i in range(indices_len):
209 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)]
210 args += [
211 f"out_stride[{i}]," for i in range(index_rank + inp_rank - indices_len)
212 ]
213 args += ["M,", "N,"]
214 code.writelines(args)
215 code.writeline(")")
216 code.writeline("return input")
217 code.newline()
218 code.newline()
219 return code
222def generate_code(
223 inputs: Tuple[Any],
224 wrapper_name: str,
225 kernel_name: str,
226 code: IndentedBuffer,
227):
228 inp_rank = inputs[0].ndim
229 # Filter out None values to get actual tensor indices
230 tensor_indices = [idx for idx in inputs[1] if idx is not None]
231 indices_len = len(tensor_indices)
232 if indices_len == 0:
233 raise ValueError("At least one non-None index tensor is required")
234 index_rank = tensor_indices[0].ndim
235 code = generate_imports(code)
236 generate_index_kernel(inp_rank, indices_len, index_rank, kernel_name, code)
237 generate_index_wrapper(
238 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code
239 )
240 return code
243class IndexFunction:
244 def __init__(self):
245 self.pid = os.getpid()
246 self.overloads: Mapping[str, Callable] = {}
248 def __call__(self, *args, **kwargs):
249 inp, tensor_indices, out = args
250 full_args = (inp, tensor_indices)
252 key = self.arg_key(*full_args)
253 if key in self.overloads:
254 overload = self.overloads[key]
255 else:
256 code = IndentedBuffer()
257 code = generate_code(
258 full_args,
259 "_index_wrapper",
260 "_index_jit_function",
261 code,
262 )
264 file_name = f"index_{key}.py"
265 file_path = code_cache_dir() / file_name
266 write_atomic(file_path, code.getvalue())
268 spec = importlib.util.spec_from_file_location(
269 f"_gen_module_rank_{key}",
270 file_path,
271 )
273 m = importlib.util.module_from_spec(spec)
274 spec.loader.exec_module(m)
275 overload = getattr(m, "_index_wrapper")
276 self.overloads[key] = overload
278 return overload(*args)
280 def arg_key(self, *args, **kwargs):
281 inp, tensor_indices = args[0], args[1]
282 inp_rank = inp.ndim
283 indices_len = len(tensor_indices)
284 if indices_len == 0:
285 index_rank = 0
286 else:
287 index_rank = tensor_indices[0].ndim
288 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}"
291_index_func = IndexFunction()
294def index(inp, indices):
295 logger.debug("GEMS INDEX")
296 original_indices = list(indices) # Save original indices for later checks
297 indices = list(indices)
299 if not indices:
300 raise ValueError("at least one index must be provided")
302 indices = [
303 index.to(inp.device)
304 if index is not None and index.device != inp.device
305 else index
306 for index in indices
307 ]
309 # Step 1: Process indices (convert bool/int8 to long, handle None)
310 # Following PyTorch meta implementation
311 processed_indices = []
312 for i, index in enumerate(indices):
313 if index is not None:
314 # Check dtype
315 if index.dtype in [torch.int8, torch.bool]:
316 # Convert boolean/int8 mask to long indices
317 nonzero = index.nonzero()
318 k = len(processed_indices)
319 if k + index.ndim > inp.ndim:
320 raise IndexError(
321 f"too many indices for tensor of dimension {inp.ndim}"
322 )
323 # Check shape matches
324 for j in range(index.ndim):
325 if index.shape[j] != inp.shape[k + j]:
326 raise IndexError(
327 f"The shape of the mask {index.shape} at index {i} "
328 f"does not match the shape of the indexed tensor {inp.shape} at index {k + j}"
329 )
330 # Extract indices from nonzero
331 for j in range(index.ndim):
332 processed_indices.append(nonzero.select(1, j))
333 elif index.dtype in [torch.long, torch.int, torch.int32, torch.int64]:
334 processed_indices.append(index)
335 else:
336 raise TypeError(
337 "tensors used as indices must be long, int, byte or bool tensors"
338 )
339 else:
340 processed_indices.append(None)
342 indices = processed_indices
344 # Check indices count
345 if len(indices) > inp.ndim:
346 raise IndexError(
347 f"too many indices for tensor of dimension {inp.ndim} (got {len(indices)})"
348 )
350 # Save for later use
351 has_any_tensor = any(idx is not None for idx in indices)
352 starts_with_none = indices[0] is None if indices else False
354 # Step 2: Broadcast indices (only tensor indices, not None)
355 tensor_indices = [idx for idx in indices if idx is not None]
356 if tensor_indices:
357 # Broadcast all tensor indices together
358 if len(tensor_indices) > 1:
359 tensor_indices = list(torch.broadcast_tensors(*tensor_indices))
360 # Update indices list with broadcasted tensors
361 tensor_idx = 0
362 for i in range(len(indices)):
363 if indices[i] is not None:
364 indices[i] = tensor_indices[tensor_idx]
365 tensor_idx += 1
367 # Step 3: Add missing None indices (pad to input.ndim)
368 while len(indices) < inp.ndim:
369 indices.append(None)
371 # Step 4: Check if has contiguous subspace
372 # (all non-None tensors are adjacent)
373 state = 0
374 has_contiguous_subspace = False
375 for index in indices:
376 if state == 0:
377 if index is not None:
378 state = 1
379 elif state == 1:
380 if index is None:
381 state = 2
382 else:
383 if index is not None:
384 break
385 else:
386 has_contiguous_subspace = True
388 # Transpose if not contiguous OR starts with None (and has tensor indices)
389 need_post_process = False
390 first_tensor_dim = None
391 if not has_contiguous_subspace or (starts_with_none and has_any_tensor):
392 dims = []
393 transposed_indices = []
394 # First add all non-None index positions
395 for i, index in enumerate(indices):
396 if index is not None:
397 dims.append(i)
398 transposed_indices.append(index)
399 # Then add all None positions
400 for i, index in enumerate(indices):
401 if index is None:
402 dims.append(i)
403 transposed_indices.append(index)
404 # Permute input
405 inp = inp.permute(dims)
406 indices = transposed_indices
408 # Check if we need post-processing
409 # (only when originally started with None and was contiguous)
410 if starts_with_none and has_any_tensor and has_contiguous_subspace:
411 need_post_process = True
412 # Find first tensor dimension in original indices
413 for i, idx in enumerate(original_indices):
414 if idx is not None:
415 first_tensor_dim = i
416 break
418 # Step 5: Now indices have contiguous subspace (after potential transpose)
419 # Calculate output shape: before_shape + replacement_shape + after_shape
420 before_shape = []
421 after_shape = []
422 replacement_shape = []
424 for dim, index in enumerate(indices):
425 if index is None:
426 if replacement_shape:
427 # None after tensor indices -> goes to after_shape
428 after_shape.append(inp.shape[dim])
429 else:
430 # None before tensor indices -> goes to before_shape
431 before_shape.append(inp.shape[dim])
432 else:
433 # First tensor index determines replacement_shape
434 if not replacement_shape:
435 replacement_shape = list(index.shape)
437 # Step 6: Build output shape and create output tensor
438 out_shape = before_shape + replacement_shape + after_shape
439 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
441 # Step 7: Handle empty tensor case
442 if inp.numel() == 0:
443 return out
445 # Step 8: Extract only tensor indices for kernel
446 tensor_indices = [idx for idx in indices if idx is not None]
447 if not tensor_indices:
448 # All None, just reshape
449 return inp.view(*out_shape)
451 # Step 9: Call kernel with tensor indices
452 _index_func(inp, tensor_indices, out)
454 # Step 10: Post-process if needed (for originally contiguous tensor indices starting with None)
455 if need_post_process:
456 # Calculate index_rank from the first tensor index
457 index_rank = tensor_indices[0].ndim
458 # Create permutation order to move broadcast dimensions to correct position
459 pre_dims = list(range(index_rank, index_rank + first_tensor_dim))
460 broadcast_dims = list(range(index_rank))
461 post_dims = list(range(index_rank + first_tensor_dim, out.ndim))
462 new_order = pre_dims + broadcast_dims + post_dims
463 out = out.permute(new_order)
465 return out