Coverage for src/flag_gems/experimental_ops/replication_pad1d.py: 0%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def replication_pad1d_kernel(
8 in_ptr,
9 out_ptr,
10 B: tl.constexpr,
11 C: tl.constexpr,
12 W_in,
13 W_out,
14 pad_left,
15 in_stride_n,
16 in_stride_c,
17 in_stride_w,
18 out_stride_n,
19 out_stride_c,
20 out_stride_w,
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid_w = tl.program_id(axis=0)
24 pid_nc = tl.program_id(axis=1)
26 n = pid_nc // C
27 c = pid_nc % C
29 off_w = pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
30 mask = off_w < W_out
32 # Compute clamped source indices for replication pad
33 w_in = off_w - pad_left
34 w_in = tl.maximum(w_in, 0)
35 w_in = tl.minimum(w_in, W_in - 1)
37 # Base offsets
38 base_in = n.to(tl.int64) * in_stride_n + c.to(tl.int64) * in_stride_c
39 base_out = n.to(tl.int64) * out_stride_n + c.to(tl.int64) * out_stride_c
41 ptrs_in = in_ptr + base_in + w_in.to(tl.int64) * in_stride_w
42 ptrs_out = out_ptr + base_out + off_w.to(tl.int64) * out_stride_w
44 x = tl.load(ptrs_in, mask=mask, other=0)
45 tl.store(ptrs_out, x, mask=mask)
48def _launch_replication_pad1d_kernel(input: torch.Tensor, padding, out: torch.Tensor):
49 if not input.is_cuda or not out.is_cuda:
50 raise RuntimeError("Triton kernels require CUDA tensors")
52 if isinstance(padding, torch.Tensor):
53 padding = tuple(padding.tolist())
54 left, right = int(padding[0]), int(padding[1])
55 if left < 0 or right < 0:
56 raise ValueError("Padding values must be non-negative for replication_pad1d")
58 dim = input.dim()
59 if dim not in (2, 3):
60 raise ValueError("replication_pad1d expects 2D (C, W) or 3D (N, C, W) input")
62 if dim == 3:
63 N, C, W_in = input.shape
64 B = N
65 in_s_n, in_s_c, in_s_w = input.stride()
66 out_s_n, out_s_c, out_s_w = out.stride()
67 expected_out_shape = (N, C, W_in + left + right)
68 else:
69 C, W_in = input.shape
70 B = 1
71 N = 1 # dummy
72 in_s_c, in_s_w = input.stride()
73 in_s_n = 0
74 if out.dim() == 2:
75 out_s_c, out_s_w = out.stride()
76 out_s_n = 0
77 elif out.dim() == 3:
78 out_s_n, out_s_c, out_s_w = out.stride()
79 else:
80 raise ValueError("Output tensor has invalid dimensions")
81 expected_out_shape = (C, W_in + left + right)
83 W_out = W_in + left + right
85 # Validate output shape
86 if tuple(out.shape) != expected_out_shape:
87 raise ValueError(
88 f"Output tensor has incorrect shape. Expected {expected_out_shape}, got {tuple(out.shape)}"
89 )
91 grid = (triton.cdiv(W_out, 256), B * C)
92 replication_pad1d_kernel[grid](
93 input,
94 out,
95 B,
96 C,
97 W_in,
98 W_out,
99 left,
100 in_s_n if dim == 3 else in_s_n,
101 in_s_c,
102 in_s_w,
103 out_s_n if (dim == 3 or out.dim() == 3) else 0,
104 out_s_c,
105 out_s_w,
106 BLOCK_SIZE=256,
107 )
108 return out
111def replication_pad1d(input: torch.Tensor, padding):
112 if isinstance(padding, torch.Tensor):
113 padding = tuple(padding.tolist())
114 left, right = int(padding[0]), int(padding[1])
115 if input.dim() == 3:
116 N, C, W_in = input.shape
117 out = torch.empty(
118 (N, C, W_in + left + right),
119 device=input.device,
120 dtype=input.dtype,
121 layout=input.layout,
122 )
123 elif input.dim() == 2:
124 C, W_in = input.shape
125 out = torch.empty(
126 (C, W_in + left + right),
127 device=input.device,
128 dtype=input.dtype,
129 layout=input.layout,
130 )
131 else:
132 raise ValueError("replication_pad1d expects 2D (C, W) or 3D (N, C, W) input")
133 return _launch_replication_pad1d_kernel(input, (left, right), out)
136def replication_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor):
137 if isinstance(padding, torch.Tensor):
138 padding = tuple(padding.tolist())
139 left, right = int(padding[0]), int(padding[1])
141 # Validate dtype/device
142 if out.dtype != input.dtype:
143 raise ValueError("Output dtype must match input dtype")
144 if out.device != input.device:
145 raise ValueError("Output device must match input device")
147 return _launch_replication_pad1d_kernel(input, (left, right), out)