Coverage for src/flag_gems/experimental_ops/replication_pad2d.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def replication_pad2d_kernel(
8 in_ptr, # *Pointer* to input tensor
9 out_ptr, # *Pointer* to output tensor
10 N,
11 C,
12 H,
13 W, # input dimensions
14 OH,
15 OW, # output H and W
16 PAD_LEFT,
17 PAD_TOP, # padding sizes
18 TOTAL_ELEMS, # total number of output elements
19 BLOCK_SIZE: tl.constexpr,
20):
21 pid = tl.program_id(axis=0)
22 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
23 mask = offs < TOTAL_ELEMS
25 # Cast to int64 for safe indexing
26 offs64 = offs.to(tl.int64)
28 OW_i64 = tl.full([1], OW, dtype=tl.int64)
29 OH_i64 = tl.full([1], OH, dtype=tl.int64)
30 C_i64 = tl.full([1], C, dtype=tl.int64)
31 W_i64 = tl.full([1], W, dtype=tl.int64)
32 H_i64 = tl.full([1], H, dtype=tl.int64)
33 PAD_LEFT_i64 = tl.full([1], PAD_LEFT, dtype=tl.int64)
34 PAD_TOP_i64 = tl.full([1], PAD_TOP, dtype=tl.int64)
36 ow = offs64 % OW_i64
37 tmp = offs64 // OW_i64
38 oh = tmp % OH_i64
39 tmp = tmp // OH_i64
40 c = tmp % C_i64
41 n = tmp // C_i64
43 ih = oh - PAD_TOP_i64
44 iw = ow - PAD_LEFT_i64
46 zero = tl.full([1], 0, dtype=tl.int64)
47 Hm1 = H_i64 - 1
48 Wm1 = W_i64 - 1
50 ih = tl.maximum(zero, tl.minimum(Hm1, ih))
51 iw = tl.maximum(zero, tl.minimum(Wm1, iw))
53 in_index = ((n * C_i64 + c) * H_i64 + ih) * W_i64 + iw
54 out_index = offs64
56 x = tl.load(in_ptr + in_index, mask=mask)
57 tl.store(out_ptr + out_index, x, mask=mask)
60def _prepare_dims_and_out(input: torch.Tensor, padding, out: torch.Tensor | None):
61 if not isinstance(padding, (tuple, list)) or len(padding) != 4:
62 raise ValueError(
63 "padding must be a sequence of 4 integers: (pad_left, pad_right, pad_top, pad_bottom)"
64 )
65 pad_left, pad_right, pad_top, pad_bottom = map(int, padding)
66 if pad_left < 0 or pad_right < 0 or pad_top < 0 or pad_bottom < 0:
67 raise ValueError("replication_pad2d does not support negative padding")
69 if input.dim() == 4:
70 N, C, H, W = input.shape
71 out_shape = (N, C, H + pad_top + pad_bottom, W + pad_left + pad_right)
72 kernel_N, kernel_C = N, C
73 elif input.dim() == 3:
74 C, H, W = input.shape
75 out_shape = (C, H + pad_top + pad_bottom, W + pad_left + pad_right)
76 kernel_N, kernel_C = 1, C
77 else:
78 raise ValueError(
79 "replication_pad2d expects a 3D (C, H, W) or 4D (N, C, H, W) input"
80 )
82 if H <= 0 or W <= 0:
83 raise ValueError(
84 "Input height and width must be greater than 0 for replication padding"
85 )
87 if out is None:
88 out = torch.empty(out_shape, device=input.device, dtype=input.dtype)
89 else:
90 if tuple(out.shape) != tuple(out_shape):
91 raise ValueError(
92 f"Provided out tensor has shape {tuple(out.shape)}, expected {out_shape}"
93 )
94 if out.device != input.device:
95 raise ValueError("Input and out must be on the same device")
96 if out.dtype != input.dtype:
97 raise ValueError("Input and out must have the same dtype")
99 return (
100 kernel_N,
101 kernel_C,
102 H,
103 W,
104 out.shape[-2],
105 out.shape[-1],
106 pad_left,
107 pad_top,
108 ), out
111def _launch_replication_pad2d_kernel(
112 input: torch.Tensor, out: torch.Tensor, kernel_params
113):
114 if not input.is_cuda or not out.is_cuda:
115 raise ValueError("Tensors must be CUDA tensors")
116 if not input.is_contiguous() or not out.is_contiguous():
117 raise ValueError("Only contiguous tensors are supported")
119 N, C, H, W, OH, OW, pad_left, pad_top = kernel_params
120 total_elems = out.numel()
121 if total_elems == 0:
122 return out
124 BLOCK_SIZE = 1024
125 grid = (triton.cdiv(total_elems, BLOCK_SIZE),)
127 replication_pad2d_kernel[grid](
128 input,
129 out,
130 N,
131 C,
132 H,
133 W,
134 OH,
135 OW,
136 pad_left,
137 pad_top,
138 total_elems,
139 BLOCK_SIZE=BLOCK_SIZE,
140 )
141 return out
144def replication_pad2d(input: torch.Tensor, padding):
145 kernel_params, out = _prepare_dims_and_out(input, padding, out=None)
146 return _launch_replication_pad2d_kernel(input, out, kernel_params)
149def replication_pad2d_out(input: torch.Tensor, padding, out: torch.Tensor):
150 kernel_params, out = _prepare_dims_and_out(input, padding, out=out)
151 return _launch_replication_pad2d_kernel(input, out, kernel_params)