Coverage for src/flag_gems/experimental_ops/unsqueeze_copy.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _unsqueeze_copy_kernel(
8 src_ptr, # pointer to input tensor data
9 dst_ptr, # pointer to output tensor data
10 sizes_ptr, # pointer to int64 sizes of src tensor (NDIM)
11 src_strides_ptr, # pointer to int64 strides of src tensor (NDIM)
12 dst_strides_ptr, # pointer to int64 strides of dst tensor (NDIM + 1)
13 n_elements, # total number of elements to copy (src.numel() == dst.numel())
14 NDIM: tl.constexpr,
15 INSERT_DIM: tl.constexpr,
16 BLOCK_SIZE: tl.constexpr,
17):
18 pid = tl.program_id(axis=0)
19 block_start = pid * BLOCK_SIZE
20 offs = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offs < n_elements
23 # use int64 for index math
24 offs = offs.to(tl.int64)
26 # Compute source and destination element offsets using shape/strides
27 src_off = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
28 dst_off = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
30 rem = offs
31 # Decompose linear index into multi-dimensional indices (row-major order)
32 for rev_d in range(NDIM - 1, -1, -1):
33 sz_d = tl.load(sizes_ptr + rev_d) # scalar int64
34 idx_d = rem % sz_d
35 rem = rem // sz_d
37 sstride_d = tl.load(src_strides_ptr + rev_d)
38 src_off += idx_d * sstride_d
40 # Map source dim rev_d to destination dim (account for inserted dim)
41 if rev_d < INSERT_DIM:
42 dstride_d = tl.load(dst_strides_ptr + rev_d)
43 dst_off += idx_d * dstride_d
44 else:
45 dstride_shift = tl.load(dst_strides_ptr + (rev_d + 1))
46 dst_off += idx_d * dstride_shift
48 vals = tl.load(src_ptr + src_off, mask=mask)
49 tl.store(dst_ptr + dst_off, vals, mask=mask)
52def _launch_unsqueeze_copy(src: torch.Tensor, dim: int, out: torch.Tensor):
53 assert src.is_cuda and out.is_cuda, "Tensors must be on CUDA device"
54 assert src.dtype == out.dtype, "Dtype mismatch between src and out"
56 n_elements = src.numel()
57 if n_elements == 0:
58 return # nothing to copy
60 # Build metadata arrays on device
61 sizes = torch.tensor(list(src.shape), dtype=torch.int64, device=src.device)
62 src_strides = torch.tensor(list(src.stride()), dtype=torch.int64, device=src.device)
63 dst_strides = torch.tensor(list(out.stride()), dtype=torch.int64, device=out.device)
65 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),)
66 _unsqueeze_copy_kernel[grid](
67 src,
68 out,
69 sizes,
70 src_strides,
71 dst_strides,
72 n_elements,
73 NDIM=src.dim(),
74 INSERT_DIM=dim,
75 BLOCK_SIZE=1024,
76 )
79def unsqueeze_copy(x: torch.Tensor, dim: int):
80 # Normalize dim
81 dim_normalized = dim if dim >= 0 else dim + x.dim() + 1
82 if not (0 <= dim_normalized <= x.dim()):
83 raise IndexError(f"dim {dim} out of range for tensor with {x.dim()} dims")
85 new_shape = list(x.shape)
86 new_shape.insert(dim_normalized, 1)
87 out = torch.empty(new_shape, device=x.device, dtype=x.dtype)
89 _launch_unsqueeze_copy(x, dim_normalized, out)
90 return out
93def unsqueeze_copy_out(x: torch.Tensor, dim: int, out: torch.Tensor):
94 # Normalize dim
95 dim_normalized = dim if dim >= 0 else dim + x.dim() + 1
96 if not (0 <= dim_normalized <= x.dim()):
97 raise IndexError(f"dim {dim} out of range for tensor with {x.dim()} dims")
99 if out.device != x.device:
100 raise ValueError("out tensor must be on the same device as input")
101 if out.dtype != x.dtype:
102 raise ValueError("out tensor must have the same dtype as input")
104 # Ensure out has the correct shape (resize_ follows PyTorch out semantics)
105 expected_shape = list(x.shape)
106 expected_shape.insert(dim_normalized, 1)
107 if list(out.shape) != expected_shape:
108 out.resize_(expected_shape)
110 _launch_unsqueeze_copy(x, dim_normalized, out)
111 return out