Coverage for src/flag_gems/runtime/backend/_mthreads/ops/repeat_interleave.py: 0%
225 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
11from flag_gems.utils.shape_utils import c_contiguous_stride
12from flag_gems.utils.tensor_wrapper import StridedBuffer
14logger = logging.getLogger(
15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
16)
19@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")])
20@triton.jit
21def copy_func(x):
22 return x
25def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):
26 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_SELF_INT")
27 if dim is None:
28 inp = inp.flatten()
29 dim = 0
30 else:
31 if (dim < -inp.ndim) or (dim >= inp.ndim):
32 raise IndexError(
33 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
34 -inp.ndim, inp.ndim - 1, dim
35 )
36 )
37 inp_shape = list(inp.shape)
38 inp_stride = list(inp.stride())
39 output_shape = list(inp.shape)
41 if dim < 0:
42 dim = dim + len(inp_shape)
44 output_shape[dim] *= repeats
46 if output_size is not None and output_size != output_shape[dim]:
47 raise RuntimeError(
48 "repeat_interleave: Invalid output_size, expected {} but got {}".format(
49 output_shape[dim], output_size
50 )
51 )
53 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
55 if repeats == 0:
56 return output
58 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]
59 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]
60 out_view_stride = c_contiguous_stride(out_view_shape)
62 in_view = StridedBuffer(inp, out_view_shape, in_view_stride)
63 out_view = StridedBuffer(output, out_view_shape, out_view_stride)
64 ndim = len(out_view_shape)
65 copy_func.instantiate(ndim)(in_view, out0=out_view)
66 return output
69@triton.jit
70def repeat_interleave_tensor_kernel(
71 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr
72):
73 pid = tle.program_id(0)
74 mask = pid < size
75 cumsum = tl.load(cumsum_ptr + pid, mask, other=0)
76 repeats = tl.load(repeats_ptr + pid, mask, other=0)
77 out_offset = cumsum - repeats
79 tl.device_assert(repeats >= 0, "repeats can not be negative")
81 out_ptr += out_offset
82 for start_k in range(0, repeats, BLOCK_SIZE):
83 offsets_k = start_k + tl.arange(0, BLOCK_SIZE)
84 mask_k = offsets_k < repeats
85 tl.store(out_ptr + offsets_k, pid, mask=mask_k)
88def repeat_interleave_tensor(repeats, *, output_size=None):
89 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_TENSOR")
91 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"
93 cumsum = repeats.cumsum(axis=0)
94 result_size = cumsum[-1].item()
96 assert result_size >= 0, "repeats can not be negative"
98 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device)
99 size = repeats.size(0)
101 grid = (size,)
102 BLOCK_SIZE = 32
103 with torch_device_fn.device(repeats.device):
104 repeat_interleave_tensor_kernel[grid](
105 repeats,
106 cumsum,
107 out,
108 size,
109 BLOCK_SIZE=BLOCK_SIZE,
110 num_warps=1,
111 )
112 return out
115@libentry()
116@triton.jit
117def fused_repeat_interleave_dim0_kernel(
118 inp_ptr,
119 out_ptr,
120 cumsum_ptr,
121 num_input_rows,
122 row_size,
123 BLOCK_SIZE: tl.constexpr,
124):
125 """Fused kernel for repeat_interleave with dim=0.
126 Each program handles one input row and copies to all its repeated output positions.
127 """
128 pid = tle.program_id(0)
130 if pid >= num_input_rows:
131 return
133 # Get output row range for this input row
134 row_idx_mask = pid > 0
135 start_row_idx = tl.load(cumsum_ptr + pid - 1, mask=row_idx_mask, other=0)
136 end_row_idx = tl.load(cumsum_ptr + pid)
138 num_of_rows = end_row_idx - start_row_idx
139 if num_of_rows == 0:
140 return
142 # Calculate input row offset
143 inp_row_offset = pid * row_size
145 # Process columns in blocks
146 for col_block in range(0, tl.cdiv(row_size, BLOCK_SIZE)):
147 col_offsets = col_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
148 col_mask = col_offsets < row_size
150 # Load from input
151 cur_inp = tl.load(
152 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0
153 )
155 # Store to each output row
156 for cur_row in range(0, num_of_rows):
157 output_row_index = start_row_idx + cur_row
158 output_row_offsets = output_row_index * row_size + col_offsets
159 tl.store(out_ptr + output_row_offsets, cur_inp, mask=col_mask)
162@libentry()
163@triton.jit
164def fused_repeat_interleave_output_centric_kernel(
165 inp_ptr,
166 out_ptr,
167 cumsum_ptr,
168 num_input_rows,
169 num_output_rows,
170 row_size,
171 BLOCK_SIZE: tl.constexpr,
172):
173 """Output-centric kernel for repeat_interleave with dim=0.
174 Uses 2D grid: (num_output_rows, num_col_chunks).
175 Uses binary search to find input row.
176 """
177 out_row_idx = tle.program_id(0)
178 col_chunk_idx = tle.program_id(1)
180 if out_row_idx >= num_output_rows:
181 return
183 # Binary search to find input row index
184 # Find the smallest i such that cumsum[i] > out_row_idx
185 low = 0
186 high = num_input_rows
187 while low < high:
188 mid = (low + high) // 2
189 cumsum_mid = tl.load(cumsum_ptr + mid)
190 if cumsum_mid <= out_row_idx:
191 low = mid + 1
192 else:
193 high = mid
195 inp_row_idx = low
197 # Calculate column offsets for this chunk
198 col_offset = col_chunk_idx * BLOCK_SIZE
199 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE)
200 col_mask = col_offsets < row_size
202 # Load from input
203 inp_row_offset = inp_row_idx * row_size
204 cur_inp = tl.load(inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0)
206 # Store to output
207 out_row_offset = out_row_idx * row_size
208 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask)
211@libentry()
212@triton.jit
213def fused_repeat_interleave_1d_bsearch_kernel(
214 inp_ptr,
215 out_ptr,
216 cumsum_ptr,
217 num_input_rows,
218 num_output_rows,
219 row_size,
220 BLOCK_SIZE: tl.constexpr,
221):
222 """1D output-centric kernel with binary search.
223 Each program handles one complete output row.
224 Better for large row sizes.
225 """
226 out_row_idx = tle.program_id(0)
228 if out_row_idx >= num_output_rows:
229 return
231 # Binary search to find input row index
232 low = 0
233 high = num_input_rows
234 while low < high:
235 mid = (low + high) // 2
236 cumsum_mid = tl.load(cumsum_ptr + mid)
237 if cumsum_mid <= out_row_idx:
238 low = mid + 1
239 else:
240 high = mid
242 inp_row_idx = low
244 # Calculate row offsets
245 inp_row_offset = inp_row_idx * row_size
246 out_row_offset = out_row_idx * row_size
248 # Process all columns in blocks
249 for col_offset in range(0, row_size, BLOCK_SIZE):
250 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE)
251 col_mask = col_offsets < row_size
253 cur_inp = tl.load(
254 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0
255 )
256 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask)
259@libentry()
260@triton.jit
261def fused_repeat_interleave_with_indices_kernel(
262 inp_ptr,
263 out_ptr,
264 index_ptr,
265 num_output_rows,
266 row_size,
267 BLOCK_SIZE: tl.constexpr,
268):
269 """Output-centric kernel using precomputed index mapping.
270 Uses 2D grid: (num_output_rows, num_col_chunks).
271 """
272 out_row_idx = tle.program_id(0)
273 col_chunk_idx = tle.program_id(1)
275 if out_row_idx >= num_output_rows:
276 return
278 # Load precomputed input row index
279 inp_row_idx = tl.load(index_ptr + out_row_idx)
281 # Calculate column offsets for this chunk
282 col_offset = col_chunk_idx * BLOCK_SIZE
283 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE)
284 col_mask = col_offsets < row_size
286 # Load from input
287 inp_row_offset = inp_row_idx * row_size
288 cur_inp = tl.load(inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0)
290 # Store to output
291 out_row_offset = out_row_idx * row_size
292 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask)
295@libentry()
296@triton.jit
297def fused_repeat_interleave_large_row_kernel(
298 inp_ptr,
299 out_ptr,
300 index_ptr,
301 num_output_rows,
302 row_size,
303 BLOCK_SIZE: tl.constexpr,
304):
305 """Optimized kernel for large row sizes.
306 Each program handles one output row and processes all columns.
307 """
308 out_row_idx = tle.program_id(0)
310 if out_row_idx >= num_output_rows:
311 return
313 # Load precomputed input row index
314 inp_row_idx = tl.load(index_ptr + out_row_idx)
316 # Calculate row offsets
317 inp_row_offset = inp_row_idx * row_size
318 out_row_offset = out_row_idx * row_size
320 # Process all columns in blocks
321 for col_offset in range(0, row_size, BLOCK_SIZE):
322 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE)
323 col_mask = col_offsets < row_size
325 # Load from input and store to output
326 cur_inp = tl.load(
327 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0
328 )
329 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask)
332def fused_repeat_interleave_dim0(inp, repeats, dim):
333 """Fused repeat_interleave for dim=0 case.
334 Works with any tensor dimension, handles dim=0 efficiently.
335 """
336 logger.debug("GEMS_MTHREADS FUSED_REPEAT_INTERLEAVE_DIM0")
338 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"
340 # Compute cumsum of repeats
341 cumsum = repeats.cumsum(axis=0)
342 total_output_rows = cumsum[-1].item()
344 if total_output_rows == 0:
345 out_shape = list(inp.shape)
346 out_shape[dim] = 0
347 return torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
349 # Setup output tensor
350 out_shape = list(inp.shape)
351 out_shape[dim] = total_output_rows
352 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
354 # Flatten non-dim dimensions for easier indexing
355 num_input_rows = inp.shape[dim]
356 row_size = inp.numel() // num_input_rows
358 # Make input contiguous for efficient access
359 inp_contig = inp.contiguous()
361 # Strategy selection:
362 # 1. Small tensors: input-centric kernel
363 # 2. Medium row sizes: output-centric 2D grid with binary search
364 # 3. Large row sizes: output-centric 1D grid with binary search
366 if row_size < 512 and total_output_rows < 512:
367 # Small tensor: use input-centric kernel
368 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 4096)
370 if BLOCK_SIZE <= 256:
371 num_warps = 2
372 elif BLOCK_SIZE <= 512:
373 num_warps = 4
374 else:
375 num_warps = 8
377 grid = (num_input_rows,)
379 with torch_device_fn.device(inp.device):
380 fused_repeat_interleave_dim0_kernel[grid](
381 inp_contig,
382 out,
383 cumsum,
384 num_input_rows,
385 row_size,
386 BLOCK_SIZE=BLOCK_SIZE,
387 num_warps=num_warps,
388 )
389 elif row_size >= 16384:
390 # Large row size: use 1D grid with binary search
391 # This reduces total number of programs and amortizes binary search cost
392 BLOCK_SIZE = 2048
393 num_warps = 16
395 grid = (total_output_rows,)
397 with torch_device_fn.device(inp.device):
398 fused_repeat_interleave_1d_bsearch_kernel[grid](
399 inp_contig,
400 out,
401 cumsum,
402 num_input_rows,
403 total_output_rows,
404 row_size,
405 BLOCK_SIZE=BLOCK_SIZE,
406 num_warps=num_warps,
407 )
408 else:
409 # Medium row size: use 2D grid with binary search
410 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 1024)
411 num_col_chunks = triton.cdiv(row_size, BLOCK_SIZE)
413 if BLOCK_SIZE <= 256:
414 num_warps = 2
415 elif BLOCK_SIZE <= 512:
416 num_warps = 4
417 else:
418 num_warps = 8
420 grid = (total_output_rows, num_col_chunks)
422 with torch_device_fn.device(inp.device):
423 fused_repeat_interleave_output_centric_kernel[grid](
424 inp_contig,
425 out,
426 cumsum,
427 num_input_rows,
428 total_output_rows,
429 row_size,
430 BLOCK_SIZE=BLOCK_SIZE,
431 num_warps=num_warps,
432 )
434 return out
437def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None):
438 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_SELF_TENSOR")
440 if dim is None:
441 inp = inp.flatten()
442 dim = 0
443 else:
444 if (dim < -inp.ndim) or (dim >= inp.ndim):
445 raise IndexError(
446 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
447 -inp.ndim, inp.ndim - 1, dim
448 )
449 )
451 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1):
452 return repeat_interleave_self_int(
453 inp, repeats.item(), dim=dim, output_size=output_size
454 )
455 elif repeats.ndim > 1:
456 raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
458 inp_shape = list(inp.shape)
459 if dim < 0:
460 dim = dim + len(inp_shape)
462 if repeats.size(0) != inp_shape[dim]:
463 raise RuntimeError(
464 "repeats must have the same size as input along dim, but got \
465 repeats.size(0) = {} and input.size({}) = {}".format(
466 repeats.size(0), dim, inp_shape[dim]
467 )
468 )
470 # Use fused kernel for dim=0
471 if dim == 0:
472 return fused_repeat_interleave_dim0(inp, repeats, dim)
474 # For other dimensions, use the fallback implementation
475 indices = repeat_interleave_tensor(repeats)
476 res = torch.index_select(inp, dim, indices)
478 return res