Coverage for src/flag_gems/runtime/backend/_ascend/ops/repeat_interleave.py: 0%
39 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils.codegen_config_utils import CodeGenConfig
7from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
8from flag_gems.utils.shape_utils import c_contiguous_stride
9from flag_gems.utils.tensor_wrapper import StridedBuffer
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14config_ = CodeGenConfig(
15 2048,
16 (48, 1, 1),
17 32,
18 False,
19 prefer_1d_tile=int(triton.__version__[0]) < 3,
20)
23@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")], config=config_)
24@triton.jit
25def copy_func(x):
26 return x
29def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):
30 logger.debug("GEMS_ASCEND REPEAT_INTERLEAVE_SELF_INT")
31 if dim is None:
32 inp = inp.flatten()
33 dim = 0
34 else:
35 if (dim < -inp.ndim) or (dim >= inp.ndim):
36 raise IndexError(
37 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
38 -inp.ndim, inp.ndim - 1, dim
39 )
40 )
41 inp_shape = list(inp.shape)
42 inp_stride = list(inp.stride())
43 output_shape = list(inp.shape)
45 if dim < 0:
46 dim = dim + len(inp_shape)
48 output_shape[dim] *= repeats
50 if output_size is not None and output_size != output_shape[dim]:
51 raise RuntimeError(
52 "repeat_interleave: Invalid output_size, expected {} but got {}".format(
53 output_shape[dim], output_size
54 )
55 )
57 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
59 if repeats == 0:
60 return output
62 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]
63 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]
64 out_view_stride = c_contiguous_stride(out_view_shape)
66 in_view = StridedBuffer(inp, out_view_shape, in_view_stride)
67 out_view = StridedBuffer(output, out_view_shape, out_view_stride)
68 ndim = len(out_view_shape)
69 copy_func.instantiate(ndim)(in_view, out0=out_view)
70 return output