Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/hstack.py: 0%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import itertools
2import logging
3from typing import List, Tuple, Union
5import torch
6import triton
8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
9from flag_gems.utils.tensor_wrapper import StridedBuffer
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
15@triton.jit
16def copy_func(x):
17 return x
20def hstack(
21 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]]
22) -> torch.Tensor:
23 logger.debug("GEMS_TSINGMICRO HSTACK")
25 if len(tensors) == 0:
26 raise RuntimeError("hstack expected a non-empty TensorList")
28 if tensors[0].ndim == 0:
29 tensors[0] = tensors[0].view(1)
30 inp0_shape = tensors[0].shape
31 out_shape = list(inp0_shape)
32 inp_shapes = [inp0_shape]
34 dtypes = [t.dtype for t in tensors]
35 dtype = dtypes[0]
37 for ty in dtypes[1:]:
38 dtype = torch.promote_types(dtype, ty)
40 for i, tensor in enumerate(tensors):
41 if tensor.dtype != dtype:
42 tensors[i] = tensor.to(dtype)
44 if len(inp0_shape) == 1:
45 dim = 0
46 else:
47 dim = 1
49 for tensor_num, tensor in enumerate(tensors[1:]):
50 if tensor.ndim == 0:
51 tensor = tensor.view(1)
52 if tensor.ndim != tensors[0].ndim:
53 raise RuntimeError(
54 f"Tensors must have same number of dimensions: got {tensors[0].ndim} and {tensor.ndim}"
55 )
57 inp_shape = tensor.shape
58 inp_shapes.append(inp_shape)
60 for i in range(len(inp_shape)):
61 if i != dim and inp_shape[i] != inp0_shape[i]:
62 raise RuntimeError(
63 f"Sizes of tensors must match except in dimension {dim}. \
64 Expected size {inp0_shape[i]} but got size {inp_shape[i]} \
65 for tensor number {tensor_num + 1} in the list."
66 )
68 out_shape[dim] = sum(s[dim] for s in inp_shapes)
70 out0 = torch.empty(out_shape, dtype=tensors[0].dtype, device=tensors[0].device)
71 out0_strides = out0.stride()
72 out0_offsets = list(
73 itertools.accumulate(
74 [s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0
75 )
76 )
78 for a, out0_offset in zip(tensors, out0_offsets):
79 in_view = StridedBuffer(a, a.shape, a.stride())
80 out_view = StridedBuffer(out0, a.shape, out0.stride(), offset=out0_offset)
81 copy_func.instantiate(a.ndim)(in_view, out0=out_view)
83 return out0