Coverage for src/flag_gems/runtime/backend/_metax/ops/index_select.py: 0%
72 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7import flag_gems.runtime as runtime
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger("flag_gems." + __name__)
14@libentry()
15@triton.heuristics(runtime.get_heuristic_config("index_select"))
16@triton.jit
17def index_select_kernel(
18 inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
19):
20 pid_x = tle.program_id(axis=0)
21 pid_y = tle.program_id(axis=1)
22 rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
23 rows_mask = rows_offsets < M
24 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)
26 out_mask = rows_mask and (cols_offsets < index_len)
28 indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0)
29 inp_off = rows_offsets * N + indices[None, :]
30 out_off = rows_offsets * index_len + cols_offsets[None, :]
32 selected = tl.load(inp + inp_off, mask=rows_mask, other=0.0)
33 tl.store(out + out_off, selected, mask=out_mask)
36@libentry()
37@triton.jit
38def index_select_2d_opt_kernel(inp, out, M, N, index, BLOCK_SIZE: tl.constexpr):
39 pid = tle.program_id(axis=0)
41 row_index = tl.load(index + pid)
42 row_offset = row_index * M
43 rows_mask = row_index < N
45 for m in range(0, tl.cdiv(M, BLOCK_SIZE)):
46 cols_offsets = m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
47 cols_mask = cols_offsets < M
48 block_mask = rows_mask and cols_mask
49 cur_inp = tl.load(inp + row_offset + cols_offsets, mask=block_mask, other=0.0)
50 out_offset = pid * M + cols_offsets
51 tl.store(out + out_offset, cur_inp, mask=block_mask)
54# Swap two dimensions of a 2D tensor to for better memory access pattern
55def dim_transpose(inp):
56 return torch.transpose(inp, 0, 1).contiguous()
59def index_select(inp, dim, index):
60 logger.debug("METAX GEMS INDEX SELECT")
61 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
62 assert index.ndim <= 1, "Index should have dimension 1 or 0"
63 assert ((i >= 0 and i < inp.size(dim)) for i in index), "Index out of range"
65 if index.ndim == 0:
66 index = index.unsqueeze(0)
67 dim = dim % inp.ndim
68 index_len = index.numel()
69 inp_shape = list(inp.shape)
70 N = inp_shape[dim]
71 M = inp.numel() // N
72 out_shape = list(inp.shape)
73 out_shape[dim] = index_len
74 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
76 if inp.ndim == 2 and dim == 0:
77 BLOCK_SIZE = min(triton.next_power_of_2(M), 4096)
78 if dim == 0:
79 index_select_2d_opt_kernel[(index_len,)](
80 inp, out, M, N, index, BLOCK_SIZE=BLOCK_SIZE
81 )
83 return out
84 else:
85 # with dim_compress
86 inp = dim_compress(inp, dim)
87 N = inp_shape[dim]
88 M = inp.numel() // N
89 out_shape = list(inp.shape)
90 out_shape[inp.ndim - 1] = index_len
91 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
93 grid = lambda meta: (
94 triton.cdiv(M, meta["BLOCK_M"]),
95 triton.cdiv(index_len, meta["BLOCK_N"]),
96 )
97 index_select_kernel[grid](inp, out, M, N, index, index_len)
98 if dim != out.ndim - 1:
99 order = [i for i in range(out.ndim - 1)]
100 order.insert(dim, out.ndim - 1)
101 return out.permute(order).contiguous()
102 else:
103 return out