Coverage for src/flag_gems/experimental_ops/unsqueeze.py: 0%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def unsqueeze_kernel(
8 src_ptr, # *Pointer* to input tensor data.
9 dst_ptr, # *Pointer* to output tensor data.
10 out_numel, # Total number of elements in output (same as input).
11 in_strides_ptr, # *Pointer* to input strides (length = RANK-1).
12 out_shape_ptr, # *Pointer* to output shape (length = RANK).
13 BLOCK_SIZE: tl.constexpr, # Number of elements processed by each program.
14 RANK: tl.constexpr, # Rank of the output tensor.
15 UNSQ_DIM: tl.constexpr, # The dimension at which to unsqueeze (compile-time constant).
16):
17 pid = tl.program_id(axis=0)
18 block_start = pid * BLOCK_SIZE
19 offsets = block_start + tl.arange(0, BLOCK_SIZE)
20 offsets = offsets.to(tl.int64)
21 mask = offsets < out_numel
23 tmp = offsets
24 src_offset = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
25 # Decompose linear index into multi-dimensional coordinates and map to input offset
26 for k in range(RANK - 1, -1, -1):
27 s_k = tl.load(out_shape_ptr + k)
28 c_k = tmp % s_k
29 tmp = tmp // s_k
30 if k != UNSQ_DIM:
31 in_k = k if k < UNSQ_DIM else k - 1
32 stride_in_k = tl.load(in_strides_ptr + in_k)
33 src_offset += c_k * stride_in_k
35 vals = tl.load(src_ptr + src_offset, mask=mask)
36 tl.store(dst_ptr + offsets, vals, mask=mask)
39def unsqueeze(*args, **kwargs):
40 # Expect signature: unsqueeze(x, dim)
41 if len(args) >= 2:
42 x, dim = args[0], args[1]
43 else:
44 x = kwargs.get("self", kwargs.get("input", None))
45 dim = kwargs.get("dim", None)
46 assert isinstance(
47 x, torch.Tensor
48 ), "unsqueeze expects a torch.Tensor as the first argument."
49 assert isinstance(dim, int), "unsqueeze expects an integer 'dim' argument."
50 assert x.is_cuda, "Input tensor must be on CUDA device."
52 ndims = x.dim()
53 if dim < 0:
54 dim += ndims + 1
55 assert 0 <= dim <= ndims, f"dim must be in range [0, {ndims}] after normalization."
57 out_shape = list(x.shape)
58 out_shape.insert(dim, 1)
59 out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
61 numel = out.numel()
62 # Prepare strides and shape tensors on device
63 in_strides = torch.tensor(x.stride(), dtype=torch.int64, device=x.device)
64 out_shape_t = torch.tensor(out_shape, dtype=torch.int64, device=x.device)
66 RANK = len(out_shape)
67 BLOCK_SIZE = 1024
68 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
70 unsqueeze_kernel[grid](
71 x,
72 out,
73 numel,
74 in_strides,
75 out_shape_t,
76 BLOCK_SIZE=BLOCK_SIZE,
77 RANK=RANK,
78 UNSQ_DIM=dim,
79 )
80 return out