mirror of https://github.com/open-mmlab/mmcv.git
Add building bricks of cnn (#247)
* add building bricks of cnn * add unit tests * use registry for building bricks * minor updates * add scale layer * add test for scale * add doc string Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>pull/261/head
parent
89f709e8e7
commit
45111e193d
|
@ -1,5 +1,10 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from .alexnet import AlexNet
|
||||
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
||||
PADDING_LAYERS, UPSAMPLE_LAYERS, ConvModule, Scale,
|
||||
build_activation_layer, build_conv_layer,
|
||||
build_norm_layer, build_padding_layer,
|
||||
build_upsample_layer)
|
||||
from .resnet import ResNet, make_res_layer
|
||||
from .vgg import VGG, make_vgg_layer
|
||||
from .weight_init import (bias_init_with_prob, caffe2_xavier_init,
|
||||
|
@ -9,5 +14,8 @@ from .weight_init import (bias_init_with_prob, caffe2_xavier_init,
|
|||
__all__ = [
|
||||
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
|
||||
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
|
||||
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob'
|
||||
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
|
||||
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
|
||||
'build_padding_layer', 'build_upsample_layer', 'ACTIVATION_LAYERS',
|
||||
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
from .activation import build_activation_layer
|
||||
from .conv import build_conv_layer
|
||||
from .conv_module import ConvModule
|
||||
from .norm import build_norm_layer
|
||||
from .padding import build_padding_layer
|
||||
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
||||
PADDING_LAYERS, UPSAMPLE_LAYERS)
|
||||
from .scale import Scale
|
||||
from .upsample import build_upsample_layer
|
||||
|
||||
__all__ = [
|
||||
'ConvModule', 'build_activation_layer', 'build_conv_layer',
|
||||
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
|
||||
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
|
||||
'UPSAMPLE_LAYERS', 'Scale'
|
||||
]
|
|
@ -0,0 +1,24 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from mmcv.utils import build_from_cfg
|
||||
from .registry import ACTIVATION_LAYERS
|
||||
|
||||
for module in [
|
||||
nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
|
||||
nn.Sigmoid, nn.Tanh
|
||||
]:
|
||||
ACTIVATION_LAYERS.register_module(module=module)
|
||||
|
||||
|
||||
def build_activation_layer(cfg):
|
||||
"""Build activation layer.
|
||||
|
||||
Args:
|
||||
cfg (dict): The activation layer config, which should contain:
|
||||
- type (str): Layer type.
|
||||
- layer args: Args needed to instantiate an activation layer.
|
||||
|
||||
Returns:
|
||||
nn.Module: Created activation layer.
|
||||
"""
|
||||
return build_from_cfg(cfg, ACTIVATION_LAYERS)
|
|
@ -0,0 +1,43 @@
|
|||
from torch import nn as nn
|
||||
|
||||
from .registry import CONV_LAYERS
|
||||
|
||||
CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
|
||||
CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
|
||||
CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
|
||||
CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
|
||||
|
||||
|
||||
def build_conv_layer(cfg, *args, **kwargs):
|
||||
"""Build convolution layer.
|
||||
|
||||
Args:
|
||||
cfg (None or dict): The conv layer config, which should contain:
|
||||
- type (str): Layer type.
|
||||
- layer args: Args needed to instantiate an activation layer.
|
||||
args (argument list): Arguments passed to the `__init__`
|
||||
method of the corresponding conv layer.
|
||||
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
|
||||
method of the corresponding conv layer.
|
||||
|
||||
Returns:
|
||||
nn.Module: Created conv layer.
|
||||
"""
|
||||
if cfg is None:
|
||||
cfg_ = dict(type='Conv2d')
|
||||
else:
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError('cfg must be a dict')
|
||||
if 'type' not in cfg:
|
||||
raise KeyError('the cfg dict must contain the key "type"')
|
||||
cfg_ = cfg.copy()
|
||||
|
||||
layer_type = cfg_.pop('type')
|
||||
if layer_type not in CONV_LAYERS:
|
||||
raise KeyError(f'Unrecognized norm type {layer_type}')
|
||||
else:
|
||||
conv_layer = CONV_LAYERS.get(layer_type)
|
||||
|
||||
layer = conv_layer(*args, **kwargs, **cfg_)
|
||||
|
||||
return layer
|
|
@ -0,0 +1,174 @@
|
|||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..weight_init import constant_init, kaiming_init
|
||||
from .activation import build_activation_layer
|
||||
from .conv import build_conv_layer
|
||||
from .norm import build_norm_layer
|
||||
from .padding import build_padding_layer
|
||||
|
||||
|
||||
class ConvModule(nn.Module):
|
||||
"""A conv block that bundles conv/norm/activation layers.
|
||||
|
||||
This block simplifies the usage of convolution layers, which are commonly
|
||||
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
||||
It is based upon three build methods: `build_conv_layer()`,
|
||||
`build_norm_layer()` and `build_activation_layer()`.
|
||||
|
||||
Besides, we add some additional features in this module.
|
||||
1. Automatically set `bias` of the conv layer.
|
||||
2. Spectral norm is supported.
|
||||
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
||||
supports zero and circular padding, and we add "reflect" padding mode.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int | tuple[int]): Same as nn.Conv2d.
|
||||
stride (int | tuple[int]): Same as nn.Conv2d.
|
||||
padding (int | tuple[int]): Same as nn.Conv2d.
|
||||
dilation (int | tuple[int]): Same as nn.Conv2d.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
bias (bool | str): If specified as `auto`, it will be decided by the
|
||||
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
|
||||
False. Default: "auto".
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
inplace (bool): Whether to use inplace mode for activation.
|
||||
Default: True.
|
||||
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
||||
Default: False.
|
||||
padding_mode (str): If the `padding_mode` has not been supported by
|
||||
current `Conv2d` in PyTorch, we will use our own padding layer
|
||||
instead. Currently, we support ['zeros', 'circular'] with official
|
||||
implementation and ['reflect'] with our own implementation.
|
||||
Default: 'zeros'.
|
||||
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
||||
sequence of "conv", "norm" and "act". Common examples are
|
||||
("conv", "norm", "act") and ("act", "conv", "norm").
|
||||
Default: ('conv', 'norm', 'act').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias='auto',
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
inplace=True,
|
||||
with_spectral_norm=False,
|
||||
padding_mode='zeros',
|
||||
order=('conv', 'norm', 'act')):
|
||||
super(ConvModule, self).__init__()
|
||||
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
||||
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
||||
assert act_cfg is None or isinstance(act_cfg, dict)
|
||||
official_padding_mode = ['zeros', 'circular']
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.inplace = inplace
|
||||
self.with_spectral_norm = with_spectral_norm
|
||||
self.with_explicit_padding = padding_mode not in official_padding_mode
|
||||
self.order = order
|
||||
assert isinstance(self.order, tuple) and len(self.order) == 3
|
||||
assert set(order) == set(['conv', 'norm', 'act'])
|
||||
|
||||
self.with_norm = norm_cfg is not None
|
||||
self.with_activation = act_cfg is not None
|
||||
# if the conv layer is before a norm layer, bias is unnecessary.
|
||||
if bias == 'auto':
|
||||
bias = False if self.with_norm else True
|
||||
self.with_bias = bias
|
||||
|
||||
if self.with_norm and self.with_bias:
|
||||
warnings.warn('ConvModule has norm and bias at the same time')
|
||||
|
||||
if self.with_explicit_padding:
|
||||
pad_cfg = dict(type=padding_mode)
|
||||
self.padding_layer = build_padding_layer(pad_cfg, padding)
|
||||
|
||||
# reset padding to 0 for conv module
|
||||
conv_padding = 0 if self.with_explicit_padding else padding
|
||||
# build convolution layer
|
||||
self.conv = build_conv_layer(
|
||||
conv_cfg,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias)
|
||||
# export the attributes of self.conv to a higher level for convenience
|
||||
self.in_channels = self.conv.in_channels
|
||||
self.out_channels = self.conv.out_channels
|
||||
self.kernel_size = self.conv.kernel_size
|
||||
self.stride = self.conv.stride
|
||||
self.padding = padding
|
||||
self.dilation = self.conv.dilation
|
||||
self.transposed = self.conv.transposed
|
||||
self.output_padding = self.conv.output_padding
|
||||
self.groups = self.conv.groups
|
||||
|
||||
if self.with_spectral_norm:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
|
||||
# build normalization layers
|
||||
if self.with_norm:
|
||||
# norm layer is after conv layer
|
||||
if order.index('norm') > order.index('conv'):
|
||||
norm_channels = out_channels
|
||||
else:
|
||||
norm_channels = in_channels
|
||||
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
|
||||
self.add_module(self.norm_name, norm)
|
||||
|
||||
# build activation layer
|
||||
if self.with_activation:
|
||||
act_cfg_ = act_cfg.copy()
|
||||
act_cfg_.setdefault('inplace', inplace)
|
||||
self.activate = build_activation_layer(act_cfg_)
|
||||
|
||||
# Use msra init by default
|
||||
self.init_weights()
|
||||
|
||||
@property
|
||||
def norm(self):
|
||||
return getattr(self, self.norm_name)
|
||||
|
||||
def init_weights(self):
|
||||
if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
|
||||
nonlinearity = 'leaky_relu'
|
||||
a = self.act_cfg.get('negative_slope', 0.01)
|
||||
else:
|
||||
nonlinearity = 'relu'
|
||||
a = 0
|
||||
|
||||
kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
|
||||
if self.with_norm:
|
||||
constant_init(self.norm, 1, bias=0)
|
||||
|
||||
def forward(self, x, activate=True, norm=True):
|
||||
for layer in self.order:
|
||||
if layer == 'conv':
|
||||
if self.with_explicit_padding:
|
||||
x = self.padding_layer(x)
|
||||
x = self.conv(x)
|
||||
elif layer == 'norm' and norm and self.with_norm:
|
||||
x = self.norm(x)
|
||||
elif layer == 'act' and activate and self.with_activation:
|
||||
x = self.activate(x)
|
||||
return x
|
|
@ -0,0 +1,118 @@
|
|||
import inspect
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.instancenorm import _InstanceNorm
|
||||
|
||||
from .registry import NORM_LAYERS
|
||||
|
||||
NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
|
||||
NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
|
||||
NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
|
||||
NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
|
||||
NORM_LAYERS.register_module('SyncBN', module=nn.SyncBatchNorm)
|
||||
NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
|
||||
NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
|
||||
NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
|
||||
NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
|
||||
NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
|
||||
NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
|
||||
|
||||
|
||||
def infer_abbr(class_type):
|
||||
"""Infer abbreviation from the class name.
|
||||
|
||||
When we build a norm layer with `build_norm_layer()`, we want to preserve
|
||||
the norm type in variable names, e.g, self.bn1, self.gn. This method will
|
||||
infer the abbreviation to map class types to abbreviations.
|
||||
|
||||
Rule 1: If the class has the property "abbr", return the property.
|
||||
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
|
||||
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
|
||||
"in" respectively.
|
||||
Rule 3: If the class name contains "batch", "group", "layer" or "instance",
|
||||
the abbreviation of this layer will be "bn", "gn", "ln" and "in"
|
||||
respectively.
|
||||
Rule 4: Otherwise, the abbreviation falls back to "norm".
|
||||
|
||||
Args:
|
||||
class_type (type): The norm layer type.
|
||||
|
||||
Returns:
|
||||
str: The inferred abbreviation.
|
||||
"""
|
||||
if not inspect.isclass(class_type):
|
||||
raise TypeError(
|
||||
f'class_type must be a type, but got {type(class_type)}')
|
||||
if hasattr(class_type, 'abbr'):
|
||||
return class_type.abbr
|
||||
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
|
||||
return 'in'
|
||||
elif issubclass(class_type, _BatchNorm):
|
||||
return 'bn'
|
||||
elif issubclass(class_type, nn.GroupNorm):
|
||||
return 'gn'
|
||||
elif issubclass(class_type, nn.LayerNorm):
|
||||
return 'ln'
|
||||
else:
|
||||
class_name = class_type.__name__.lower()
|
||||
if 'batch' in class_name:
|
||||
return 'bn'
|
||||
elif 'group' in class_name:
|
||||
return 'gn'
|
||||
elif 'layer' in class_name:
|
||||
return 'ln'
|
||||
elif 'instance' in class_name:
|
||||
return 'in'
|
||||
else:
|
||||
return 'norm'
|
||||
|
||||
|
||||
def build_norm_layer(cfg, num_features, postfix=''):
|
||||
"""Build normalization layer.
|
||||
|
||||
Args:
|
||||
cfg (dict): The norm layer config, which should contain:
|
||||
- type (str): Layer type.
|
||||
- layer args: Args needed to instantiate a norm layer.
|
||||
- requires_grad (bool, optional): Whether stop gradient updates.
|
||||
num_features (int): Number of input channels.
|
||||
postfix (int | str): The postfix to be appended into norm abbreviation
|
||||
to create named layer.
|
||||
|
||||
Returns:
|
||||
tuple[str, nn.Module]:
|
||||
name (str): The layer name consisting of abbreviation and postfix,
|
||||
e.g., bn1, gn.
|
||||
layer (nn.Module): Created norm layer.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError('cfg must be a dict')
|
||||
if 'type' not in cfg:
|
||||
raise KeyError('the cfg dict must contain the key "type"')
|
||||
cfg_ = cfg.copy()
|
||||
|
||||
layer_type = cfg_.pop('type')
|
||||
if layer_type not in NORM_LAYERS:
|
||||
raise KeyError(f'Unrecognized norm type {layer_type}')
|
||||
|
||||
norm_layer = NORM_LAYERS.get(layer_type)
|
||||
abbr = infer_abbr(norm_layer)
|
||||
|
||||
assert isinstance(postfix, (int, str))
|
||||
name = abbr + str(postfix)
|
||||
|
||||
requires_grad = cfg_.pop('requires_grad', True)
|
||||
cfg_.setdefault('eps', 1e-5)
|
||||
if layer_type != 'GN':
|
||||
layer = norm_layer(num_features, **cfg_)
|
||||
if layer_type == 'SyncBN':
|
||||
layer._specify_ddp_gpu_num(1)
|
||||
else:
|
||||
assert 'num_groups' in cfg_
|
||||
layer = norm_layer(num_channels=num_features, **cfg_)
|
||||
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
return name, layer
|
|
@ -0,0 +1,35 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from .registry import PADDING_LAYERS
|
||||
|
||||
PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
|
||||
PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
|
||||
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
|
||||
|
||||
|
||||
def build_padding_layer(cfg, *args, **kwargs):
|
||||
"""Build padding layer.
|
||||
|
||||
Args:
|
||||
cfg (None or dict): The padding layer config, which should contain:
|
||||
- type (str): Layer type.
|
||||
- layer args: Args needed to instantiate a padding layer.
|
||||
|
||||
Returns:
|
||||
nn.Module: Created padding layer.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError('cfg must be a dict')
|
||||
if 'type' not in cfg:
|
||||
raise KeyError('the cfg dict must contain the key "type"')
|
||||
|
||||
cfg_ = cfg.copy()
|
||||
padding_type = cfg_.pop('type')
|
||||
if padding_type not in PADDING_LAYERS:
|
||||
raise KeyError(f'Unrecognized padding type {padding_type}.')
|
||||
else:
|
||||
padding_layer = PADDING_LAYERS.get(padding_type)
|
||||
|
||||
layer = padding_layer(*args, **kwargs, **cfg_)
|
||||
|
||||
return layer
|
|
@ -0,0 +1,7 @@
|
|||
from mmcv.utils import Registry
|
||||
|
||||
CONV_LAYERS = Registry('conv layer')
|
||||
NORM_LAYERS = Registry('norm layer')
|
||||
ACTIVATION_LAYERS = Registry('activation layer')
|
||||
PADDING_LAYERS = Registry('padding layer')
|
||||
UPSAMPLE_LAYERS = Registry('upsample layer')
|
|
@ -0,0 +1,20 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Scale(nn.Module):
|
||||
"""A learnable scale parameter.
|
||||
|
||||
This layer scales the input by a learnable factor. It multiplies a
|
||||
learnable scale parameter of shape (1,) with input of any shape.
|
||||
|
||||
Args:
|
||||
scale (float): Initial value of scale factor. Default: 1.0
|
||||
"""
|
||||
|
||||
def __init__(self, scale=1.0):
|
||||
super(Scale, self).__init__()
|
||||
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
|
@ -0,0 +1,79 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..weight_init import xavier_init
|
||||
from .registry import UPSAMPLE_LAYERS
|
||||
|
||||
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
|
||||
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
|
||||
UPSAMPLE_LAYERS.register_module('deconv', module=nn.ConvTranspose2d)
|
||||
|
||||
|
||||
@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
|
||||
class PixelShufflePack(nn.Module):
|
||||
"""Pixel Shuffle upsample layer.
|
||||
|
||||
This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
|
||||
achieve a simple upsampling with pixel shuffle.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
scale_factor (int): Upsample ratio.
|
||||
upsample_kernel (int): Kernel size of the conv layer to expand the
|
||||
channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, scale_factor,
|
||||
upsample_kernel):
|
||||
super(PixelShufflePack, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.scale_factor = scale_factor
|
||||
self.upsample_kernel = upsample_kernel
|
||||
self.upsample_conv = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.out_channels * scale_factor * scale_factor,
|
||||
self.upsample_kernel,
|
||||
padding=(self.upsample_kernel - 1) // 2)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
xavier_init(self.upsample_conv, distribution='uniform')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.upsample_conv(x)
|
||||
x = F.pixel_shuffle(x, self.scale_factor)
|
||||
return x
|
||||
|
||||
|
||||
def build_upsample_layer(cfg):
|
||||
"""Build upsample layer.
|
||||
|
||||
Args:
|
||||
cfg (dict): The upsample layer config, which should contain:
|
||||
- type (str): Layer type.
|
||||
- scale_factor (int): Upsample ratio, which is not applicable to
|
||||
deconv.
|
||||
- layer args: Args needed to instantiate a upsample layer.
|
||||
|
||||
Returns:
|
||||
nn.Module: Created upsample layer.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
||||
if 'type' not in cfg:
|
||||
raise KeyError(
|
||||
f'the cfg dict must contain the key "type", but got {cfg}')
|
||||
cfg_ = cfg.copy()
|
||||
|
||||
layer_type = cfg_.pop('type')
|
||||
if layer_type not in UPSAMPLE_LAYERS:
|
||||
raise KeyError(f'Unrecognized upsample type {layer_type}')
|
||||
else:
|
||||
upsample = UPSAMPLE_LAYERS.get(layer_type)
|
||||
|
||||
if upsample is nn.Upsample:
|
||||
cfg_['mode'] = layer_type
|
||||
layer = upsample(**cfg_)
|
||||
return layer
|
|
@ -100,6 +100,8 @@ class Registry(object):
|
|||
the same name. Default: False.
|
||||
module (type): Module class to be registered.
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
# NOTE: This is a walkaround to be compatible with the old api,
|
||||
# while it may introduce unexpected bugs.
|
||||
if isinstance(name, type):
|
||||
|
@ -109,7 +111,7 @@ class Registry(object):
|
|||
if module is not None:
|
||||
self._register_module(
|
||||
module_class=module, module_name=name, force=force)
|
||||
return
|
||||
return module
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)):
|
||||
|
@ -119,6 +121,7 @@ class Registry(object):
|
|||
def _register(cls):
|
||||
self._register_module(
|
||||
module_class=cls, module_name=name, force=force)
|
||||
return cls
|
||||
|
||||
return _register
|
||||
|
||||
|
@ -132,10 +135,13 @@ def build_from_cfg(cfg, registry, default_args=None):
|
|||
default_args (dict, optional): Default initialization arguments.
|
||||
|
||||
Returns:
|
||||
obj: The constructed object.
|
||||
object: The constructed object.
|
||||
"""
|
||||
if not (isinstance(cfg, dict) and 'type' in cfg):
|
||||
raise TypeError('cfg must be a dict containing the key "type"')
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
||||
if 'type' not in cfg:
|
||||
raise KeyError(
|
||||
f'the cfg dict must contain the key "type", but got {cfg}')
|
||||
if not isinstance(registry, Registry):
|
||||
raise TypeError('registry must be an mmcv.Registry object, '
|
||||
f'but got {type(registry)}')
|
||||
|
|
|
@ -0,0 +1,245 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.cnn.bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
||||
PADDING_LAYERS, build_activation_layer,
|
||||
build_conv_layer, build_norm_layer,
|
||||
build_padding_layer, build_upsample_layer)
|
||||
from mmcv.cnn.bricks.norm import infer_abbr
|
||||
from mmcv.cnn.bricks.upsample import PixelShufflePack
|
||||
|
||||
|
||||
def test_build_conv_layer():
|
||||
with pytest.raises(TypeError):
|
||||
# cfg must be a dict
|
||||
cfg = 'Conv2d'
|
||||
build_conv_layer(cfg)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# `type` must be in cfg
|
||||
cfg = dict(kernel_size=3)
|
||||
build_conv_layer(cfg)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# unsupported conv type
|
||||
cfg = dict(type='FancyConv')
|
||||
build_conv_layer(cfg)
|
||||
|
||||
kwargs = dict(
|
||||
in_channels=4, out_channels=8, kernel_size=3, groups=2, dilation=2)
|
||||
cfg = None
|
||||
layer = build_conv_layer(cfg, **kwargs)
|
||||
assert isinstance(layer, nn.Conv2d)
|
||||
assert layer.in_channels == kwargs['in_channels']
|
||||
assert layer.out_channels == kwargs['out_channels']
|
||||
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size'])
|
||||
assert layer.groups == kwargs['groups']
|
||||
assert layer.dilation == (kwargs['dilation'], kwargs['dilation'])
|
||||
|
||||
cfg = dict(type='Conv')
|
||||
layer = build_conv_layer(cfg, **kwargs)
|
||||
assert isinstance(layer, nn.Conv2d)
|
||||
assert layer.in_channels == kwargs['in_channels']
|
||||
assert layer.out_channels == kwargs['out_channels']
|
||||
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size'])
|
||||
assert layer.groups == kwargs['groups']
|
||||
assert layer.dilation == (kwargs['dilation'], kwargs['dilation'])
|
||||
|
||||
for type_name, module in CONV_LAYERS.module_dict.items():
|
||||
cfg = dict(type=type_name)
|
||||
layer = build_conv_layer(cfg, **kwargs)
|
||||
assert isinstance(layer, module)
|
||||
assert layer.in_channels == kwargs['in_channels']
|
||||
assert layer.out_channels == kwargs['out_channels']
|
||||
|
||||
|
||||
def test_infer_abbr():
|
||||
with pytest.raises(TypeError):
|
||||
# class_type must be a class
|
||||
infer_abbr(0)
|
||||
|
||||
class MyNorm:
|
||||
|
||||
abbr = 'mn'
|
||||
|
||||
assert infer_abbr(MyNorm) == 'mn'
|
||||
|
||||
class FancyBatchNorm:
|
||||
pass
|
||||
|
||||
assert infer_abbr(FancyBatchNorm) == 'bn'
|
||||
|
||||
class FancyInstanceNorm:
|
||||
pass
|
||||
|
||||
assert infer_abbr(FancyInstanceNorm) == 'in'
|
||||
|
||||
class FancyLayerNorm:
|
||||
pass
|
||||
|
||||
assert infer_abbr(FancyLayerNorm) == 'ln'
|
||||
|
||||
class FancyGroupNorm:
|
||||
pass
|
||||
|
||||
assert infer_abbr(FancyGroupNorm) == 'gn'
|
||||
|
||||
class FancyNorm:
|
||||
pass
|
||||
|
||||
assert infer_abbr(FancyNorm) == 'norm'
|
||||
|
||||
|
||||
def test_build_norm_layer():
|
||||
with pytest.raises(TypeError):
|
||||
# cfg must be a dict
|
||||
cfg = 'BN'
|
||||
build_norm_layer(cfg, 3)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# `type` must be in cfg
|
||||
cfg = dict()
|
||||
build_norm_layer(cfg, 3)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# unsupported norm type
|
||||
cfg = dict(type='FancyNorm')
|
||||
build_norm_layer(cfg, 3)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# postfix must be int or str
|
||||
cfg = dict(type='BN')
|
||||
build_norm_layer(cfg, 3, postfix=[1, 2])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# `num_groups` must be in cfg when using 'GN'
|
||||
cfg = dict(type='GN')
|
||||
build_norm_layer(cfg, 3)
|
||||
|
||||
# test each type of norm layer in norm_cfg
|
||||
abbr_mapping = {
|
||||
'BN': 'bn',
|
||||
'BN1d': 'bn',
|
||||
'BN2d': 'bn',
|
||||
'BN3d': 'bn',
|
||||
'SyncBN': 'bn',
|
||||
'GN': 'gn',
|
||||
'LN': 'ln',
|
||||
'IN': 'in',
|
||||
'IN1d': 'in',
|
||||
'IN2d': 'in',
|
||||
'IN3d': 'in',
|
||||
}
|
||||
for type_name, module in NORM_LAYERS.module_dict.items():
|
||||
for postfix in ['_test', 1]:
|
||||
cfg = dict(type=type_name)
|
||||
if type_name == 'GN':
|
||||
cfg['num_groups'] = 2
|
||||
name, layer = build_norm_layer(cfg, 3, postfix=postfix)
|
||||
assert name == abbr_mapping[type_name] + str(postfix)
|
||||
assert isinstance(layer, module)
|
||||
if type_name == 'GN':
|
||||
assert layer.num_channels == 3
|
||||
assert layer.num_groups == cfg['num_groups']
|
||||
elif type_name != 'LN':
|
||||
assert layer.num_features == 3
|
||||
|
||||
|
||||
def test_build_activation_layer():
|
||||
with pytest.raises(TypeError):
|
||||
# cfg must be a dict
|
||||
cfg = 'ReLU'
|
||||
build_activation_layer(cfg)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# `type` must be in cfg
|
||||
cfg = dict()
|
||||
build_activation_layer(cfg)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# unsupported activation type
|
||||
cfg = dict(type='FancyReLU')
|
||||
build_activation_layer(cfg)
|
||||
|
||||
# test each type of activation layer in activation_cfg
|
||||
for type_name, module in ACTIVATION_LAYERS.module_dict.items():
|
||||
cfg['type'] = type_name
|
||||
layer = build_activation_layer(cfg)
|
||||
assert isinstance(layer, module)
|
||||
|
||||
|
||||
def test_build_padding_layer():
|
||||
with pytest.raises(TypeError):
|
||||
# cfg must be a dict
|
||||
cfg = 'reflect'
|
||||
build_padding_layer(cfg)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# `type` must be in cfg
|
||||
cfg = dict()
|
||||
build_padding_layer(cfg)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# unsupported activation type
|
||||
cfg = dict(type='FancyPad')
|
||||
build_padding_layer(cfg)
|
||||
|
||||
for type_name, module in PADDING_LAYERS.module_dict.items():
|
||||
cfg['type'] = type_name
|
||||
layer = build_padding_layer(cfg, 2)
|
||||
assert isinstance(layer, module)
|
||||
|
||||
input_x = torch.randn(1, 2, 5, 5)
|
||||
cfg = dict(type='reflect')
|
||||
padding_layer = build_padding_layer(cfg, 2)
|
||||
res = padding_layer(input_x)
|
||||
assert res.shape == (1, 2, 9, 9)
|
||||
|
||||
|
||||
def test_upsample_layer():
|
||||
with pytest.raises(TypeError):
|
||||
# cfg must be a dict
|
||||
cfg = 'bilinear'
|
||||
build_upsample_layer(cfg)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# `type` must be in cfg
|
||||
cfg = dict()
|
||||
build_upsample_layer(cfg)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# unsupported activation type
|
||||
cfg = dict(type='FancyUpsample')
|
||||
build_upsample_layer(cfg)
|
||||
|
||||
for type_name in ['nearest', 'bilinear']:
|
||||
cfg['type'] = type_name
|
||||
layer = build_upsample_layer(cfg)
|
||||
assert isinstance(layer, nn.Upsample)
|
||||
assert layer.mode == type_name
|
||||
|
||||
cfg = dict(
|
||||
type='deconv', in_channels=3, out_channels=3, kernel_size=3, stride=2)
|
||||
layer = build_upsample_layer(cfg)
|
||||
assert isinstance(layer, nn.ConvTranspose2d)
|
||||
|
||||
cfg = dict(
|
||||
type='pixel_shuffle',
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
scale_factor=2,
|
||||
upsample_kernel=3)
|
||||
layer = build_upsample_layer(cfg)
|
||||
|
||||
assert isinstance(layer, PixelShufflePack)
|
||||
assert layer.scale_factor == 2
|
||||
assert layer.upsample_kernel == 3
|
||||
|
||||
|
||||
def test_pixel_shuffle_pack():
|
||||
x_in = torch.rand(2, 3, 10, 10)
|
||||
pixel_shuffle = PixelShufflePack(3, 3, scale_factor=2, upsample_kernel=3)
|
||||
assert pixel_shuffle.upsample_conv.kernel_size == (3, 3)
|
||||
x_out = pixel_shuffle(x_in)
|
||||
assert x_out.shape == (2, 3, 20, 20)
|
|
@ -0,0 +1,156 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.cnn.bricks import ConvModule
|
||||
|
||||
|
||||
def test_conv_module():
|
||||
with pytest.raises(AssertionError):
|
||||
# conv_cfg must be a dict or None
|
||||
conv_cfg = 'conv'
|
||||
ConvModule(3, 8, 2, conv_cfg=conv_cfg)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# norm_cfg must be a dict or None
|
||||
norm_cfg = 'norm'
|
||||
ConvModule(3, 8, 2, norm_cfg=norm_cfg)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# softmax is not supported
|
||||
act_cfg = dict(type='softmax')
|
||||
ConvModule(3, 8, 2, act_cfg=act_cfg)
|
||||
|
||||
# conv + norm + act
|
||||
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
||||
assert conv.with_activation
|
||||
assert hasattr(conv, 'activate')
|
||||
assert conv.with_norm
|
||||
assert hasattr(conv, 'norm')
|
||||
x = torch.rand(1, 3, 256, 256)
|
||||
output = conv(x)
|
||||
assert output.shape == (1, 8, 255, 255)
|
||||
|
||||
# conv + act
|
||||
conv = ConvModule(3, 8, 2)
|
||||
assert conv.with_activation
|
||||
assert hasattr(conv, 'activate')
|
||||
assert not conv.with_norm
|
||||
assert not hasattr(conv, 'norm')
|
||||
x = torch.rand(1, 3, 256, 256)
|
||||
output = conv(x)
|
||||
assert output.shape == (1, 8, 255, 255)
|
||||
|
||||
# conv
|
||||
conv = ConvModule(3, 8, 2, act_cfg=None)
|
||||
assert not conv.with_norm
|
||||
assert not hasattr(conv, 'norm')
|
||||
assert not conv.with_activation
|
||||
assert not hasattr(conv, 'activate')
|
||||
x = torch.rand(1, 3, 256, 256)
|
||||
output = conv(x)
|
||||
assert output.shape == (1, 8, 255, 255)
|
||||
|
||||
# with_spectral_norm=True
|
||||
conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True)
|
||||
assert hasattr(conv.conv, 'weight_orig')
|
||||
output = conv(x)
|
||||
assert output.shape == (1, 8, 256, 256)
|
||||
|
||||
# padding_mode='reflect'
|
||||
conv = ConvModule(3, 8, 3, padding=1, padding_mode='reflect')
|
||||
assert isinstance(conv.padding_layer, nn.ReflectionPad2d)
|
||||
output = conv(x)
|
||||
assert output.shape == (1, 8, 256, 256)
|
||||
|
||||
# non-existing padding mode
|
||||
with pytest.raises(KeyError):
|
||||
conv = ConvModule(3, 8, 3, padding=1, padding_mode='non_exists')
|
||||
|
||||
# leaky relu
|
||||
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU'))
|
||||
assert isinstance(conv.activate, nn.LeakyReLU)
|
||||
output = conv(x)
|
||||
assert output.shape == (1, 8, 256, 256)
|
||||
|
||||
|
||||
def test_bias():
|
||||
# bias: auto, without norm
|
||||
conv = ConvModule(3, 8, 2)
|
||||
assert conv.conv.bias is not None
|
||||
|
||||
# bias: auto, with norm
|
||||
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
||||
assert conv.conv.bias is None
|
||||
|
||||
# bias: False, without norm
|
||||
conv = ConvModule(3, 8, 2, bias=False)
|
||||
assert conv.conv.bias is None
|
||||
|
||||
# bias: True, with norm
|
||||
with pytest.warns(UserWarning) as record:
|
||||
ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='BN'))
|
||||
assert len(record) == 1
|
||||
assert record[0].message.args[
|
||||
0] == 'ConvModule has norm and bias at the same time'
|
||||
|
||||
|
||||
def conv_forward(self, x):
|
||||
return x + '_conv'
|
||||
|
||||
|
||||
def bn_forward(self, x):
|
||||
return x + '_bn'
|
||||
|
||||
|
||||
def relu_forward(self, x):
|
||||
return x + '_relu'
|
||||
|
||||
|
||||
@patch('torch.nn.ReLU.forward', relu_forward)
|
||||
@patch('torch.nn.BatchNorm2d.forward', bn_forward)
|
||||
@patch('torch.nn.Conv2d.forward', conv_forward)
|
||||
def test_order():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# order must be a tuple
|
||||
order = ['conv', 'norm', 'act']
|
||||
ConvModule(3, 8, 2, order=order)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# length of order must be 3
|
||||
order = ('conv', 'norm')
|
||||
ConvModule(3, 8, 2, order=order)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# order must be an order of 'conv', 'norm', 'act'
|
||||
order = ('conv', 'norm', 'norm')
|
||||
ConvModule(3, 8, 2, order=order)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# order must be an order of 'conv', 'norm', 'act'
|
||||
order = ('conv', 'norm', 'something')
|
||||
ConvModule(3, 8, 2, order=order)
|
||||
|
||||
# ('conv', 'norm', 'act')
|
||||
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
||||
out = conv('input')
|
||||
assert out == 'input_conv_bn_relu'
|
||||
|
||||
# ('norm', 'conv', 'act')
|
||||
conv = ConvModule(
|
||||
3, 8, 2, norm_cfg=dict(type='BN'), order=('norm', 'conv', 'act'))
|
||||
out = conv('input')
|
||||
assert out == 'input_bn_conv_relu'
|
||||
|
||||
# ('conv', 'norm', 'act'), activate=False
|
||||
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
||||
out = conv('input', activate=False)
|
||||
assert out == 'input_conv_bn'
|
||||
|
||||
# ('conv', 'norm', 'act'), activate=False
|
||||
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
||||
out = conv('input', norm=False)
|
||||
assert out == 'input_conv_relu'
|
|
@ -0,0 +1,21 @@
|
|||
import torch
|
||||
|
||||
from mmcv.cnn.bricks import Scale
|
||||
|
||||
|
||||
def test_scale():
|
||||
# test default scale
|
||||
scale = Scale()
|
||||
assert scale.scale.data == 1.
|
||||
assert scale.scale.dtype == torch.float
|
||||
x = torch.rand(1, 3, 64, 64)
|
||||
output = scale(x)
|
||||
assert output.shape == (1, 3, 64, 64)
|
||||
|
||||
# test given scale
|
||||
scale = Scale(10.)
|
||||
assert scale.scale.data == 10.
|
||||
assert scale.scale.dtype == torch.float
|
||||
x = torch.rand(1, 3, 64, 64)
|
||||
output = scale(x)
|
||||
assert output.shape == (1, 3, 64, 64)
|
|
@ -179,16 +179,16 @@ def test_build_from_cfg():
|
|||
model = mmcv.build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# cfg should contain the key "type"
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(KeyError):
|
||||
cfg = dict(depth=50, stages=4)
|
||||
model = mmcv.build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# incorrect registry type
|
||||
with pytest.raises(TypeError):
|
||||
dict(type='ResNet', depth=50)
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = mmcv.build_from_cfg(cfg, 'BACKBONES')
|
||||
|
||||
# incorrect default_args type
|
||||
with pytest.raises(TypeError):
|
||||
dict(type='ResNet', depth=50)
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)
|
||||
|
|
Loading…
Reference in New Issue