Coverage for src/flag_gems/ops/conv1d.py: 89%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2import math 

3 

4from flag_gems.ops.conv2d import conv2d 

5 

6logger = logging.getLogger(__name__) 

7 

8 

9def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 

10 logger.debug("GEMS CONV1D") 

11 if isinstance(stride, (list, tuple)): 

12 stride_width = stride[0] 

13 else: 

14 stride_width = stride 

15 

16 if isinstance(dilation, (list, tuple)): 

17 dilation_width = dilation[0] 

18 else: 

19 dilation_width = dilation 

20 

21 if isinstance(padding, str): 

22 if padding == "same": 

23 assert ( 

24 stride == 1 

25 ), "Doesn't support any stride values other than 1 \ 

26 in padding = 'same' mode, received stride value {stride}" 

27 il = input.shape[-1] 

28 kernel_size = weight.shape[-1] 

29 padding_width = math.ceil( 

30 (stride_width * (il - 1) + 1 + dilation_width * (kernel_size - 1) - il) 

31 / 2 

32 ) 

33 ol = int( 

34 (il + 2 * padding_width - dilation_width * (kernel_size - 1) - 1) 

35 / stride_width 

36 + 1 

37 ) 

38 return conv2d( 

39 input.unsqueeze(-1), 

40 weight.unsqueeze(-1), 

41 bias, 

42 (stride_width, 1), 

43 (padding_width, 0), 

44 (dilation_width, 1), 

45 groups, 

46 ).squeeze(-1)[..., (ol - il) :] 

47 elif padding == "valid": 

48 padding_width = 0 

49 else: 

50 raise ValueError( 

51 f"Unsupported padding mode: {padding}, only 'valid' or 'same' are allowed." 

52 ) 

53 elif isinstance(padding, (list, tuple)): 

54 padding_width = padding[0] 

55 else: 

56 padding_width = padding 

57 return conv2d( 

58 input.unsqueeze(-1), 

59 weight.unsqueeze(-1), 

60 bias, 

61 (stride_width, 1), 

62 (padding_width, 0), 

63 (dilation_width, 1), 

64 groups, 

65 ).squeeze(-1)