Coverage for src/flag_gems/runtime/backend/_ascend/ops/cat.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import itertools 

2import logging 

3from typing import List, Tuple, Union 

4 

5import torch 

6import triton 

7 

8from flag_gems.utils import pointwise_dynamic 

9from flag_gems.utils.codegen_config_utils import CodeGenConfig 

10from flag_gems.utils.tensor_wrapper import StridedBuffer 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15config_ = CodeGenConfig( 

16 1024, 

17 (40, 1, 1), 

18 32, 

19 False, 

20 prefer_1d_tile=int(triton.__version__[0]) < 3, 

21) 

22 

23 

24@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_) 

25@triton.jit 

26def copy_func(x): 

27 return x 

28 

29 

30def cat( 

31 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

32) -> torch.Tensor: 

33 logger.debug("GEMS_ASCEND CAT") 

34 

35 if len(A) == 0: 

36 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors") 

37 if len(A) == 1: 

38 return A[0] 

39 

40 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}" 

41 # Convert negative dim to positive 

42 dim = dim % A[0].ndim 

43 

44 # Same rank check 

45 inp_shapes = [list(_.shape) for _ in A] 

46 inp0_shape = inp_shapes[0] 

47 for s in inp_shapes[1:]: 

48 if len(s) != len(inp0_shape): 

49 raise RuntimeError( 

50 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}" 

51 ) 

52 # Same size check 

53 for tensor_idx, inp_shape in enumerate(inp_shapes): 

54 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)): 

55 if idx == dim: 

56 continue 

57 elif length != common_length: 

58 raise RuntimeError( 

59 f"Sizes of tensors must match except in dimension {dim}. " 

60 f"Expected size {common_length} but got size {length} for tensor number " 

61 f"{tensor_idx} in the list" 

62 ) 

63 

64 out_shape = list(inp0_shape) 

65 out_shape[dim] = sum(s[dim] for s in inp_shapes) 

66 out0 = torch.empty(out_shape, dtype=A[0].dtype, device=A[0].device) 

67 out0_strides = out0.stride() 

68 out0_offsets = list( 

69 itertools.accumulate( 

70 [s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0 

71 ) 

72 ) 

73 

74 for a, out0_offset in zip(A, out0_offsets): 

75 in_view = StridedBuffer(a, a.shape, a.stride()) 

76 out_view = StridedBuffer(out0, a.shape, out0.stride(), offset=out0_offset) 

77 copy_func.instantiate(a.ndim)(in_view, out0=out_view) 

78 return out0