Coverage for src/flag_gems/ops/pixel_unshuffle.py: 63%
87 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def pixel_unshuffle_kernel(
15 in_ptr, # *Pointer* to input tensor (contiguous NCHW)
16 out_ptr, # *Pointer* to output tensor (contiguous NCHW)
17 n_elements, # total number of elements (N*C*H*W)
18 N,
19 C,
20 H,
21 W, # input dimensions
22 R, # downscale factor
23 C_out,
24 H_out,
25 W_out, # output dimensions
26 BLOCK_SIZE: tl.constexpr,
27):
28 pid = tl.program_id(axis=0)
29 block_start = pid * BLOCK_SIZE
30 offsets = block_start + tl.arange(0, BLOCK_SIZE)
31 mask = offsets < n_elements
33 # Strides for contiguous NCHW
34 sN_in = C * H * W
35 sC_in = H * W
36 sH_in = W
37 sW_in = 1
39 sN_out = C_out * H_out * W_out
40 sC_out = H_out * W_out
41 sH_out = W_out
42 sW_out = 1 # noqa: F841
44 # Decode output linear index into (n, c_out, h_out, w_out)
45 n = offsets // sN_out
46 rem1 = offsets - n * sN_out
47 c_out = rem1 // sC_out
48 rem2 = rem1 - c_out * sC_out
49 h_out = rem2 // sH_out
50 w_out = rem2 - h_out * sH_out
52 # Map output channel to input channel and spatial offsets
53 r2 = R * R
54 c_in = c_out // r2
55 remc = c_out - c_in * r2
56 dh = remc // R
57 dw = remc - dh * R
59 # Compute input spatial indices
60 h_in = h_out * R + dh
61 w_in = w_out * R + dw
63 # Compute input linear index
64 in_index = n * sN_in + c_in * sC_in + h_in * sH_in + w_in * sW_in
66 x = tl.load(in_ptr + in_index, mask=mask)
67 tl.store(out_ptr + offsets, x, mask=mask)
70def _launch_pixel_unshuffle_kernel(
71 inp: torch.Tensor, downscale_factor: int, out: torch.Tensor
72):
73 assert inp.is_contiguous(), "Input must be contiguous (NCHW)"
74 assert out.is_contiguous(), "Output must be contiguous (NCHW)"
75 assert inp.ndim == 4, "Input must be a 4D tensor (N, C, H, W)"
76 N, C, H, W = inp.shape
77 r = int(downscale_factor)
78 assert r > 0, "downscale_factor must be > 0"
79 assert (H % r == 0) and (
80 W % r == 0
81 ), "H and W must be divisible by downscale_factor"
82 C_out = C * r * r
83 H_out = H // r
84 W_out = W // r
85 assert out.shape == (N, C_out, H_out, W_out), "Output has incorrect shape"
87 n_elements = inp.numel()
88 if n_elements == 0:
89 return
91 BLOCK_SIZE = 1024
92 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),)
93 with torch_device_fn.device(inp.device):
94 pixel_unshuffle_kernel[grid](
95 inp,
96 out,
97 n_elements,
98 N,
99 C,
100 H,
101 W,
102 r,
103 C_out,
104 H_out,
105 W_out,
106 BLOCK_SIZE=BLOCK_SIZE,
107 )
110def pixel_unshuffle(input, downscale_factor, *, layout=None):
111 logger.debug("GEMS PIXEL_UNSHUFFLE")
112 x = input
113 if not x.is_contiguous():
114 x = x.contiguous()
115 assert x.ndim == 4, "Input must be a 4D tensor (N, C, H, W)"
116 N, C, H, W = x.shape
117 r = int(downscale_factor)
118 assert r > 0, "downscale_factor must be > 0"
119 assert (H % r == 0) and (
120 W % r == 0
121 ), "H and W must be divisible by downscale_factor"
123 out_shape = (N, C * r * r, H // r, W // r)
124 out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
125 _launch_pixel_unshuffle_kernel(x, r, out)
126 return out
129def pixel_unshuffle_out(input, downscale_factor, out):
130 logger.debug("GEMS PIXEL_UNSHUFFLE_OUT")
131 x = input
132 if not x.is_contiguous():
133 x = x.contiguous()
134 assert x.ndim == 4, "Input must be a 4D tensor (N, C, H, W)"
135 N, C, H, W = x.shape
136 r = int(downscale_factor)
137 assert r > 0, "downscale_factor must be > 0"
138 assert (H % r == 0) and (
139 W % r == 0
140 ), "H and W must be divisible by downscale_factor"
141 expected_shape = (N, C * r * r, H // r, W // r)
142 assert out.shape == expected_shape, f"out must have shape {expected_shape}"
143 assert out.dtype == x.dtype, "out dtype must match input dtype"
144 assert out.device == x.device, "out device must match input device"
145 if not out.is_contiguous():
146 raise ValueError("out must be contiguous")
148 _launch_pixel_unshuffle_kernel(x, r, out)
149 return out