Coverage for src/flag_gems/experimental_ops/pixel_shuffle.py: 0%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def pixel_shuffle_kernel( 

8 in_ptr, 

9 out_ptr, 

10 N, 

11 C_out, 

12 H, 

13 W, 

14 R, 

15 H_out, 

16 W_out, 

17 s_in_n, 

18 s_in_c, 

19 s_in_h, 

20 s_in_w, 

21 s_out_n, 

22 s_out_c, 

23 s_out_h, 

24 s_out_w, 

25 n_elements, 

26 BLOCK_SIZE: tl.constexpr, 

27): 

28 pid = tl.program_id(axis=0) 

29 block_start = pid * BLOCK_SIZE 

30 offs32 = block_start + tl.arange(0, BLOCK_SIZE) 

31 mask = offs32 < n_elements 

32 offs = tl.cast(offs32, tl.int64) 

33 

34 N64 = tl.cast(N, tl.int64) # noqa: F841 

35 C64 = tl.cast(C_out, tl.int64) 

36 H64 = tl.cast(H, tl.int64) # noqa: F841 

37 W64 = tl.cast(W, tl.int64) # noqa: F841 

38 R64 = tl.cast(R, tl.int64) 

39 H_out64 = tl.cast(H_out, tl.int64) 

40 W_out64 = tl.cast(W_out, tl.int64) 

41 

42 s_in_n64 = tl.cast(s_in_n, tl.int64) 

43 s_in_c64 = tl.cast(s_in_c, tl.int64) 

44 s_in_h64 = tl.cast(s_in_h, tl.int64) 

45 s_in_w64 = tl.cast(s_in_w, tl.int64) 

46 

47 s_out_n64 = tl.cast(s_out_n, tl.int64) 

48 s_out_c64 = tl.cast(s_out_c, tl.int64) 

49 s_out_h64 = tl.cast(s_out_h, tl.int64) 

50 s_out_w64 = tl.cast(s_out_w, tl.int64) 

51 

52 wo = offs % W_out64 

53 tmp = offs // W_out64 

54 ho = tmp % H_out64 

55 tmp = tmp // H_out64 

56 co = tmp % C64 

57 no = tmp // C64 

58 

59 rh = ho % R64 

60 h = ho // R64 

61 rw = wo % R64 

62 w = wo // R64 

63 

64 cin = co * (R64 * R64) + rh * R64 + rw 

65 

66 in_idx = no * s_in_n64 + cin * s_in_c64 + h * s_in_h64 + w * s_in_w64 

67 out_idx = no * s_out_n64 + co * s_out_c64 + ho * s_out_h64 + wo * s_out_w64 

68 

69 val = tl.load(in_ptr + in_idx, mask=mask, other=0) 

70 tl.store(out_ptr + out_idx, val, mask=mask) 

71 

72 

73def _check_and_get_shapes_strides(x: torch.Tensor, upscale_factor: int): 

74 if x.dim() != 4: 

75 raise RuntimeError( 

76 f"pixel_shuffle expects a 4D tensor (N, C, H, W), but got {x.dim()}D" 

77 ) 

78 if upscale_factor <= 0: 

79 raise RuntimeError("upscale_factor must be > 0") 

80 N, C_in, H, W = x.shape 

81 r2 = upscale_factor * upscale_factor 

82 if C_in % r2 != 0: 

83 raise RuntimeError( 

84 f"Input channel dimension {C_in} is not divisible by upscale_factor^2={r2}" 

85 ) 

86 C_out = C_in // r2 

87 H_out = H * upscale_factor 

88 W_out = W * upscale_factor 

89 in_strides = x.stride() 

90 return (N, C_in, H, W, C_out, H_out, W_out, in_strides) 

91 

92 

93def _launch_pixel_shuffle_kernel( 

94 x: torch.Tensor, out: torch.Tensor, upscale_factor: int 

95): 

96 N, C_in, H, W = x.shape 

97 C_out = C_in // (upscale_factor * upscale_factor) 

98 H_out = H * upscale_factor 

99 W_out = W * upscale_factor 

100 

101 s_in_n, s_in_c, s_in_h, s_in_w = x.stride() 

102 s_out_n, s_out_c, s_out_h, s_out_w = out.stride() 

103 

104 n_elements = out.numel() 

105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731 

106 pixel_shuffle_kernel[grid]( 

107 x, 

108 out, 

109 N, 

110 C_out, 

111 H, 

112 W, 

113 upscale_factor, 

114 H_out, 

115 W_out, 

116 s_in_n, 

117 s_in_c, 

118 s_in_h, 

119 s_in_w, 

120 s_out_n, 

121 s_out_c, 

122 s_out_h, 

123 s_out_w, 

124 n_elements, 

125 BLOCK_SIZE=1024, 

126 ) 

127 

128 

129def pixel_shuffle(self: torch.Tensor, upscale_factor: int): 

130 if not self.is_cuda: 

131 raise RuntimeError("pixel_shuffle: input must be a CUDA tensor") 

132 if not isinstance(upscale_factor, int): 

133 raise TypeError("pixel_shuffle: upscale_factor must be an integer") 

134 N, C_in, H, W, C_out, H_out, W_out, _ = _check_and_get_shapes_strides( 

135 self, upscale_factor 

136 ) 

137 out = torch.empty( 

138 (N, C_out, H_out, W_out), 

139 dtype=self.dtype, 

140 device=self.device, 

141 layout=self.layout, 

142 ) 

143 _launch_pixel_shuffle_kernel(self, out, upscale_factor) 

144 return out 

145 

146 

147def pixel_shuffle_out(self: torch.Tensor, upscale_factor: int, out: torch.Tensor): 

148 if not self.is_cuda or not out.is_cuda: 

149 raise RuntimeError("pixel_shuffle_out: input and out must be CUDA tensors") 

150 if not isinstance(upscale_factor, int): 

151 raise TypeError("pixel_shuffle_out: upscale_factor must be an integer") 

152 N, C_in, H, W, C_out, H_out, W_out, _ = _check_and_get_shapes_strides( 

153 self, upscale_factor 

154 ) 

155 expected_shape = (N, C_out, H_out, W_out) 

156 if tuple(out.shape) != expected_shape: 

157 raise RuntimeError( 

158 f"pixel_shuffle_out: out tensor has incorrect shape, expected {expected_shape} but got {tuple(out.shape)}" 

159 ) 

160 if out.dtype != self.dtype: 

161 raise RuntimeError("pixel_shuffle_out: out dtype must match input dtype") 

162 _launch_pixel_shuffle_kernel(self, out, upscale_factor) 

163 return out