Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/stack.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import itertools 

2import logging 

3from typing import List, Tuple, Union 

4 

5import torch 

6import triton 

7 

8from flag_gems.utils.tensor_wrapper import StridedBuffer 

9 

10from ..utils.pointwise_dynamic import pointwise_dynamic 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")]) 

16@triton.jit 

17def copy_func(x): 

18 return x 

19 

20 

21def stack( 

22 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

23) -> torch.Tensor: 

24 logger.debug("GEMS STACK") 

25 

26 if len(tensors) == 0: 

27 raise RuntimeError("stack expected a non-empty TensorList") 

28 

29 inp_shapes = [list(_.shape) for _ in tensors] 

30 inp0_shape = inp_shapes[0] 

31 for i, s in enumerate(inp_shapes[1:]): 

32 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()): 

33 raise IndexError( 

34 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

35 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim 

36 ) 

37 ) 

38 if s != inp0_shape: 

39 raise RuntimeError( 

40 f"stack expects each tensor to be equal size, but got {inp0_shape} at entry 0 and {s} at entry {i + 1}" 

41 ) 

42 

43 if dim < 0: 

44 dim = dim + len(inp0_shape) + 1 

45 

46 in0_shape = inp0_shape[:dim] + [1] + inp0_shape[dim:] 

47 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:] 

48 out0 = torch.empty(out_shape, dtype=tensors[0].dtype, device=tensors[0].device) 

49 out0_strides = out0.stride() 

50 out0_offsets = list( 

51 itertools.accumulate([out0_strides[dim] for _ in inp_shapes[:-1]], initial=0) 

52 ) 

53 

54 for a, out0_offset in zip(tensors, out0_offsets): 

55 a = a.reshape(in0_shape) 

56 in_view = StridedBuffer(a, in0_shape, a.stride()) 

57 out_view = StridedBuffer(out0, in0_shape, out0.stride(), offset=out0_offset) 

58 copy_func.instantiate(a.ndim)(in_view, out0=out_view) 

59 

60 return out0