Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/cat.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import itertools
2import logging
3from typing import List, Tuple, Union
5import torch
6import triton
8from flag_gems.utils import pointwise_dynamic
9from flag_gems.utils.codegen_config_utils import CodeGenConfig
10from flag_gems.utils.tensor_wrapper import StridedBuffer
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15config_ = CodeGenConfig(
16 1024,
17 (16, 1, 1),
18 32,
19 False,
20 prefer_1d_tile=int(triton.__version__[0]) < 3,
21)
24@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_)
25@triton.jit
26def copy_func(x):
27 return x
30def cat(
31 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
32) -> torch.Tensor:
33 logger.debug("GEMS_TSINGMICRO CAT")
35 if len(A) == 0:
36 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors")
37 if len(A) == 1:
38 return A[0]
40 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}"
41 # Convert negative dim to positive
42 dim = dim % A[0].ndim
44 # Same rank check
45 inp_shapes = [list(_.shape) for _ in A]
46 inp0_shape = inp_shapes[0]
47 for s in inp_shapes[1:]:
48 if len(s) != len(inp0_shape):
49 raise RuntimeError(
50 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}"
51 )
52 # Same size check
53 for tensor_idx, inp_shape in enumerate(inp_shapes):
54 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)):
55 if idx == dim:
56 continue
57 elif length != common_length:
58 raise RuntimeError(
59 f"Sizes of tensors must match except in dimension {dim}. "
60 f"Expected size {common_length} but got size {length} for tensor number "
61 f"{tensor_idx} in the list"
62 )
64 out_shape = list(inp0_shape)
65 out_shape[dim] = sum(s[dim] for s in inp_shapes)
66 out0 = torch.empty(out_shape, dtype=A[0].dtype, device=A[0].device)
67 out0_strides = out0.stride()
68 out0_offsets = list(
69 itertools.accumulate(
70 [s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0
71 )
72 )
74 for a, out0_offset in zip(A, out0_offsets):
75 in_view = StridedBuffer(a, a.shape, a.stride())
76 out_view = StridedBuffer(out0, a.shape, out0.stride(), offset=out0_offset)
77 copy_func.instantiate(a.ndim)(in_view, out0=out_view)
78 return out0