Coverage for src/flag_gems/ops/conv1d.py: 89%
27 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2import math
4from flag_gems.ops.conv2d import conv2d
6logger = logging.getLogger(__name__)
9def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
10 logger.debug("GEMS CONV1D")
11 if isinstance(stride, (list, tuple)):
12 stride_width = stride[0]
13 else:
14 stride_width = stride
16 if isinstance(dilation, (list, tuple)):
17 dilation_width = dilation[0]
18 else:
19 dilation_width = dilation
21 if isinstance(padding, str):
22 if padding == "same":
23 assert (
24 stride == 1
25 ), "Doesn't support any stride values other than 1 \
26 in padding = 'same' mode, received stride value {stride}"
27 il = input.shape[-1]
28 kernel_size = weight.shape[-1]
29 padding_width = math.ceil(
30 (stride_width * (il - 1) + 1 + dilation_width * (kernel_size - 1) - il)
31 / 2
32 )
33 ol = int(
34 (il + 2 * padding_width - dilation_width * (kernel_size - 1) - 1)
35 / stride_width
36 + 1
37 )
38 return conv2d(
39 input.unsqueeze(-1),
40 weight.unsqueeze(-1),
41 bias,
42 (stride_width, 1),
43 (padding_width, 0),
44 (dilation_width, 1),
45 groups,
46 ).squeeze(-1)[..., (ol - il) :]
47 elif padding == "valid":
48 padding_width = 0
49 else:
50 raise ValueError(
51 f"Unsupported padding mode: {padding}, only 'valid' or 'same' are allowed."
52 )
53 elif isinstance(padding, (list, tuple)):
54 padding_width = padding[0]
55 else:
56 padding_width = padding
57 return conv2d(
58 input.unsqueeze(-1),
59 weight.unsqueeze(-1),
60 bias,
61 (stride_width, 1),
62 (padding_width, 0),
63 (dilation_width, 1),
64 groups,
65 ).squeeze(-1)