Coverage for src/flag_gems/experimental_ops/unsqueeze_copy.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _unsqueeze_copy_kernel( 

8 src_ptr, # pointer to input tensor data 

9 dst_ptr, # pointer to output tensor data 

10 sizes_ptr, # pointer to int64 sizes of src tensor (NDIM) 

11 src_strides_ptr, # pointer to int64 strides of src tensor (NDIM) 

12 dst_strides_ptr, # pointer to int64 strides of dst tensor (NDIM + 1) 

13 n_elements, # total number of elements to copy (src.numel() == dst.numel()) 

14 NDIM: tl.constexpr, 

15 INSERT_DIM: tl.constexpr, 

16 BLOCK_SIZE: tl.constexpr, 

17): 

18 pid = tl.program_id(axis=0) 

19 block_start = pid * BLOCK_SIZE 

20 offs = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offs < n_elements 

22 

23 # use int64 for index math 

24 offs = offs.to(tl.int64) 

25 

26 # Compute source and destination element offsets using shape/strides 

27 src_off = tl.zeros([BLOCK_SIZE], dtype=tl.int64) 

28 dst_off = tl.zeros([BLOCK_SIZE], dtype=tl.int64) 

29 

30 rem = offs 

31 # Decompose linear index into multi-dimensional indices (row-major order) 

32 for rev_d in range(NDIM - 1, -1, -1): 

33 sz_d = tl.load(sizes_ptr + rev_d) # scalar int64 

34 idx_d = rem % sz_d 

35 rem = rem // sz_d 

36 

37 sstride_d = tl.load(src_strides_ptr + rev_d) 

38 src_off += idx_d * sstride_d 

39 

40 # Map source dim rev_d to destination dim (account for inserted dim) 

41 if rev_d < INSERT_DIM: 

42 dstride_d = tl.load(dst_strides_ptr + rev_d) 

43 dst_off += idx_d * dstride_d 

44 else: 

45 dstride_shift = tl.load(dst_strides_ptr + (rev_d + 1)) 

46 dst_off += idx_d * dstride_shift 

47 

48 vals = tl.load(src_ptr + src_off, mask=mask) 

49 tl.store(dst_ptr + dst_off, vals, mask=mask) 

50 

51 

52def _launch_unsqueeze_copy(src: torch.Tensor, dim: int, out: torch.Tensor): 

53 assert src.is_cuda and out.is_cuda, "Tensors must be on CUDA device" 

54 assert src.dtype == out.dtype, "Dtype mismatch between src and out" 

55 

56 n_elements = src.numel() 

57 if n_elements == 0: 

58 return # nothing to copy 

59 

60 # Build metadata arrays on device 

61 sizes = torch.tensor(list(src.shape), dtype=torch.int64, device=src.device) 

62 src_strides = torch.tensor(list(src.stride()), dtype=torch.int64, device=src.device) 

63 dst_strides = torch.tensor(list(out.stride()), dtype=torch.int64, device=out.device) 

64 

65 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) 

66 _unsqueeze_copy_kernel[grid]( 

67 src, 

68 out, 

69 sizes, 

70 src_strides, 

71 dst_strides, 

72 n_elements, 

73 NDIM=src.dim(), 

74 INSERT_DIM=dim, 

75 BLOCK_SIZE=1024, 

76 ) 

77 

78 

79def unsqueeze_copy(x: torch.Tensor, dim: int): 

80 # Normalize dim 

81 dim_normalized = dim if dim >= 0 else dim + x.dim() + 1 

82 if not (0 <= dim_normalized <= x.dim()): 

83 raise IndexError(f"dim {dim} out of range for tensor with {x.dim()} dims") 

84 

85 new_shape = list(x.shape) 

86 new_shape.insert(dim_normalized, 1) 

87 out = torch.empty(new_shape, device=x.device, dtype=x.dtype) 

88 

89 _launch_unsqueeze_copy(x, dim_normalized, out) 

90 return out 

91 

92 

93def unsqueeze_copy_out(x: torch.Tensor, dim: int, out: torch.Tensor): 

94 # Normalize dim 

95 dim_normalized = dim if dim >= 0 else dim + x.dim() + 1 

96 if not (0 <= dim_normalized <= x.dim()): 

97 raise IndexError(f"dim {dim} out of range for tensor with {x.dim()} dims") 

98 

99 if out.device != x.device: 

100 raise ValueError("out tensor must be on the same device as input") 

101 if out.dtype != x.dtype: 

102 raise ValueError("out tensor must have the same dtype as input") 

103 

104 # Ensure out has the correct shape (resize_ follows PyTorch out semantics) 

105 expected_shape = list(x.shape) 

106 expected_shape.insert(dim_normalized, 1) 

107 if list(out.shape) != expected_shape: 

108 out.resize_(expected_shape) 

109 

110 _launch_unsqueeze_copy(x, dim_normalized, out) 

111 return out