Ross Wightman f902bcd54c Layer refactoring continues, ResNet downsample rewrite for proper dilation in 3x3 and avg_pool cases
* select_conv2d -> create_conv2d
* added create_attn to create attention module from string/bool/module
* factor padding helpers into own file, use in both conv2d_same and avg_pool2d_same
* add some more test eca resnet variants
* minor tweaks, naming, comments, consistency
2020-02-10 11:55:03 -08:00

31 lines
1.3 KiB
Python

""" Create Conv2d Factory Method
Hacked together by Ross Wightman
"""
from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d
from .conv2d_same import create_conv2d_pad
def create_conv2d(in_chs, out_chs, kernel_size, **kwargs):
""" Select a 2d convolution implementation based on arguments
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
Used extensively by EfficientNet, MobileNetv3 and related networks.
"""
assert 'groups' not in kwargs # only use 'depthwise' bool arg
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_chs if depthwise else 1
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
return m