Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cat.py: 0%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import itertools 

2import logging 

3from typing import List, Tuple, Union 

4 

5import torch 

6import triton 

7from _kunlunxin.utils.codegen_config_utils import CodeGenConfig 

8 

9from flag_gems.utils.tensor_wrapper import StridedBuffer 

10 

11from ..utils.pointwise_dynamic import pointwise_dynamic 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14config_ = CodeGenConfig( 

15 512, 

16 (65536, 65536, 65536), 

17 32, 

18 True, 

19 prefer_1d_tile=True, 

20) 

21 

22 

23@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_) 

24@triton.jit 

25def copy_func(x): 

26 return x 

27 

28 

29def cat( 

30 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

31) -> torch.Tensor: 

32 logger.debug("GEMS CAT") 

33 

34 if len(A) == 0: 

35 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors") 

36 if len(A) == 1: 

37 return A[0] 

38 

39 # remove torch.Size([0]) tensors 

40 device = A[0].device 

41 dtype = A[0].dtype 

42 A = list(A) 

43 for i in range(len(A) - 1, -1, -1): 

44 if A[i].shape == torch.Size([0]): 

45 A.pop(i) 

46 if len(A) == 0: 

47 return torch.tensor([], device=device, dtype=dtype) 

48 elif len(A) == 1: 

49 return A[0] 

50 

51 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}" 

52 # Convert negative dim to positive 

53 dim = dim % A[0].ndim 

54 

55 # Same rank check 

56 inp_shapes = [list(_.shape) for _ in A] 

57 inp0_shape = inp_shapes[0] 

58 for s in inp_shapes[1:]: 

59 if len(s) != len(inp0_shape): 

60 raise RuntimeError( 

61 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}" 

62 ) 

63 # Same size check 

64 for tensor_idx, inp_shape in enumerate(inp_shapes): 

65 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)): 

66 if idx == dim: 

67 continue 

68 elif length != common_length: 

69 raise RuntimeError( 

70 f"Sizes of tensors must match except in dimension {dim}. " 

71 f"Expected size {common_length} but got size {length} for tensor number " 

72 f"{tensor_idx} in the list" 

73 ) 

74 

75 out_shape = list(inp0_shape) 

76 out_shape[dim] = sum(s[dim] for s in inp_shapes) 

77 out0 = torch.empty(out_shape, dtype=A[0].dtype, device=A[0].device) 

78 out0_strides = out0.stride() 

79 out0_offsets = list( 

80 itertools.accumulate( 

81 [s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0 

82 ) 

83 ) 

84 

85 for a, out0_offset in zip(A, out0_offsets): 

86 in_view = StridedBuffer(a, a.shape, a.stride()) 

87 out_view = StridedBuffer(out0, a.shape, out0.stride(), offset=out0_offset) 

88 copy_func.instantiate(a.ndim)(in_view, out0=out_view) 

89 return out0