Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/hstack.py: 0%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import itertools 

2import logging 

3from typing import List, Tuple, Union 

4 

5import torch 

6import triton 

7 

8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

9from flag_gems.utils.tensor_wrapper import StridedBuffer 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")]) 

15@triton.jit 

16def copy_func(x): 

17 return x 

18 

19 

20def hstack( 

21 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]] 

22) -> torch.Tensor: 

23 logger.debug("GEMS_TSINGMICRO HSTACK") 

24 

25 if len(tensors) == 0: 

26 raise RuntimeError("hstack expected a non-empty TensorList") 

27 

28 if tensors[0].ndim == 0: 

29 tensors[0] = tensors[0].view(1) 

30 inp0_shape = tensors[0].shape 

31 out_shape = list(inp0_shape) 

32 inp_shapes = [inp0_shape] 

33 

34 dtypes = [t.dtype for t in tensors] 

35 dtype = dtypes[0] 

36 

37 for ty in dtypes[1:]: 

38 dtype = torch.promote_types(dtype, ty) 

39 

40 for i, tensor in enumerate(tensors): 

41 if tensor.dtype != dtype: 

42 tensors[i] = tensor.to(dtype) 

43 

44 if len(inp0_shape) == 1: 

45 dim = 0 

46 else: 

47 dim = 1 

48 

49 for tensor_num, tensor in enumerate(tensors[1:]): 

50 if tensor.ndim == 0: 

51 tensor = tensor.view(1) 

52 if tensor.ndim != tensors[0].ndim: 

53 raise RuntimeError( 

54 f"Tensors must have same number of dimensions: got {tensors[0].ndim} and {tensor.ndim}" 

55 ) 

56 

57 inp_shape = tensor.shape 

58 inp_shapes.append(inp_shape) 

59 

60 for i in range(len(inp_shape)): 

61 if i != dim and inp_shape[i] != inp0_shape[i]: 

62 raise RuntimeError( 

63 f"Sizes of tensors must match except in dimension {dim}. \ 

64 Expected size {inp0_shape[i]} but got size {inp_shape[i]} \ 

65 for tensor number {tensor_num + 1} in the list." 

66 ) 

67 

68 out_shape[dim] = sum(s[dim] for s in inp_shapes) 

69 

70 out0 = torch.empty(out_shape, dtype=tensors[0].dtype, device=tensors[0].device) 

71 out0_strides = out0.stride() 

72 out0_offsets = list( 

73 itertools.accumulate( 

74 [s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0 

75 ) 

76 ) 

77 

78 for a, out0_offset in zip(tensors, out0_offsets): 

79 in_view = StridedBuffer(a, a.shape, a.stride()) 

80 out_view = StridedBuffer(out0, a.shape, out0.stride(), offset=out0_offset) 

81 copy_func.instantiate(a.ndim)(in_view, out0=out_view) 

82 

83 return out0