Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/conv_depthwise2d.py: 0%

10 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2 

3from .conv2d import conv2d 

4 

5logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

6 

7 

8def _conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation): 

9 logger.debug("GEMS DEPTHWISE") 

10 assert ( 

11 input.ndim == 4 

12 ), "Invalid input tensor must be 4D, recevied shape {input.shape}" 

13 assert ( 

14 weight.shape[0] % input.shape[1] == 0 

15 ), "Output channels must be multiple of input, recevied output {weught.shape[0], input {input.shape[0]}}" 

16 assert ( 

17 weight.shape[1] == 1 

18 ), "input channels of per goups must be 1, recevied {weight.shape[1]}" 

19 groups = input.shape[1] 

20 return conv2d(input, weight, bias, stride, padding, dilation, groups)