Coverage for src/flag_gems/experimental_ops/pixel_unshuffle.py: 0%
82 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def pixel_unshuffle_kernel(
8 in_ptr, # *Pointer* to input tensor (contiguous NCHW)
9 out_ptr, # *Pointer* to output tensor (contiguous NCHW)
10 n_elements, # total number of elements (N*C*H*W)
11 N,
12 C,
13 H,
14 W, # input dimensions
15 R, # downscale factor
16 C_out,
17 H_out,
18 W_out, # output dimensions
19 BLOCK_SIZE: tl.constexpr,
20):
21 pid = tl.program_id(axis=0)
22 block_start = pid * BLOCK_SIZE
23 offsets = block_start + tl.arange(0, BLOCK_SIZE)
24 mask = offsets < n_elements
26 # Strides for contiguous NCHW
27 sN_in = C * H * W
28 sC_in = H * W
29 sH_in = W
30 sW_in = 1
32 sN_out = C_out * H_out * W_out
33 sC_out = H_out * W_out
34 sH_out = W_out
35 sW_out = 1 # noqa: F841
37 # Decode output linear index into (n, c_out, h_out, w_out)
38 n = offsets // sN_out
39 rem1 = offsets - n * sN_out
40 c_out = rem1 // sC_out
41 rem2 = rem1 - c_out * sC_out
42 h_out = rem2 // sH_out
43 w_out = rem2 - h_out * sH_out
45 # Map output channel to input channel and spatial offsets
46 r2 = R * R
47 c_in = c_out // r2
48 remc = c_out - c_in * r2
49 dh = remc // R
50 dw = remc - dh * R
52 # Compute input spatial indices
53 h_in = h_out * R + dh
54 w_in = w_out * R + dw
56 # Compute input linear index
57 in_index = n * sN_in + c_in * sC_in + h_in * sH_in + w_in * sW_in
59 x = tl.load(in_ptr + in_index, mask=mask)
60 tl.store(out_ptr + offsets, x, mask=mask)
63def _launch_pixel_unshuffle_kernel(
64 inp: torch.Tensor, downscale_factor: int, out: torch.Tensor
65):
66 assert inp.is_cuda and out.is_cuda, "Input and output must be CUDA tensors"
67 assert inp.is_contiguous(), "Input must be contiguous (NCHW)"
68 assert out.is_contiguous(), "Output must be contiguous (NCHW)"
69 assert inp.ndim == 4, "Input must be a 4D tensor (N, C, H, W)"
70 N, C, H, W = inp.shape
71 r = int(downscale_factor)
72 assert r > 0, "downscale_factor must be > 0"
73 assert (H % r == 0) and (
74 W % r == 0
75 ), "H and W must be divisible by downscale_factor"
76 C_out = C * r * r
77 H_out = H // r
78 W_out = W // r
79 assert out.shape == (N, C_out, H_out, W_out), "Output has incorrect shape"
81 n_elements = inp.numel()
82 if n_elements == 0:
83 return
85 BLOCK_SIZE = 1024
86 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),)
87 pixel_unshuffle_kernel[grid](
88 inp,
89 out,
90 n_elements,
91 N,
92 C,
93 H,
94 W,
95 r,
96 C_out,
97 H_out,
98 W_out,
99 BLOCK_SIZE=BLOCK_SIZE,
100 )
103def pixel_unshuffle(input, downscale_factor, *, layout=None):
104 """
105 Wrapper for aten::pixel_unshuffle
106 Args:
107 input: Tensor[N, C, H, W] (contiguous)
108 downscale_factor: int
109 layout: unused (for API parity)
110 """
111 x = input
112 if not x.is_contiguous():
113 x = x.contiguous()
114 assert x.ndim == 4, "Input must be a 4D tensor (N, C, H, W)"
115 N, C, H, W = x.shape
116 r = int(downscale_factor)
117 assert r > 0, "downscale_factor must be > 0"
118 assert (H % r == 0) and (
119 W % r == 0
120 ), "H and W must be divisible by downscale_factor"
122 out_shape = (N, C * r * r, H // r, W // r)
123 out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
124 _launch_pixel_unshuffle_kernel(x, r, out)
125 return out
128def pixel_unshuffle_out(input, downscale_factor, out):
129 """
130 Wrapper for aten::pixel_unshuffle.out
131 Args:
132 input: Tensor[N, C, H, W] (contiguous)
133 downscale_factor: int
134 out: preallocated Tensor[N, C*r*r, H//r, W//r] (contiguous)
135 """
136 x = input
137 if not x.is_contiguous():
138 x = x.contiguous()
139 assert x.ndim == 4, "Input must be a 4D tensor (N, C, H, W)"
140 N, C, H, W = x.shape
141 r = int(downscale_factor)
142 assert r > 0, "downscale_factor must be > 0"
143 assert (H % r == 0) and (
144 W % r == 0
145 ), "H and W must be divisible by downscale_factor"
146 expected_shape = (N, C * r * r, H // r, W // r)
147 assert out.shape == expected_shape, f"out must have shape {expected_shape}"
148 assert out.dtype == x.dtype, "out dtype must match input dtype"
149 assert out.device == x.device, "out device must match input device"
150 if not out.is_contiguous():
151 raise ValueError("out must be contiguous")
153 _launch_pixel_unshuffle_kernel(x, r, out)
154 return out