Coverage for src/flag_gems/experimental_ops/im2col.py: 0%
114 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def im2col_kernel(
8 x_ptr, # *Pointer* to input tensor [N, C, H, W]
9 out_ptr, # *Pointer* to output tensor [N, C*kH*kW, outH*outW]
10 N,
11 C,
12 H,
13 W,
14 kH,
15 kW,
16 dH,
17 dW,
18 pH,
19 pW,
20 sH,
21 sW,
22 outH,
23 outW,
24 rows_total, # C * kH * kW
25 L, # outH * outW
26 num_row_tiles, # ceil_div(rows_total, BLOCK_M)
27 BLOCK_M: tl.constexpr,
28 BLOCK_N: tl.constexpr,
29):
30 pid0 = tl.program_id(0)
31 pid1 = tl.program_id(1)
33 n = pid0 // num_row_tiles
34 row_tile = pid0 % num_row_tiles
36 row_offsets = row_tile * BLOCK_M + tl.arange(0, BLOCK_M)
37 col_offsets = pid1 * BLOCK_N + tl.arange(0, BLOCK_N)
39 mask_rows = row_offsets < rows_total
40 mask_cols = col_offsets < L
42 k_area = kH * kW
44 c_idx = row_offsets // k_area
45 rem = row_offsets % k_area
46 kh_idx = rem // kW
47 kw_idx = rem % kW
49 oh_vec = col_offsets // outW
50 ow_vec = col_offsets % outW
52 # Broadcast to [BLOCK_M, BLOCK_N]
53 oh = oh_vec[None, :]
54 ow = ow_vec[None, :]
55 kh = kh_idx[:, None]
56 kw = kw_idx[:, None]
57 c = c_idx[:, None]
59 ih = oh * sH - pH + kh * dH
60 iw = ow * sW - pW + kw * dW
62 in_h = (ih >= 0) & (ih < H)
63 in_w = (iw >= 0) & (iw < W)
64 in_bounds = in_h & in_w
66 # Base offsets
67 base_in = (n.to(tl.int64) * C * H * W).to(tl.int64)
68 base_out = (n.to(tl.int64) * rows_total * L).to(tl.int64)
70 # Compute input pointers
71 ptrs_in = (
72 x_ptr + base_in + ((c.to(tl.int64) * H + ih.to(tl.int64)) * W + iw.to(tl.int64))
73 )
75 # Compute output pointers
76 ptrs_out = (
77 out_ptr
78 + base_out
79 + (row_offsets[:, None].to(tl.int64) * L + col_offsets[None, :].to(tl.int64))
80 )
82 mask = mask_rows[:, None] & mask_cols[None, :] & in_bounds
84 vals = tl.load(ptrs_in, mask=mask, other=0)
85 tl.store(ptrs_out, vals, mask=(mask_rows[:, None] & mask_cols[None, :]))
88def _parse_2tuple(x, name):
89 if isinstance(x, int):
90 return (x, x)
91 if (
92 isinstance(x, (list, tuple))
93 and len(x) == 2
94 and all(isinstance(v, int) for v in x)
95 ):
96 return (int(x[0]), int(x[1]))
97 raise ValueError(f"{name} must be an int or a tuple/list of two ints")
100def _compute_output_dims(H, W, kH, kW, dH, dW, pH, pW, sH, sW):
101 outH = (H + 2 * pH - (dH * (kH - 1) + 1)) // sH + 1
102 outW = (W + 2 * pW - (dW * (kW - 1) + 1)) // sW + 1
103 return outH, outW
106def _launch_im2col_kernel(x, out, kH, kW, dH, dW, pH, pW, sH, sW):
107 assert x.is_cuda and out.is_cuda, "Inputs must be CUDA tensors"
108 x = x.contiguous()
109 out = out.contiguous()
111 N, C, H, W = x.shape
112 outH, outW = _compute_output_dims(H, W, kH, kW, dH, dW, pH, pW, sH, sW)
113 rows_total = C * kH * kW
114 L = outH * outW
116 if rows_total == 0 or L == 0 or N == 0:
117 return # Nothing to do
119 BLOCK_M = 64
120 BLOCK_N = 128
122 num_row_tiles = triton.cdiv(rows_total, BLOCK_M)
123 num_col_tiles = triton.cdiv(L, BLOCK_N)
124 grid = (N * num_row_tiles, num_col_tiles)
126 im2col_kernel[grid](
127 x,
128 out,
129 N,
130 C,
131 H,
132 W,
133 kH,
134 kW,
135 dH,
136 dW,
137 pH,
138 pW,
139 sH,
140 sW,
141 outH,
142 outW,
143 rows_total,
144 L,
145 num_row_tiles,
146 BLOCK_M=BLOCK_M,
147 BLOCK_N=BLOCK_N,
148 num_warps=4,
149 num_stages=2,
150 )
153def im2col(input, kernel_size, dilation=1, padding=0, stride=1):
154 x = input
155 if x.ndim == 3:
156 x = x.unsqueeze(0)
157 if x.ndim != 4:
158 raise ValueError("im2col expects input of shape (N, C, H, W) or (C, H, W)")
159 kH, kW = _parse_2tuple(kernel_size, "kernel_size")
160 dH, dW = _parse_2tuple(dilation, "dilation")
161 pH, pW = _parse_2tuple(padding, "padding")
162 sH, sW = _parse_2tuple(stride, "stride")
164 N, C, H, W = x.shape
165 outH, outW = _compute_output_dims(H, W, kH, kW, dH, dW, pH, pW, sH, sW)
166 rows_total = C * kH * kW
167 L = outH * outW
169 out = torch.empty((N, rows_total, L), device=x.device, dtype=x.dtype)
170 if L == 0 or rows_total == 0 or N == 0:
171 return out if input.ndim == 4 else out.squeeze(0)
173 _launch_im2col_kernel(x, out, kH, kW, dH, dW, pH, pW, sH, sW)
174 return out if input.ndim == 4 else out.squeeze(0)
177def im2col_out(input, kernel_size, dilation=1, padding=0, stride=1, out=None):
178 x = input
179 if x.ndim == 3:
180 x = x.unsqueeze(0)
181 if x.ndim != 4:
182 raise ValueError("im2col_out expects input of shape (N, C, H, W) or (C, H, W)")
183 kH, kW = _parse_2tuple(kernel_size, "kernel_size")
184 dH, dW = _parse_2tuple(dilation, "dilation")
185 pH, pW = _parse_2tuple(padding, "padding")
186 sH, sW = _parse_2tuple(stride, "stride")
188 N, C, H, W = x.shape
189 outH, outW = _compute_output_dims(H, W, kH, kW, dH, dW, pH, pW, sH, sW)
190 rows_total = C * kH * kW
191 L = outH * outW
193 if out is None:
194 out = torch.empty((N, rows_total, L), device=x.device, dtype=x.dtype)
195 else:
196 if out.ndim == 2 and N == 1:
197 # Allow (C*kH*kW, L) for single batch for convenience
198 expected = (rows_total, L)
199 else:
200 expected = (N, rows_total, L)
201 if tuple(out.shape) != expected:
202 raise ValueError(f"out has shape {tuple(out.shape)}, expected {expected}")
203 if out.device != x.device or out.dtype != x.dtype:
204 raise ValueError("out must have same device and dtype as input")
206 if L == 0 or rows_total == 0 or N == 0:
207 return out
209 # If out was provided as 2D for N=1, make it 3D view for kernel, then restore
210 squeeze_after = False
211 if out.ndim == 2 and N == 1:
212 out_3d = out.view(1, rows_total, L)
213 squeeze_after = True
214 else:
215 out_3d = out
217 _launch_im2col_kernel(x, out_3d, kH, kW, dH, dW, pH, pW, sH, sW)
219 return out_3d.view(rows_total, L) if squeeze_after else out