Coverage for src/flag_gems/ops/conv_depthwise2d.py: 40%
10 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3from flag_gems.ops.conv2d import conv2d
5logger = logging.getLogger(__name__)
8def _conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation):
9 logger.debug("GEMS DEPTHWISE")
10 assert (
11 input.ndim == 4
12 ), "Invalid input tensor must be 4D, recevied shape {input.shape}"
13 assert (
14 weight.shape[0] % input.shape[1] == 0
15 ), "Output channels must be multiple of input, recevied output {weught.shape[0], input {input.shape[0]}}"
16 assert (
17 weight.shape[1] == 1
18 ), "input channels of per goups must be 1, recevied {weight.shape[1]}"
19 groups = input.shape[1]
20 return conv2d(input, weight, bias, stride, padding, dilation, groups)