Coverage for src/flag_gems/runtime/backend/_cambricon/ops/repeat_interleave.py: 0%
111 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5from triton import language as tl
7from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
8from flag_gems.utils.shape_utils import c_contiguous_stride
9from flag_gems.utils.tensor_wrapper import StridedBuffer
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")])
15@triton.jit
16def copy_func(x):
17 return x
20def repeat_interleave_self_int_forward(inp, repeats, dim=None, *, output_size=None):
21 inp_shape = list(inp.shape)
22 inp_stride = list(inp.stride())
23 output_shape = list(inp.shape)
25 output_shape[dim] *= repeats
27 if output_size is not None and output_size != output_shape[dim]:
28 raise RuntimeError(
29 "repeat_interleave: Invalid output_size, expected {} but got {}".format(
30 output_shape[dim], output_size
31 )
32 )
34 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
36 if repeats == 0:
37 return output
39 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]
40 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]
41 out_view_stride = c_contiguous_stride(out_view_shape)
43 in_view = StridedBuffer(inp, out_view_shape, in_view_stride)
44 out_view = StridedBuffer(output, out_view_shape, out_view_stride)
45 ndim = len(out_view_shape)
46 copy_func.instantiate(ndim)(in_view, out0=out_view)
47 return output
50class RepeatInterleaveSelfIntFn(torch.autograd.Function):
51 @staticmethod
52 def forward(ctx, inp, repeats, dim, output_size):
53 logger.debug("GEMS_CAMBRICON REPEAT_INTERLEAVE_SELF_INT FORWARD")
54 ctx.inp_shape = inp.shape
55 ctx.dim = dim
56 if dim is None:
57 inp = inp.flatten()
58 dim = 0
59 else:
60 if (dim < -inp.ndim) or (dim >= inp.ndim):
61 raise IndexError(
62 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
63 -inp.ndim, inp.ndim - 1, dim
64 )
65 )
66 inp_shape = list(inp.shape)
67 if dim < 0:
68 dim = dim + len(inp_shape)
69 ctx.repeats = repeats
70 ctx.output_size = output_size
72 out = repeat_interleave_self_int_forward(
73 inp, repeats, dim=dim, output_size=output_size
74 )
75 return out
77 @staticmethod
78 def backward(ctx, grad_out):
79 logger.debug("GEMS_CAMBRICON REPEAT_INTERLEAVE_SELF_INT BACKWARD")
80 dim = ctx.dim
81 k = ctx.repeats
82 shape = ctx.inp_shape
83 new_shape = list(shape)
84 if ctx.dim is None:
85 new_shape.insert(len(shape), k)
86 grad_view = grad_out.view(*new_shape)
87 grad_x = grad_view.sum(dim=len(shape))
88 else:
89 if dim < 0:
90 dim = dim + len(shape)
91 new_shape.insert(dim + 1, k)
92 grad_view = grad_out.view(*new_shape)
93 grad_x = grad_view.sum(dim=dim + 1)
94 return grad_x, None, None, None
97def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):
98 return RepeatInterleaveSelfIntFn.apply(inp, repeats, dim, output_size)
101@triton.jit
102def repeat_interleave_tensor_kernel(
103 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr
104):
105 pid = tl.program_id(0)
106 mask = pid < size
107 cumsum = tl.load(cumsum_ptr + pid, mask, other=0)
108 repeats = tl.load(repeats_ptr + pid, mask, other=0)
109 out_offset = cumsum - repeats
111 tl.device_assert(repeats >= 0, "repeats can not be negative")
113 out_ptr += out_offset
114 for start_k in range(0, repeats, BLOCK_SIZE):
115 offsets_k = start_k + tl.arange(0, BLOCK_SIZE)
116 mask_k = offsets_k < repeats
117 tl.store(out_ptr + offsets_k, pid, mask=mask_k)
120def repeat_interleave_tensor(repeats, *, output_size=None):
121 logger.debug("GEMS_CAMBRICON REPEAT_INTERLEAVE_TENSOR")
123 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"
125 cumsum = repeats.cumsum(axis=0)
126 result_size = cumsum[-1].item()
128 assert result_size >= 0, "repeats can not be negative"
130 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device)
131 size = repeats.size(0)
133 grid = (size,)
134 BLOCK_SIZE = 32
135 repeat_interleave_tensor_kernel[grid](
136 repeats,
137 cumsum,
138 out,
139 size,
140 BLOCK_SIZE=BLOCK_SIZE,
141 num_warps=1,
142 )
143 return out
146def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None):
147 logger.debug("GEMS_CAMBRICON REPEAT_INTERLEAVE_SELF_TENSOR")
149 if dim is None:
150 inp = inp.flatten()
151 dim = 0
152 else:
153 if (dim < -inp.ndim) or (dim >= inp.ndim):
154 raise IndexError(
155 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
156 -inp.ndim, inp.ndim - 1, dim
157 )
158 )
160 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1):
161 return repeat_interleave_self_int(
162 inp, repeats.item(), dim=dim, output_size=output_size
163 )
164 elif repeats.ndim > 1:
165 raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
167 inp_shape = list(inp.shape)
168 if dim < 0:
169 dim = dim + len(inp_shape)
171 if repeats.size(0) != inp_shape[dim]:
172 raise RuntimeError(
173 "repeats must have the same size as input along dim, but got \
174 repeats.size(0) = {} and input.size({}) = {}".format(
175 repeats.size(0), dim, inp_shape[dim]
176 )
177 )
179 indices = repeat_interleave_tensor(repeats)
180 res = torch.index_select(inp, dim, indices)
182 return res