mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Use in_channels for depthwise groups, allows using out_channels=N * in_channels
(does not impact existing models). Fix #354.
This commit is contained in:
parent
9811e229f7
commit
1bcc69e0ad
@ -22,7 +22,8 @@ def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
|
|||||||
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
|
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
|
||||||
else:
|
else:
|
||||||
depthwise = kwargs.pop('depthwise', False)
|
depthwise = kwargs.pop('depthwise', False)
|
||||||
groups = out_channels if depthwise else kwargs.pop('groups', 1)
|
# for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
|
||||||
|
groups = in_channels if depthwise else kwargs.pop('groups', 1)
|
||||||
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
||||||
m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
|
m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
@ -34,7 +34,7 @@ class MixedConv2d(nn.ModuleDict):
|
|||||||
self.in_channels = sum(in_splits)
|
self.in_channels = sum(in_splits)
|
||||||
self.out_channels = sum(out_splits)
|
self.out_channels = sum(out_splits)
|
||||||
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||||
conv_groups = out_ch if depthwise else 1
|
conv_groups = in_ch if depthwise else 1
|
||||||
# use add_module to keep key space clean
|
# use add_module to keep key space clean
|
||||||
self.add_module(
|
self.add_module(
|
||||||
str(idx),
|
str(idx),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user