Coverage for src/flag_gems/experimental_ops/pixel_unshuffle.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def pixel_unshuffle_kernel( 

8 in_ptr, # *Pointer* to input tensor (contiguous NCHW) 

9 out_ptr, # *Pointer* to output tensor (contiguous NCHW) 

10 n_elements, # total number of elements (N*C*H*W) 

11 N, 

12 C, 

13 H, 

14 W, # input dimensions 

15 R, # downscale factor 

16 C_out, 

17 H_out, 

18 W_out, # output dimensions 

19 BLOCK_SIZE: tl.constexpr, 

20): 

21 pid = tl.program_id(axis=0) 

22 block_start = pid * BLOCK_SIZE 

23 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

24 mask = offsets < n_elements 

25 

26 # Strides for contiguous NCHW 

27 sN_in = C * H * W 

28 sC_in = H * W 

29 sH_in = W 

30 sW_in = 1 

31 

32 sN_out = C_out * H_out * W_out 

33 sC_out = H_out * W_out 

34 sH_out = W_out 

35 sW_out = 1 # noqa: F841 

36 

37 # Decode output linear index into (n, c_out, h_out, w_out) 

38 n = offsets // sN_out 

39 rem1 = offsets - n * sN_out 

40 c_out = rem1 // sC_out 

41 rem2 = rem1 - c_out * sC_out 

42 h_out = rem2 // sH_out 

43 w_out = rem2 - h_out * sH_out 

44 

45 # Map output channel to input channel and spatial offsets 

46 r2 = R * R 

47 c_in = c_out // r2 

48 remc = c_out - c_in * r2 

49 dh = remc // R 

50 dw = remc - dh * R 

51 

52 # Compute input spatial indices 

53 h_in = h_out * R + dh 

54 w_in = w_out * R + dw 

55 

56 # Compute input linear index 

57 in_index = n * sN_in + c_in * sC_in + h_in * sH_in + w_in * sW_in 

58 

59 x = tl.load(in_ptr + in_index, mask=mask) 

60 tl.store(out_ptr + offsets, x, mask=mask) 

61 

62 

63def _launch_pixel_unshuffle_kernel( 

64 inp: torch.Tensor, downscale_factor: int, out: torch.Tensor 

65): 

66 assert inp.is_cuda and out.is_cuda, "Input and output must be CUDA tensors" 

67 assert inp.is_contiguous(), "Input must be contiguous (NCHW)" 

68 assert out.is_contiguous(), "Output must be contiguous (NCHW)" 

69 assert inp.ndim == 4, "Input must be a 4D tensor (N, C, H, W)" 

70 N, C, H, W = inp.shape 

71 r = int(downscale_factor) 

72 assert r > 0, "downscale_factor must be > 0" 

73 assert (H % r == 0) and ( 

74 W % r == 0 

75 ), "H and W must be divisible by downscale_factor" 

76 C_out = C * r * r 

77 H_out = H // r 

78 W_out = W // r 

79 assert out.shape == (N, C_out, H_out, W_out), "Output has incorrect shape" 

80 

81 n_elements = inp.numel() 

82 if n_elements == 0: 

83 return 

84 

85 BLOCK_SIZE = 1024 

86 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) 

87 pixel_unshuffle_kernel[grid]( 

88 inp, 

89 out, 

90 n_elements, 

91 N, 

92 C, 

93 H, 

94 W, 

95 r, 

96 C_out, 

97 H_out, 

98 W_out, 

99 BLOCK_SIZE=BLOCK_SIZE, 

100 ) 

101 

102 

103def pixel_unshuffle(input, downscale_factor, *, layout=None): 

104 """ 

105 Wrapper for aten::pixel_unshuffle 

106 Args: 

107 input: Tensor[N, C, H, W] (contiguous) 

108 downscale_factor: int 

109 layout: unused (for API parity) 

110 """ 

111 x = input 

112 if not x.is_contiguous(): 

113 x = x.contiguous() 

114 assert x.ndim == 4, "Input must be a 4D tensor (N, C, H, W)" 

115 N, C, H, W = x.shape 

116 r = int(downscale_factor) 

117 assert r > 0, "downscale_factor must be > 0" 

118 assert (H % r == 0) and ( 

119 W % r == 0 

120 ), "H and W must be divisible by downscale_factor" 

121 

122 out_shape = (N, C * r * r, H // r, W // r) 

123 out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

124 _launch_pixel_unshuffle_kernel(x, r, out) 

125 return out 

126 

127 

128def pixel_unshuffle_out(input, downscale_factor, out): 

129 """ 

130 Wrapper for aten::pixel_unshuffle.out 

131 Args: 

132 input: Tensor[N, C, H, W] (contiguous) 

133 downscale_factor: int 

134 out: preallocated Tensor[N, C*r*r, H//r, W//r] (contiguous) 

135 """ 

136 x = input 

137 if not x.is_contiguous(): 

138 x = x.contiguous() 

139 assert x.ndim == 4, "Input must be a 4D tensor (N, C, H, W)" 

140 N, C, H, W = x.shape 

141 r = int(downscale_factor) 

142 assert r > 0, "downscale_factor must be > 0" 

143 assert (H % r == 0) and ( 

144 W % r == 0 

145 ), "H and W must be divisible by downscale_factor" 

146 expected_shape = (N, C * r * r, H // r, W // r) 

147 assert out.shape == expected_shape, f"out must have shape {expected_shape}" 

148 assert out.dtype == x.dtype, "out dtype must match input dtype" 

149 assert out.device == x.device, "out device must match input device" 

150 if not out.is_contiguous(): 

151 raise ValueError("out must be contiguous") 

152 

153 _launch_pixel_unshuffle_kernel(x, r, out) 

154 return out