Coverage for src/flag_gems/experimental_ops/pixel_shuffle.py: 0%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def pixel_shuffle_kernel(
8 in_ptr,
9 out_ptr,
10 N,
11 C_out,
12 H,
13 W,
14 R,
15 H_out,
16 W_out,
17 s_in_n,
18 s_in_c,
19 s_in_h,
20 s_in_w,
21 s_out_n,
22 s_out_c,
23 s_out_h,
24 s_out_w,
25 n_elements,
26 BLOCK_SIZE: tl.constexpr,
27):
28 pid = tl.program_id(axis=0)
29 block_start = pid * BLOCK_SIZE
30 offs32 = block_start + tl.arange(0, BLOCK_SIZE)
31 mask = offs32 < n_elements
32 offs = tl.cast(offs32, tl.int64)
34 N64 = tl.cast(N, tl.int64) # noqa: F841
35 C64 = tl.cast(C_out, tl.int64)
36 H64 = tl.cast(H, tl.int64) # noqa: F841
37 W64 = tl.cast(W, tl.int64) # noqa: F841
38 R64 = tl.cast(R, tl.int64)
39 H_out64 = tl.cast(H_out, tl.int64)
40 W_out64 = tl.cast(W_out, tl.int64)
42 s_in_n64 = tl.cast(s_in_n, tl.int64)
43 s_in_c64 = tl.cast(s_in_c, tl.int64)
44 s_in_h64 = tl.cast(s_in_h, tl.int64)
45 s_in_w64 = tl.cast(s_in_w, tl.int64)
47 s_out_n64 = tl.cast(s_out_n, tl.int64)
48 s_out_c64 = tl.cast(s_out_c, tl.int64)
49 s_out_h64 = tl.cast(s_out_h, tl.int64)
50 s_out_w64 = tl.cast(s_out_w, tl.int64)
52 wo = offs % W_out64
53 tmp = offs // W_out64
54 ho = tmp % H_out64
55 tmp = tmp // H_out64
56 co = tmp % C64
57 no = tmp // C64
59 rh = ho % R64
60 h = ho // R64
61 rw = wo % R64
62 w = wo // R64
64 cin = co * (R64 * R64) + rh * R64 + rw
66 in_idx = no * s_in_n64 + cin * s_in_c64 + h * s_in_h64 + w * s_in_w64
67 out_idx = no * s_out_n64 + co * s_out_c64 + ho * s_out_h64 + wo * s_out_w64
69 val = tl.load(in_ptr + in_idx, mask=mask, other=0)
70 tl.store(out_ptr + out_idx, val, mask=mask)
73def _check_and_get_shapes_strides(x: torch.Tensor, upscale_factor: int):
74 if x.dim() != 4:
75 raise RuntimeError(
76 f"pixel_shuffle expects a 4D tensor (N, C, H, W), but got {x.dim()}D"
77 )
78 if upscale_factor <= 0:
79 raise RuntimeError("upscale_factor must be > 0")
80 N, C_in, H, W = x.shape
81 r2 = upscale_factor * upscale_factor
82 if C_in % r2 != 0:
83 raise RuntimeError(
84 f"Input channel dimension {C_in} is not divisible by upscale_factor^2={r2}"
85 )
86 C_out = C_in // r2
87 H_out = H * upscale_factor
88 W_out = W * upscale_factor
89 in_strides = x.stride()
90 return (N, C_in, H, W, C_out, H_out, W_out, in_strides)
93def _launch_pixel_shuffle_kernel(
94 x: torch.Tensor, out: torch.Tensor, upscale_factor: int
95):
96 N, C_in, H, W = x.shape
97 C_out = C_in // (upscale_factor * upscale_factor)
98 H_out = H * upscale_factor
99 W_out = W * upscale_factor
101 s_in_n, s_in_c, s_in_h, s_in_w = x.stride()
102 s_out_n, s_out_c, s_out_h, s_out_w = out.stride()
104 n_elements = out.numel()
105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
106 pixel_shuffle_kernel[grid](
107 x,
108 out,
109 N,
110 C_out,
111 H,
112 W,
113 upscale_factor,
114 H_out,
115 W_out,
116 s_in_n,
117 s_in_c,
118 s_in_h,
119 s_in_w,
120 s_out_n,
121 s_out_c,
122 s_out_h,
123 s_out_w,
124 n_elements,
125 BLOCK_SIZE=1024,
126 )
129def pixel_shuffle(self: torch.Tensor, upscale_factor: int):
130 if not self.is_cuda:
131 raise RuntimeError("pixel_shuffle: input must be a CUDA tensor")
132 if not isinstance(upscale_factor, int):
133 raise TypeError("pixel_shuffle: upscale_factor must be an integer")
134 N, C_in, H, W, C_out, H_out, W_out, _ = _check_and_get_shapes_strides(
135 self, upscale_factor
136 )
137 out = torch.empty(
138 (N, C_out, H_out, W_out),
139 dtype=self.dtype,
140 device=self.device,
141 layout=self.layout,
142 )
143 _launch_pixel_shuffle_kernel(self, out, upscale_factor)
144 return out
147def pixel_shuffle_out(self: torch.Tensor, upscale_factor: int, out: torch.Tensor):
148 if not self.is_cuda or not out.is_cuda:
149 raise RuntimeError("pixel_shuffle_out: input and out must be CUDA tensors")
150 if not isinstance(upscale_factor, int):
151 raise TypeError("pixel_shuffle_out: upscale_factor must be an integer")
152 N, C_in, H, W, C_out, H_out, W_out, _ = _check_and_get_shapes_strides(
153 self, upscale_factor
154 )
155 expected_shape = (N, C_out, H_out, W_out)
156 if tuple(out.shape) != expected_shape:
157 raise RuntimeError(
158 f"pixel_shuffle_out: out tensor has incorrect shape, expected {expected_shape} but got {tuple(out.shape)}"
159 )
160 if out.dtype != self.dtype:
161 raise RuntimeError("pixel_shuffle_out: out dtype must match input dtype")
162 _launch_pixel_shuffle_kernel(self, out, upscale_factor)
163 return out