Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/stack.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import itertools
2import logging
3from typing import List, Tuple, Union
5import torch
6import triton
8from flag_gems.utils.tensor_wrapper import StridedBuffer
10from ..utils.pointwise_dynamic import pointwise_dynamic
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
16@triton.jit
17def copy_func(x):
18 return x
21def stack(
22 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
23) -> torch.Tensor:
24 logger.debug("GEMS STACK")
26 if len(tensors) == 0:
27 raise RuntimeError("stack expected a non-empty TensorList")
29 inp_shapes = [list(_.shape) for _ in tensors]
30 inp0_shape = inp_shapes[0]
31 for i, s in enumerate(inp_shapes[1:]):
32 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()):
33 raise IndexError(
34 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
35 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim
36 )
37 )
38 if s != inp0_shape:
39 raise RuntimeError(
40 f"stack expects each tensor to be equal size, but got {inp0_shape} at entry 0 and {s} at entry {i + 1}"
41 )
43 if dim < 0:
44 dim = dim + len(inp0_shape) + 1
46 in0_shape = inp0_shape[:dim] + [1] + inp0_shape[dim:]
47 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:]
48 out0 = torch.empty(out_shape, dtype=tensors[0].dtype, device=tensors[0].device)
49 out0_strides = out0.stride()
50 out0_offsets = list(
51 itertools.accumulate([out0_strides[dim] for _ in inp_shapes[:-1]], initial=0)
52 )
54 for a, out0_offset in zip(tensors, out0_offsets):
55 a = a.reshape(in0_shape)
56 in_view = StridedBuffer(a, in0_shape, a.stride())
57 out_view = StridedBuffer(out0, in0_shape, out0.stride(), offset=out0_offset)
58 copy_func.instantiate(a.ndim)(in_view, out0=out_view)
60 return out0