Coverage for src/flag_gems/ops/pixel_unshuffle.py: 63%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def pixel_unshuffle_kernel( 

15 in_ptr, # *Pointer* to input tensor (contiguous NCHW) 

16 out_ptr, # *Pointer* to output tensor (contiguous NCHW) 

17 n_elements, # total number of elements (N*C*H*W) 

18 N, 

19 C, 

20 H, 

21 W, # input dimensions 

22 R, # downscale factor 

23 C_out, 

24 H_out, 

25 W_out, # output dimensions 

26 BLOCK_SIZE: tl.constexpr, 

27): 

28 pid = tl.program_id(axis=0) 

29 block_start = pid * BLOCK_SIZE 

30 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

31 mask = offsets < n_elements 

32 

33 # Strides for contiguous NCHW 

34 sN_in = C * H * W 

35 sC_in = H * W 

36 sH_in = W 

37 sW_in = 1 

38 

39 sN_out = C_out * H_out * W_out 

40 sC_out = H_out * W_out 

41 sH_out = W_out 

42 sW_out = 1 # noqa: F841 

43 

44 # Decode output linear index into (n, c_out, h_out, w_out) 

45 n = offsets // sN_out 

46 rem1 = offsets - n * sN_out 

47 c_out = rem1 // sC_out 

48 rem2 = rem1 - c_out * sC_out 

49 h_out = rem2 // sH_out 

50 w_out = rem2 - h_out * sH_out 

51 

52 # Map output channel to input channel and spatial offsets 

53 r2 = R * R 

54 c_in = c_out // r2 

55 remc = c_out - c_in * r2 

56 dh = remc // R 

57 dw = remc - dh * R 

58 

59 # Compute input spatial indices 

60 h_in = h_out * R + dh 

61 w_in = w_out * R + dw 

62 

63 # Compute input linear index 

64 in_index = n * sN_in + c_in * sC_in + h_in * sH_in + w_in * sW_in 

65 

66 x = tl.load(in_ptr + in_index, mask=mask) 

67 tl.store(out_ptr + offsets, x, mask=mask) 

68 

69 

70def _launch_pixel_unshuffle_kernel( 

71 inp: torch.Tensor, downscale_factor: int, out: torch.Tensor 

72): 

73 assert inp.is_contiguous(), "Input must be contiguous (NCHW)" 

74 assert out.is_contiguous(), "Output must be contiguous (NCHW)" 

75 assert inp.ndim == 4, "Input must be a 4D tensor (N, C, H, W)" 

76 N, C, H, W = inp.shape 

77 r = int(downscale_factor) 

78 assert r > 0, "downscale_factor must be > 0" 

79 assert (H % r == 0) and ( 

80 W % r == 0 

81 ), "H and W must be divisible by downscale_factor" 

82 C_out = C * r * r 

83 H_out = H // r 

84 W_out = W // r 

85 assert out.shape == (N, C_out, H_out, W_out), "Output has incorrect shape" 

86 

87 n_elements = inp.numel() 

88 if n_elements == 0: 

89 return 

90 

91 BLOCK_SIZE = 1024 

92 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) 

93 with torch_device_fn.device(inp.device): 

94 pixel_unshuffle_kernel[grid]( 

95 inp, 

96 out, 

97 n_elements, 

98 N, 

99 C, 

100 H, 

101 W, 

102 r, 

103 C_out, 

104 H_out, 

105 W_out, 

106 BLOCK_SIZE=BLOCK_SIZE, 

107 ) 

108 

109 

110def pixel_unshuffle(input, downscale_factor, *, layout=None): 

111 logger.debug("GEMS PIXEL_UNSHUFFLE") 

112 x = input 

113 if not x.is_contiguous(): 

114 x = x.contiguous() 

115 assert x.ndim == 4, "Input must be a 4D tensor (N, C, H, W)" 

116 N, C, H, W = x.shape 

117 r = int(downscale_factor) 

118 assert r > 0, "downscale_factor must be > 0" 

119 assert (H % r == 0) and ( 

120 W % r == 0 

121 ), "H and W must be divisible by downscale_factor" 

122 

123 out_shape = (N, C * r * r, H // r, W // r) 

124 out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

125 _launch_pixel_unshuffle_kernel(x, r, out) 

126 return out 

127 

128 

129def pixel_unshuffle_out(input, downscale_factor, out): 

130 logger.debug("GEMS PIXEL_UNSHUFFLE_OUT") 

131 x = input 

132 if not x.is_contiguous(): 

133 x = x.contiguous() 

134 assert x.ndim == 4, "Input must be a 4D tensor (N, C, H, W)" 

135 N, C, H, W = x.shape 

136 r = int(downscale_factor) 

137 assert r > 0, "downscale_factor must be > 0" 

138 assert (H % r == 0) and ( 

139 W % r == 0 

140 ), "H and W must be divisible by downscale_factor" 

141 expected_shape = (N, C * r * r, H // r, W // r) 

142 assert out.shape == expected_shape, f"out must have shape {expected_shape}" 

143 assert out.dtype == x.dtype, "out dtype must match input dtype" 

144 assert out.device == x.device, "out device must match input device" 

145 if not out.is_contiguous(): 

146 raise ValueError("out must be contiguous") 

147 

148 _launch_pixel_unshuffle_kernel(x, r, out) 

149 return out