Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/conv1d.py: 0%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2import math 

3 

4from .conv2d import conv2d 

5 

6logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

7 

8 

9def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 

10 logger.debug("GEMS CONV1D") 

11 if isinstance(stride, (list, tuple)): 

12 stride_width = stride[0] 

13 else: 

14 stride_width = stride 

15 

16 if isinstance(dilation, (list, tuple)): 

17 dilation_width = dilation[0] 

18 else: 

19 dilation_width = dilation 

20 

21 if isinstance(padding, str): 

22 if padding == "same": 

23 assert stride == 1, ( 

24 f"Doesn't support any stride values other than 1 in padding = 'same' mode, " 

25 f"received stride value {stride}" 

26 ) 

27 il = input.shape[-1] 

28 kernel_size = weight.shape[-1] 

29 padding_width = math.ceil( 

30 (stride_width * (il - 1) + 1 + dilation_width * (kernel_size - 1) - il) 

31 / 2 

32 ) 

33 ol = int( 

34 (il + 2 * padding_width - dilation_width * (kernel_size - 1) - 1) 

35 / stride_width 

36 + 1 

37 ) 

38 return conv2d( 

39 input.unsqueeze(-1), 

40 weight.unsqueeze(-1), 

41 bias, 

42 (stride_width, 1), 

43 (padding_width, 0), 

44 (dilation_width, 1), 

45 groups, 

46 ).squeeze(-1)[..., (ol - il) :] 

47 elif padding == "valid": 

48 # For "valid" padding, pass the string directly to conv2d 

49 # conv2d will handle it properly in its own logic 

50 return conv2d( 

51 input.unsqueeze(-1), 

52 weight.unsqueeze(-1), 

53 bias, 

54 (stride_width, 1), 

55 padding, # Pass string "valid" directly 

56 (dilation_width, 1), 

57 groups, 

58 ).squeeze(-1) 

59 else: 

60 raise ValueError( 

61 f"Unsupported padding string: {padding}, only 'valid'/'same' are allowed." 

62 ) 

63 elif isinstance(padding, (list, tuple)): 

64 padding_width = padding[0] 

65 else: 

66 padding_width = padding 

67 return conv2d( 

68 input.unsqueeze(-1), 

69 weight.unsqueeze(-1), 

70 bias, 

71 (stride_width, 1), 

72 (padding_width, 0), 

73 (dilation_width, 1), 

74 groups, 

75 ).squeeze(-1)