Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/conv1d.py: 0%
27 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2import math
4from .conv2d import conv2d
6logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
9def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
10 logger.debug("GEMS CONV1D")
11 if isinstance(stride, (list, tuple)):
12 stride_width = stride[0]
13 else:
14 stride_width = stride
16 if isinstance(dilation, (list, tuple)):
17 dilation_width = dilation[0]
18 else:
19 dilation_width = dilation
21 if isinstance(padding, str):
22 if padding == "same":
23 assert stride == 1, (
24 f"Doesn't support any stride values other than 1 in padding = 'same' mode, "
25 f"received stride value {stride}"
26 )
27 il = input.shape[-1]
28 kernel_size = weight.shape[-1]
29 padding_width = math.ceil(
30 (stride_width * (il - 1) + 1 + dilation_width * (kernel_size - 1) - il)
31 / 2
32 )
33 ol = int(
34 (il + 2 * padding_width - dilation_width * (kernel_size - 1) - 1)
35 / stride_width
36 + 1
37 )
38 return conv2d(
39 input.unsqueeze(-1),
40 weight.unsqueeze(-1),
41 bias,
42 (stride_width, 1),
43 (padding_width, 0),
44 (dilation_width, 1),
45 groups,
46 ).squeeze(-1)[..., (ol - il) :]
47 elif padding == "valid":
48 # For "valid" padding, pass the string directly to conv2d
49 # conv2d will handle it properly in its own logic
50 return conv2d(
51 input.unsqueeze(-1),
52 weight.unsqueeze(-1),
53 bias,
54 (stride_width, 1),
55 padding, # Pass string "valid" directly
56 (dilation_width, 1),
57 groups,
58 ).squeeze(-1)
59 else:
60 raise ValueError(
61 f"Unsupported padding string: {padding}, only 'valid'/'same' are allowed."
62 )
63 elif isinstance(padding, (list, tuple)):
64 padding_width = padding[0]
65 else:
66 padding_width = padding
67 return conv2d(
68 input.unsqueeze(-1),
69 weight.unsqueeze(-1),
70 bias,
71 (stride_width, 1),
72 (padding_width, 0),
73 (dilation_width, 1),
74 groups,
75 ).squeeze(-1)