Use in_channels for depthwise groups, allows using `out_channels=N * in_channels` (does not impact existing models). Fix #354.
parent
9811e229f7
commit
1bcc69e0ad
|
@ -22,7 +22,8 @@ def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
|
|||
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
|
||||
else:
|
||||
depthwise = kwargs.pop('depthwise', False)
|
||||
groups = out_channels if depthwise else kwargs.pop('groups', 1)
|
||||
# for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
|
||||
groups = in_channels if depthwise else kwargs.pop('groups', 1)
|
||||
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
||||
m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
|
||||
else:
|
||||
|
|
|
@ -34,7 +34,7 @@ class MixedConv2d(nn.ModuleDict):
|
|||
self.in_channels = sum(in_splits)
|
||||
self.out_channels = sum(out_splits)
|
||||
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||
conv_groups = out_ch if depthwise else 1
|
||||
conv_groups = in_ch if depthwise else 1
|
||||
# use add_module to keep key space clean
|
||||
self.add_module(
|
||||
str(idx),
|
||||
|
|
Loading…
Reference in New Issue