Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_select.py: 0%
101 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.index_select import index_select as default_index_select
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(
13 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
14)
17@libentry()
18@triton.jit
19def index_select_dim0_1d_kernel(
20 inp_ptr,
21 out_ptr,
22 index_ptr,
23 inp_row_stride,
24 out_row_stride,
25 row_size,
26 num_indices,
27 BLOCK_SIZE: tl.constexpr,
28):
29 """Kernel for dim=0 index_select - each program handles one row."""
30 pid = tle.program_id(axis=0)
32 # Load the index for this row
33 row_index = tl.load(index_ptr + pid)
35 # Calculate input and output row offsets
36 inp_row_offset = row_index * inp_row_stride
37 out_row_offset = pid * out_row_stride
39 # Process row in chunks
40 for offset in range(0, row_size, BLOCK_SIZE):
41 cols = offset + tl.arange(0, BLOCK_SIZE)
42 mask = cols < row_size
44 # Load from input and store to output
45 data = tl.load(inp_ptr + inp_row_offset + cols, mask=mask, other=0.0)
46 tl.store(out_ptr + out_row_offset + cols, data, mask=mask)
49@libentry()
50@triton.jit
51def index_select_dim0_split_kernel(
52 inp_ptr,
53 out_ptr,
54 index_ptr,
55 inp_row_stride,
56 out_row_stride,
57 row_size,
58 num_indices,
59 BLOCK_SIZE: tl.constexpr,
60):
61 """Kernel for dim=0 index_select - 2D grid for large row_size.
62 First dimension: indices, Second dimension: column chunks.
63 """
64 pid_idx = tle.program_id(axis=0)
65 pid_col = tle.program_id(axis=1)
67 # Load the index for this row
68 row_index = tl.load(index_ptr + pid_idx)
70 # Calculate input and output row offsets
71 inp_row_offset = row_index * inp_row_stride
72 out_row_offset = pid_idx * out_row_stride
74 # Calculate column offset for this program
75 col_offset = pid_col * BLOCK_SIZE
76 cols = col_offset + tl.arange(0, BLOCK_SIZE)
77 mask = cols < row_size
79 # Load from input and store to output
80 data = tl.load(inp_ptr + inp_row_offset + cols, mask=mask, other=0.0)
81 tl.store(out_ptr + out_row_offset + cols, data, mask=mask)
84@libentry()
85@triton.jit
86def index_select_dim1_kernel(
87 inp_ptr,
88 out_ptr,
89 index_ptr,
90 num_rows,
91 inp_row_stride,
92 out_row_stride,
93 num_indices,
94 BLOCK_M: tl.constexpr,
95 BLOCK_N: tl.constexpr,
96):
97 """Optimized kernel for dim=1 index_select on 2D tensors.
98 Each program handles a tile of rows x indices.
99 """
100 pid_m = tle.program_id(axis=0)
101 pid_n = tle.program_id(axis=1)
103 row_start = pid_m * BLOCK_M
104 idx_start = pid_n * BLOCK_N
106 rows = row_start + tl.arange(0, BLOCK_M)[:, None]
107 idx_offsets = idx_start + tl.arange(0, BLOCK_N)[None, :]
109 rows_mask = rows < num_rows
110 idx_mask = idx_offsets < num_indices
111 mask = rows_mask & idx_mask
113 # Load indices
114 indices = tl.load(index_ptr + idx_offsets, mask=idx_mask, other=0)
116 # Calculate offsets
117 inp_offsets = rows * inp_row_stride + indices
118 out_offsets = rows * out_row_stride + idx_offsets
120 # Load and store
121 data = tl.load(inp_ptr + inp_offsets, mask=mask, other=0.0)
122 tl.store(out_ptr + out_offsets, data, mask=mask)
125def _get_num_warps(total_elements):
126 """Get optimal num_warps based on workload size."""
127 if total_elements < 1024:
128 return 2
129 elif total_elements < 4096:
130 return 4
131 elif total_elements < 16384:
132 return 8
133 else:
134 return 16
137def index_select(inp, dim, index):
138 logger.debug("GEMS_MTHREADS INDEX SELECT")
139 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
140 assert index.ndim <= 1, "Index should have dimension 1 or 0"
142 if index.ndim == 0:
143 index = index.unsqueeze(0)
145 dim = dim % inp.ndim
146 index_len = index.numel()
148 # Create output shape
149 out_shape = list(inp.shape)
150 out_shape[dim] = index_len
151 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
153 if inp.numel() == 0 or index_len == 0:
154 return out
156 # Optimized path for 2D tensors with dim=0
157 if inp.ndim == 2 and dim == 0 and inp.is_contiguous():
158 num_rows, row_size = inp.shape
159 inp_row_stride = inp.stride(0)
160 out_row_stride = out.stride(0)
162 # For large row_size, use 2D grid (indices x column_chunks) for more parallelism
163 if row_size >= 16384:
164 BLOCK_SIZE = 1024
165 num_col_chunks = triton.cdiv(row_size, BLOCK_SIZE)
166 grid = (index_len, num_col_chunks)
167 num_warps = _get_num_warps(BLOCK_SIZE)
169 with torch_device_fn.device(inp.device):
170 index_select_dim0_split_kernel[grid](
171 inp,
172 out,
173 index,
174 inp_row_stride,
175 out_row_stride,
176 row_size,
177 index_len,
178 BLOCK_SIZE=BLOCK_SIZE,
179 num_warps=num_warps,
180 )
181 return out
182 else:
183 # Use 1D kernel - each program handles one complete row
184 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 2048)
185 num_warps = _get_num_warps(BLOCK_SIZE)
187 with torch_device_fn.device(inp.device):
188 index_select_dim0_1d_kernel[(index_len,)](
189 inp,
190 out,
191 index,
192 inp_row_stride,
193 out_row_stride,
194 row_size,
195 index_len,
196 BLOCK_SIZE=BLOCK_SIZE,
197 num_warps=num_warps,
198 )
199 return out
201 # Optimized path for 2D tensors with dim=1
202 if inp.ndim == 2 and dim == 1 and inp.is_contiguous():
203 num_rows, num_cols = inp.shape
204 inp_row_stride = inp.stride(0)
205 out_row_stride = out.stride(0)
207 BLOCK_M = min(triton.next_power_of_2(num_rows), 64)
208 BLOCK_N = min(triton.next_power_of_2(index_len), 128)
210 grid = (triton.cdiv(num_rows, BLOCK_M), triton.cdiv(index_len, BLOCK_N))
211 num_warps = _get_num_warps(BLOCK_M * BLOCK_N)
213 with torch_device_fn.device(inp.device):
214 index_select_dim1_kernel[grid](
215 inp,
216 out,
217 index,
218 num_rows,
219 inp_row_stride,
220 out_row_stride,
221 index_len,
222 BLOCK_M=BLOCK_M,
223 BLOCK_N=BLOCK_N,
224 num_warps=num_warps,
225 )
226 return out
228 # Fall back to default implementation for other cases
229 return default_index_select(inp, dim, index)