Coverage for src/flag_gems/ops/replication_pad1d.py: 66%
79 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def replication_pad1d_kernel(
15 in_ptr,
16 out_ptr,
17 B: tl.constexpr,
18 C: tl.constexpr,
19 W_in,
20 W_out,
21 pad_left,
22 in_stride_n,
23 in_stride_c,
24 in_stride_w,
25 out_stride_n,
26 out_stride_c,
27 out_stride_w,
28 BLOCK_SIZE: tl.constexpr,
29):
30 pid_w = tl.program_id(axis=0)
31 pid_nc = tl.program_id(axis=1)
33 n = pid_nc // C
34 c = pid_nc % C
36 off_w = pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
37 mask = off_w < W_out
39 # Compute clamped source indices for replication pad
40 w_in = off_w - pad_left
41 w_in = tl.maximum(w_in, 0)
42 w_in = tl.minimum(w_in, W_in - 1)
44 # Base offsets
45 base_in = n.to(tl.int64) * in_stride_n + c.to(tl.int64) * in_stride_c
46 base_out = n.to(tl.int64) * out_stride_n + c.to(tl.int64) * out_stride_c
48 ptrs_in = in_ptr + base_in + w_in.to(tl.int64) * in_stride_w
49 ptrs_out = out_ptr + base_out + off_w.to(tl.int64) * out_stride_w
51 x = tl.load(ptrs_in, mask=mask, other=0)
52 tl.store(ptrs_out, x, mask=mask)
55def _launch_replication_pad1d_kernel(input: torch.Tensor, padding, out: torch.Tensor):
56 if isinstance(padding, torch.Tensor):
57 padding = tuple(padding.tolist())
58 left, right = int(padding[0]), int(padding[1])
59 if left < 0 or right < 0:
60 raise ValueError("Padding values must be non-negative for replication_pad1d")
62 dim = input.dim()
63 if dim not in (2, 3):
64 raise ValueError("replication_pad1d expects 2D (C, W) or 3D (N, C, W) input")
66 if dim == 3:
67 N, C, W_in = input.shape
68 B = N
69 in_s_n, in_s_c, in_s_w = input.stride()
70 out_s_n, out_s_c, out_s_w = out.stride()
71 expected_out_shape = (N, C, W_in + left + right)
72 else:
73 C, W_in = input.shape
74 B = 1
75 in_s_c, in_s_w = input.stride()
76 in_s_n = 0
77 if out.dim() == 2:
78 out_s_c, out_s_w = out.stride()
79 out_s_n = 0
80 elif out.dim() == 3:
81 out_s_n, out_s_c, out_s_w = out.stride()
82 else:
83 raise ValueError("Output tensor has invalid dimensions")
84 expected_out_shape = (C, W_in + left + right)
86 W_out = W_in + left + right
88 # Validate output shape
89 if tuple(out.shape) != expected_out_shape:
90 raise ValueError(
91 f"Output tensor has incorrect shape. Expected {expected_out_shape}, got {tuple(out.shape)}"
92 )
94 grid = (triton.cdiv(W_out, 256), B * C)
95 with torch_device_fn.device(input.device):
96 replication_pad1d_kernel[grid](
97 input,
98 out,
99 B,
100 C,
101 W_in,
102 W_out,
103 left,
104 in_s_n if dim == 3 else in_s_n,
105 in_s_c,
106 in_s_w,
107 out_s_n if (dim == 3 or out.dim() == 3) else 0,
108 out_s_c,
109 out_s_w,
110 BLOCK_SIZE=256,
111 )
112 return out
115def replication_pad1d(input: torch.Tensor, padding):
116 logger.debug("GEMS REPLICATION_PAD1D")
117 if isinstance(padding, torch.Tensor):
118 padding = tuple(padding.tolist())
119 left, right = int(padding[0]), int(padding[1])
120 if input.dim() == 3:
121 N, C, W_in = input.shape
122 out = torch.empty(
123 (N, C, W_in + left + right),
124 device=input.device,
125 dtype=input.dtype,
126 layout=input.layout,
127 )
128 elif input.dim() == 2:
129 C, W_in = input.shape
130 out = torch.empty(
131 (C, W_in + left + right),
132 device=input.device,
133 dtype=input.dtype,
134 layout=input.layout,
135 )
136 else:
137 raise ValueError("replication_pad1d expects 2D (C, W) or 3D (N, C, W) input")
138 return _launch_replication_pad1d_kernel(input, (left, right), out)
141def replication_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor):
142 logger.debug("GEMS REPLICATION_PAD1D_OUT")
143 if isinstance(padding, torch.Tensor):
144 padding = tuple(padding.tolist())
145 left, right = int(padding[0]), int(padding[1])
147 # Validate dtype/device
148 if out.dtype != input.dtype:
149 raise ValueError("Output dtype must match input dtype")
150 if out.device != input.device:
151 raise ValueError("Output device must match input device")
153 return _launch_replication_pad1d_kernel(input, (left, right), out)