Coverage for src/flag_gems/runtime/backend/_metax/ops/repeat_interleave.py: 0%
117 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5from triton import language as tl
7from flag_gems.utils import triton_lang_extension as tle
8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
9from flag_gems.utils.shape_utils import c_contiguous_stride
10from flag_gems.utils.tensor_wrapper import StridedBuffer
12logger = logging.getLogger("flag_gems." + __name__)
15@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")])
16@triton.jit
17def copy_func(x):
18 return x
21def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):
22 logger.debug("METAX GEMS REPEAT_INTERLEAVE_SELF_INT")
23 if dim is None:
24 inp = inp.flatten()
25 dim = 0
26 else:
27 if (dim < -inp.ndim) or (dim >= inp.ndim):
28 raise IndexError(
29 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
30 -inp.ndim, inp.ndim - 1, dim
31 )
32 )
33 inp_shape = list(inp.shape)
34 inp_stride = list(inp.stride())
35 output_shape = list(inp.shape)
37 if dim < 0:
38 dim = dim + len(inp_shape)
40 output_shape[dim] *= repeats
42 if output_size is not None and output_size != output_shape[dim]:
43 raise RuntimeError(
44 "repeat_interleave: Invalid output_size, expected {} but got {}".format(
45 output_shape[dim], output_size
46 )
47 )
49 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
51 if repeats == 0:
52 return output
54 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]
55 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]
56 out_view_stride = c_contiguous_stride(out_view_shape)
58 in_view = StridedBuffer(inp, out_view_shape, in_view_stride)
59 out_view = StridedBuffer(output, out_view_shape, out_view_stride)
60 ndim = len(out_view_shape)
61 copy_func.instantiate(ndim)(in_view, out0=out_view)
62 return output
65@triton.jit
66def repeat_interleave_tensor_kernel(
67 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr
68):
69 pid = tle.program_id(0)
70 mask = pid < size
71 cumsum = tl.load(cumsum_ptr + pid, mask, other=0)
72 repeats = tl.load(repeats_ptr + pid, mask, other=0)
73 out_offset = cumsum - repeats
75 tl.device_assert(repeats >= 0, "repeats can not be negative")
77 out_ptr += out_offset
78 for start_k in range(0, repeats, BLOCK_SIZE):
79 offsets_k = start_k + tl.arange(0, BLOCK_SIZE)
80 mask_k = offsets_k < repeats
81 tl.store(out_ptr + offsets_k, pid, mask=mask_k)
84@triton.jit
85def fused_repeat_and_index_select_kernel(
86 inp, out, M, N, repeats, BLOCK_SIZE: tl.constexpr
87):
88 pid = tle.program_id(0)
89 row_idx_mask = pid > 0
90 start_row_idx = tl.load(repeats + pid - 1, mask=row_idx_mask, other=0)
91 end_row_idx = tl.load(repeats + pid)
93 num_of_rows = end_row_idx - start_row_idx
94 if num_of_rows == 0:
95 return
97 inp_row_offset = pid * M
98 for m in range(0, tl.cdiv(M, BLOCK_SIZE)):
99 cols_offsets = m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
100 mask = cols_offsets < M
101 cur_inp = tl.load(inp + inp_row_offset + cols_offsets, mask=mask, other=0.0)
103 for cur_row_in_pid in range(0, num_of_rows):
104 output_row_index = start_row_idx + cur_row_in_pid
105 output_row_offsets = output_row_index * M + cols_offsets
106 tl.store(out + output_row_offsets, cur_inp, mask=mask)
109def repeat_interleave_tensor(repeats, *, output_size=None):
110 logger.debug("METAX GEMS REPEAT_INTERLEAVE_TENSOR")
112 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"
114 cumsum = repeats.cumsum(axis=0)
115 result_size = cumsum[-1].item()
117 assert result_size >= 0, "repeats can not be negative"
119 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device)
120 size = repeats.size(0)
122 grid = (size,)
123 BLOCK_SIZE = 32
124 repeat_interleave_tensor_kernel[grid](
125 repeats,
126 cumsum,
127 out,
128 size,
129 BLOCK_SIZE=BLOCK_SIZE,
130 num_warps=1,
131 )
132 return out
135def fused_repeat_and_index_select(inp, repeats, dim):
136 logger.debug("METAX GEMS FUSED_REPEAT_AND_INDEX_SELECT")
138 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"
140 repeats = repeats.cumsum(axis=0)
141 index_len = repeats[-1].item()
143 out_shape = list(inp.shape)
144 out_shape[dim] = index_len
145 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
147 N = inp.shape[dim]
148 M = inp.numel() // N
150 grid = (inp.shape[dim],)
151 BLOCK_SIZE = min(triton.next_power_of_2(M), 4096)
153 fused_repeat_and_index_select_kernel[grid](
154 inp, out, M, N, repeats, BLOCK_SIZE=BLOCK_SIZE
155 )
156 return out
159def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None):
160 logger.debug("METAX GEMS REPEAT_INTERLEAVE_SELF_TENSOR")
162 if dim is None:
163 inp = inp.flatten()
164 dim = 0
165 else:
166 if (dim < -inp.ndim) or (dim >= inp.ndim):
167 raise IndexError(
168 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
169 -inp.ndim, inp.ndim - 1, dim
170 )
171 )
173 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1):
174 return repeat_interleave_self_int(
175 inp, repeats.item(), dim=dim, output_size=output_size
176 )
177 elif repeats.ndim > 1:
178 raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
180 inp_shape = list(inp.shape)
181 if dim < 0:
182 dim = dim + len(inp_shape)
184 if repeats.size(0) != inp_shape[dim]:
185 raise RuntimeError(
186 "repeats must have the same size as input along dim, but got \
187 repeats.size(0) = {} and input.size({}) = {}".format(
188 repeats.size(0), dim, inp_shape[dim]
189 )
190 )
192 if inp.ndim == 2 and dim == 0:
193 res = fused_repeat_and_index_select(inp, repeats, dim)
194 else:
195 indices = repeat_interleave_tensor(repeats)
196 res = torch.index_select(inp, dim, indices)
198 return res