Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cat.py: 0%
53 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import itertools
2import logging
3from typing import List, Tuple, Union
5import torch
6import triton
7from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
9from flag_gems.utils.tensor_wrapper import StridedBuffer
11from ..utils.pointwise_dynamic import pointwise_dynamic
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14config_ = CodeGenConfig(
15 512,
16 (65536, 65536, 65536),
17 32,
18 True,
19 prefer_1d_tile=True,
20)
23@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_)
24@triton.jit
25def copy_func(x):
26 return x
29def cat(
30 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
31) -> torch.Tensor:
32 logger.debug("GEMS CAT")
34 if len(A) == 0:
35 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors")
36 if len(A) == 1:
37 return A[0]
39 # remove torch.Size([0]) tensors
40 device = A[0].device
41 dtype = A[0].dtype
42 A = list(A)
43 for i in range(len(A) - 1, -1, -1):
44 if A[i].shape == torch.Size([0]):
45 A.pop(i)
46 if len(A) == 0:
47 return torch.tensor([], device=device, dtype=dtype)
48 elif len(A) == 1:
49 return A[0]
51 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}"
52 # Convert negative dim to positive
53 dim = dim % A[0].ndim
55 # Same rank check
56 inp_shapes = [list(_.shape) for _ in A]
57 inp0_shape = inp_shapes[0]
58 for s in inp_shapes[1:]:
59 if len(s) != len(inp0_shape):
60 raise RuntimeError(
61 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}"
62 )
63 # Same size check
64 for tensor_idx, inp_shape in enumerate(inp_shapes):
65 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)):
66 if idx == dim:
67 continue
68 elif length != common_length:
69 raise RuntimeError(
70 f"Sizes of tensors must match except in dimension {dim}. "
71 f"Expected size {common_length} but got size {length} for tensor number "
72 f"{tensor_idx} in the list"
73 )
75 out_shape = list(inp0_shape)
76 out_shape[dim] = sum(s[dim] for s in inp_shapes)
77 out0 = torch.empty(out_shape, dtype=A[0].dtype, device=A[0].device)
78 out0_strides = out0.stride()
79 out0_offsets = list(
80 itertools.accumulate(
81 [s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0
82 )
83 )
85 for a, out0_offset in zip(A, out0_offsets):
86 in_view = StridedBuffer(a, a.shape, a.stride())
87 out_view = StridedBuffer(out0, a.shape, out0.stride(), offset=out0_offset)
88 copy_func.instantiate(a.ndim)(in_view, out0=out_view)
89 return out0