Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/repeat_interleave.py: 0%
83 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
3import torch
4import triton
5from triton import language as tl
7from flag_gems.utils import triton_lang_extension as tle
8from flag_gems.utils.shape_utils import c_contiguous_stride
9from flag_gems.utils.tensor_wrapper import StridedBuffer
11from ..utils.pointwise_dynamic import pointwise_dynamic
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")])
17@triton.jit
18def copy_func(x):
19 return x
22def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):
23 logger.debug("GEMS REPEAT_INTERLEAVE_SELF_INT")
24 if dim is None:
25 inp = inp.flatten()
26 dim = 0
27 else:
28 if (dim < -inp.ndim) or (dim >= inp.ndim):
29 raise IndexError(
30 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
31 -inp.ndim, inp.ndim - 1, dim
32 )
33 )
34 inp_shape = list(inp.shape)
35 inp_stride = list(inp.stride())
36 output_shape = list(inp.shape)
38 if dim < 0:
39 dim = dim + len(inp_shape)
41 output_shape[dim] *= repeats
43 if output_size is not None and output_size != output_shape[dim]:
44 raise RuntimeError(
45 "repeat_interleave: Invalid output_size, expected {} but got {}".format(
46 output_shape[dim], output_size
47 )
48 )
50 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
52 if repeats == 0:
53 return output
55 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]
56 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]
57 out_view_stride = c_contiguous_stride(out_view_shape)
59 in_view = StridedBuffer(inp, out_view_shape, in_view_stride)
60 out_view = StridedBuffer(output, out_view_shape, out_view_stride)
61 ndim = len(out_view_shape)
62 copy_func.instantiate(ndim)(in_view, out0=out_view)
63 return output
66@triton.jit
67def repeat_interleave_tensor_kernel(
68 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr
69):
70 pid = tle.program_id(0)
71 mask = pid < size
72 cumsum = tl.load(cumsum_ptr + pid, mask, other=0)
73 repeats = tl.load(repeats_ptr + pid, mask, other=0)
74 out_offset = cumsum - repeats
76 tl.device_assert(repeats >= 0, "repeats can not be negative")
78 out_ptr += out_offset
79 for start_k in range(0, repeats, BLOCK_SIZE):
80 offsets_k = start_k + tl.arange(0, BLOCK_SIZE)
81 mask_k = offsets_k < repeats
82 tl.store(out_ptr + offsets_k, pid, mask=mask_k)
85def repeat_interleave_tensor(repeats, *, output_size=None):
86 logger.debug("GEMS REPEAT_INTERLEAVE_TENSOR")
88 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"
90 cumsum = repeats.cumsum(axis=0)
91 result_size = cumsum[-1].item()
93 assert result_size >= 0, "repeats can not be negative"
95 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device)
96 size = repeats.size(0)
98 grid = (size,)
99 BLOCK_SIZE = 32
100 repeat_interleave_tensor_kernel[grid](
101 repeats,
102 cumsum,
103 out,
104 size,
105 BLOCK_SIZE=BLOCK_SIZE,
106 num_warps=1,
107 )
108 return out
111def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None):
112 logger.debug("GEMS REPEAT_INTERLEAVE_SELF_TENSOR")
114 if dim is None:
115 inp = inp.flatten()
116 dim = 0
117 else:
118 if (dim < -inp.ndim) or (dim >= inp.ndim):
119 raise IndexError(
120 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
121 -inp.ndim, inp.ndim - 1, dim
122 )
123 )
125 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1):
126 return repeat_interleave_self_int(
127 inp, repeats.item(), dim=dim, output_size=output_size
128 )
129 elif repeats.ndim > 1:
130 raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
132 inp_shape = list(inp.shape)
133 if dim < 0:
134 dim = dim + len(inp_shape)
136 if repeats.size(0) != inp_shape[dim]:
137 raise RuntimeError(
138 "repeats must have the same size as input along dim, but got \
139 repeats.size(0) = {} and input.size({}) = {}".format(
140 repeats.size(0), dim, inp_shape[dim]
141 )
142 )
144 indices = repeat_interleave_tensor(repeats)
145 res = torch.index_select(inp, dim, indices)
147 return res