Coverage for src/flag_gems/ops/reflection_pad2d.py: 46%
92 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def reflection_pad2d_kernel(
13 in_ptr,
14 out_ptr,
15 B,
16 H_in,
17 W_in,
18 pad_left,
19 pad_top,
20 H_out,
21 W_out,
22 BLOCK_HW: tl.constexpr,
23):
24 pid_b = tl.program_id(axis=0)
25 pid_n = tl.program_id(axis=1)
27 # Flatten 2D index to 1D for block processing
28 offs_n = pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW)
29 # Decode to (h, w) coordinates
30 h_idx = offs_n // W_out
31 w_idx = offs_n % W_out
33 mask = (offs_n < H_out * W_out) & (pid_b < B)
35 base_in = pid_b * (H_in * W_in)
36 base_out = pid_b * (H_out * W_out)
38 # Compute reflected indices for height
39 y = h_idx.to(tl.int32) - pad_top
40 Hm1 = H_in - 1
41 pH = 2 * Hm1
42 t_h = tl.abs(y)
43 m_h = t_h % pH
44 ih = tl.where(m_h < H_in, m_h, pH - m_h)
46 # Compute reflected indices for width
47 x = w_idx.to(tl.int32) - pad_left
48 Wm1 = W_in - 1
49 pW = 2 * Wm1
50 t_w = tl.abs(x)
51 m_w = t_w % pW
52 iw = tl.where(m_w < W_in, m_w, pW - m_w)
54 # Load from input and store to output
55 in_offs = ih * W_in + iw
56 vals = tl.load(in_ptr + base_in + in_offs, mask=mask, other=0)
57 tl.store(out_ptr + base_out + offs_n, vals, mask=mask)
60@triton.jit
61def copy_tensor_kernel(in_ptr, out_ptr, B, H, W, BLOCK_HW: tl.constexpr):
62 pid_b = tl.program_id(axis=0)
63 pid_n = tl.program_id(axis=1)
65 offs_n = pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW)
66 mask = (offs_n < H * W) & (pid_b < B)
68 base = pid_b * (H * W)
69 vals = tl.load(in_ptr + base + offs_n, mask=mask, other=0)
70 tl.store(out_ptr + base + offs_n, vals, mask=mask)
73def launch_reflection_pad2d(input: torch.Tensor, padding, out: torch.Tensor = None):
74 # Validate padding format
75 if not isinstance(padding, (list, tuple)):
76 raise ValueError("padding must be a sequence")
77 if len(padding) != 4:
78 raise ValueError(
79 "padding must be a sequence of length 4: (pad_left, pad_right, pad_top, pad_bottom)"
80 )
81 pad_left, pad_right, pad_top, pad_bottom = [int(p) for p in padding]
83 # Validate padding values
84 if pad_left < 0 or pad_right < 0 or pad_top < 0 or pad_bottom < 0:
85 raise ValueError("padding values must be >= 0")
87 # Validate input
88 if input.dim() < 3:
89 raise ValueError("input must have at least 3 dimensions")
90 if not input.is_cuda:
91 raise ValueError("input must be a CUDA tensor")
93 x = input.contiguous()
94 H_in = int(x.shape[-2])
95 W_in = int(x.shape[-1])
96 # Validate reflection padding constraints
97 if H_in < 2 or W_in < 2:
98 raise ValueError(
99 "input spatial dimensions must be at least 2 for reflection padding when padding > 0"
100 )
101 if H_in <= 0 or W_in <= 0:
102 raise ValueError("spatial dimensions must be > 0")
103 if pad_left >= W_in or pad_right >= W_in or pad_top >= H_in or pad_bottom >= H_in:
104 raise ValueError(
105 "padding values must be less than the input spatial dimensions for reflection padding"
106 )
108 H_out = H_in + pad_top + pad_bottom
109 W_out = W_in + pad_left + pad_right
111 leading_shape = x.shape[:-2]
112 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1
114 # Handle output tensor
115 if out is None:
116 out = torch.empty(
117 (*leading_shape, H_out, W_out), device=x.device, dtype=x.dtype
118 )
119 else:
120 if not out.is_cuda:
121 raise ValueError("out must be a CUDA tensor")
122 expected_shape = (*leading_shape, H_out, W_out)
123 if tuple(out.shape) != expected_shape:
124 raise ValueError(
125 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}"
126 )
127 if out.dtype != x.dtype:
128 raise ValueError(
129 f"out dtype {out.dtype} does not match input dtype {x.dtype}"
130 )
131 if out.device != x.device:
132 raise ValueError("out must be on the same device as input")
133 out = out.contiguous()
135 # No padding: just copy
136 if pad_left == 0 and pad_right == 0 and pad_top == 0 and pad_bottom == 0:
137 BLOCK_HW = 256
138 grid = (B, triton.cdiv(H_in * W_in, BLOCK_HW))
139 copy_tensor_kernel[grid](x, out, B, H_in, W_in, BLOCK_HW=BLOCK_HW)
140 return out
142 BLOCK_HW = 256
143 grid = (B, triton.cdiv(H_out * W_out, BLOCK_HW))
144 reflection_pad2d_kernel[grid](
145 x, out, B, H_in, W_in, pad_left, pad_top, H_out, W_out, BLOCK_HW=BLOCK_HW
146 )
147 return out
150def reflection_pad2d(input: torch.Tensor, padding):
151 logger.debug("GEMS REFLECTION_PAD2D")
152 return launch_reflection_pad2d(input, padding, out=None)
155def reflection_pad2d_out(input: torch.Tensor, padding, out: torch.Tensor):
156 logger.debug("GEMS REFLECTION_PAD2D_OUT")
157 return launch_reflection_pad2d(input, padding, out=out)