Coverage for src/flag_gems/experimental_ops/im2col.py: 0%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def im2col_kernel( 

8 x_ptr, # *Pointer* to input tensor [N, C, H, W] 

9 out_ptr, # *Pointer* to output tensor [N, C*kH*kW, outH*outW] 

10 N, 

11 C, 

12 H, 

13 W, 

14 kH, 

15 kW, 

16 dH, 

17 dW, 

18 pH, 

19 pW, 

20 sH, 

21 sW, 

22 outH, 

23 outW, 

24 rows_total, # C * kH * kW 

25 L, # outH * outW 

26 num_row_tiles, # ceil_div(rows_total, BLOCK_M) 

27 BLOCK_M: tl.constexpr, 

28 BLOCK_N: tl.constexpr, 

29): 

30 pid0 = tl.program_id(0) 

31 pid1 = tl.program_id(1) 

32 

33 n = pid0 // num_row_tiles 

34 row_tile = pid0 % num_row_tiles 

35 

36 row_offsets = row_tile * BLOCK_M + tl.arange(0, BLOCK_M) 

37 col_offsets = pid1 * BLOCK_N + tl.arange(0, BLOCK_N) 

38 

39 mask_rows = row_offsets < rows_total 

40 mask_cols = col_offsets < L 

41 

42 k_area = kH * kW 

43 

44 c_idx = row_offsets // k_area 

45 rem = row_offsets % k_area 

46 kh_idx = rem // kW 

47 kw_idx = rem % kW 

48 

49 oh_vec = col_offsets // outW 

50 ow_vec = col_offsets % outW 

51 

52 # Broadcast to [BLOCK_M, BLOCK_N] 

53 oh = oh_vec[None, :] 

54 ow = ow_vec[None, :] 

55 kh = kh_idx[:, None] 

56 kw = kw_idx[:, None] 

57 c = c_idx[:, None] 

58 

59 ih = oh * sH - pH + kh * dH 

60 iw = ow * sW - pW + kw * dW 

61 

62 in_h = (ih >= 0) & (ih < H) 

63 in_w = (iw >= 0) & (iw < W) 

64 in_bounds = in_h & in_w 

65 

66 # Base offsets 

67 base_in = (n.to(tl.int64) * C * H * W).to(tl.int64) 

68 base_out = (n.to(tl.int64) * rows_total * L).to(tl.int64) 

69 

70 # Compute input pointers 

71 ptrs_in = ( 

72 x_ptr + base_in + ((c.to(tl.int64) * H + ih.to(tl.int64)) * W + iw.to(tl.int64)) 

73 ) 

74 

75 # Compute output pointers 

76 ptrs_out = ( 

77 out_ptr 

78 + base_out 

79 + (row_offsets[:, None].to(tl.int64) * L + col_offsets[None, :].to(tl.int64)) 

80 ) 

81 

82 mask = mask_rows[:, None] & mask_cols[None, :] & in_bounds 

83 

84 vals = tl.load(ptrs_in, mask=mask, other=0) 

85 tl.store(ptrs_out, vals, mask=(mask_rows[:, None] & mask_cols[None, :])) 

86 

87 

88def _parse_2tuple(x, name): 

89 if isinstance(x, int): 

90 return (x, x) 

91 if ( 

92 isinstance(x, (list, tuple)) 

93 and len(x) == 2 

94 and all(isinstance(v, int) for v in x) 

95 ): 

96 return (int(x[0]), int(x[1])) 

97 raise ValueError(f"{name} must be an int or a tuple/list of two ints") 

98 

99 

100def _compute_output_dims(H, W, kH, kW, dH, dW, pH, pW, sH, sW): 

101 outH = (H + 2 * pH - (dH * (kH - 1) + 1)) // sH + 1 

102 outW = (W + 2 * pW - (dW * (kW - 1) + 1)) // sW + 1 

103 return outH, outW 

104 

105 

106def _launch_im2col_kernel(x, out, kH, kW, dH, dW, pH, pW, sH, sW): 

107 assert x.is_cuda and out.is_cuda, "Inputs must be CUDA tensors" 

108 x = x.contiguous() 

109 out = out.contiguous() 

110 

111 N, C, H, W = x.shape 

112 outH, outW = _compute_output_dims(H, W, kH, kW, dH, dW, pH, pW, sH, sW) 

113 rows_total = C * kH * kW 

114 L = outH * outW 

115 

116 if rows_total == 0 or L == 0 or N == 0: 

117 return # Nothing to do 

118 

119 BLOCK_M = 64 

120 BLOCK_N = 128 

121 

122 num_row_tiles = triton.cdiv(rows_total, BLOCK_M) 

123 num_col_tiles = triton.cdiv(L, BLOCK_N) 

124 grid = (N * num_row_tiles, num_col_tiles) 

125 

126 im2col_kernel[grid]( 

127 x, 

128 out, 

129 N, 

130 C, 

131 H, 

132 W, 

133 kH, 

134 kW, 

135 dH, 

136 dW, 

137 pH, 

138 pW, 

139 sH, 

140 sW, 

141 outH, 

142 outW, 

143 rows_total, 

144 L, 

145 num_row_tiles, 

146 BLOCK_M=BLOCK_M, 

147 BLOCK_N=BLOCK_N, 

148 num_warps=4, 

149 num_stages=2, 

150 ) 

151 

152 

153def im2col(input, kernel_size, dilation=1, padding=0, stride=1): 

154 x = input 

155 if x.ndim == 3: 

156 x = x.unsqueeze(0) 

157 if x.ndim != 4: 

158 raise ValueError("im2col expects input of shape (N, C, H, W) or (C, H, W)") 

159 kH, kW = _parse_2tuple(kernel_size, "kernel_size") 

160 dH, dW = _parse_2tuple(dilation, "dilation") 

161 pH, pW = _parse_2tuple(padding, "padding") 

162 sH, sW = _parse_2tuple(stride, "stride") 

163 

164 N, C, H, W = x.shape 

165 outH, outW = _compute_output_dims(H, W, kH, kW, dH, dW, pH, pW, sH, sW) 

166 rows_total = C * kH * kW 

167 L = outH * outW 

168 

169 out = torch.empty((N, rows_total, L), device=x.device, dtype=x.dtype) 

170 if L == 0 or rows_total == 0 or N == 0: 

171 return out if input.ndim == 4 else out.squeeze(0) 

172 

173 _launch_im2col_kernel(x, out, kH, kW, dH, dW, pH, pW, sH, sW) 

174 return out if input.ndim == 4 else out.squeeze(0) 

175 

176 

177def im2col_out(input, kernel_size, dilation=1, padding=0, stride=1, out=None): 

178 x = input 

179 if x.ndim == 3: 

180 x = x.unsqueeze(0) 

181 if x.ndim != 4: 

182 raise ValueError("im2col_out expects input of shape (N, C, H, W) or (C, H, W)") 

183 kH, kW = _parse_2tuple(kernel_size, "kernel_size") 

184 dH, dW = _parse_2tuple(dilation, "dilation") 

185 pH, pW = _parse_2tuple(padding, "padding") 

186 sH, sW = _parse_2tuple(stride, "stride") 

187 

188 N, C, H, W = x.shape 

189 outH, outW = _compute_output_dims(H, W, kH, kW, dH, dW, pH, pW, sH, sW) 

190 rows_total = C * kH * kW 

191 L = outH * outW 

192 

193 if out is None: 

194 out = torch.empty((N, rows_total, L), device=x.device, dtype=x.dtype) 

195 else: 

196 if out.ndim == 2 and N == 1: 

197 # Allow (C*kH*kW, L) for single batch for convenience 

198 expected = (rows_total, L) 

199 else: 

200 expected = (N, rows_total, L) 

201 if tuple(out.shape) != expected: 

202 raise ValueError(f"out has shape {tuple(out.shape)}, expected {expected}") 

203 if out.device != x.device or out.dtype != x.dtype: 

204 raise ValueError("out must have same device and dtype as input") 

205 

206 if L == 0 or rows_total == 0 or N == 0: 

207 return out 

208 

209 # If out was provided as 2D for N=1, make it 3D view for kernel, then restore 

210 squeeze_after = False 

211 if out.ndim == 2 and N == 1: 

212 out_3d = out.view(1, rows_total, L) 

213 squeeze_after = True 

214 else: 

215 out_3d = out 

216 

217 _launch_im2col_kernel(x, out_3d, kH, kW, dH, dW, pH, pW, sH, sW) 

218 

219 return out_3d.view(rows_total, L) if squeeze_after else out