Add building bricks of cnn ()

* add building bricks of cnn

* add unit tests

* use registry for building bricks

* minor updates

* add scale layer

* add test for scale

* add doc string

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
pull/261/head
Kai Chen 2020-05-01 00:32:25 +08:00 committed by GitHub
parent 89f709e8e7
commit 45111e193d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 960 additions and 8 deletions

View File

@ -1,5 +1,10 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .alexnet import AlexNet
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, UPSAMPLE_LAYERS, ConvModule, Scale,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer,
build_upsample_layer)
from .resnet import ResNet, make_res_layer
from .vgg import VGG, make_vgg_layer
from .weight_init import (bias_init_with_prob, caffe2_xavier_init,
@ -9,5 +14,8 @@ from .weight_init import (bias_init_with_prob, caffe2_xavier_init,
__all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob'
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
'build_padding_layer', 'build_upsample_layer', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
]

View File

@ -0,0 +1,16 @@
from .activation import build_activation_layer
from .conv import build_conv_layer
from .conv_module import ConvModule
from .norm import build_norm_layer
from .padding import build_padding_layer
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, UPSAMPLE_LAYERS)
from .scale import Scale
from .upsample import build_upsample_layer
__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'Scale'
]

View File

@ -0,0 +1,24 @@
import torch.nn as nn
from mmcv.utils import build_from_cfg
from .registry import ACTIVATION_LAYERS
for module in [
nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
nn.Sigmoid, nn.Tanh
]:
ACTIVATION_LAYERS.register_module(module=module)
def build_activation_layer(cfg):
"""Build activation layer.
Args:
cfg (dict): The activation layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate an activation layer.
Returns:
nn.Module: Created activation layer.
"""
return build_from_cfg(cfg, ACTIVATION_LAYERS)

View File

@ -0,0 +1,43 @@
from torch import nn as nn
from .registry import CONV_LAYERS
CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
def build_conv_layer(cfg, *args, **kwargs):
"""Build convolution layer.
Args:
cfg (None or dict): The conv layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate an activation layer.
args (argument list): Arguments passed to the `__init__`
method of the corresponding conv layer.
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
method of the corresponding conv layer.
Returns:
nn.Module: Created conv layer.
"""
if cfg is None:
cfg_ = dict(type='Conv2d')
else:
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in CONV_LAYERS:
raise KeyError(f'Unrecognized norm type {layer_type}')
else:
conv_layer = CONV_LAYERS.get(layer_type)
layer = conv_layer(*args, **kwargs, **cfg_)
return layer

View File

@ -0,0 +1,174 @@
import warnings
import torch.nn as nn
from ..weight_init import constant_init, kaiming_init
from .activation import build_activation_layer
from .conv import build_conv_layer
from .norm import build_norm_layer
from .padding import build_padding_layer
class ConvModule(nn.Module):
"""A conv block that bundles conv/norm/activation layers.
This block simplifies the usage of convolution layers, which are commonly
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
It is based upon three build methods: `build_conv_layer()`,
`build_norm_layer()` and `build_activation_layer()`.
Besides, we add some additional features in this module.
1. Automatically set `bias` of the conv layer.
2. Spectral norm is supported.
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
supports zero and circular padding, and we add "reflect" padding mode.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int | tuple[int]): Same as nn.Conv2d.
stride (int | tuple[int]): Same as nn.Conv2d.
padding (int | tuple[int]): Same as nn.Conv2d.
dilation (int | tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool | str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
False. Default: "auto".
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
inplace (bool): Whether to use inplace mode for activation.
Default: True.
with_spectral_norm (bool): Whether use spectral norm in conv module.
Default: False.
padding_mode (str): If the `padding_mode` has not been supported by
current `Conv2d` in PyTorch, we will use our own padding layer
instead. Currently, we support ['zeros', 'circular'] with official
implementation and ['reflect'] with our own implementation.
Default: 'zeros'.
order (tuple[str]): The order of conv/norm/activation layers. It is a
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act').
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias='auto',
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
inplace=True,
with_spectral_norm=False,
padding_mode='zeros',
order=('conv', 'norm', 'act')):
super(ConvModule, self).__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
assert act_cfg is None or isinstance(act_cfg, dict)
official_padding_mode = ['zeros', 'circular']
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.inplace = inplace
self.with_spectral_norm = with_spectral_norm
self.with_explicit_padding = padding_mode not in official_padding_mode
self.order = order
assert isinstance(self.order, tuple) and len(self.order) == 3
assert set(order) == set(['conv', 'norm', 'act'])
self.with_norm = norm_cfg is not None
self.with_activation = act_cfg is not None
# if the conv layer is before a norm layer, bias is unnecessary.
if bias == 'auto':
bias = False if self.with_norm else True
self.with_bias = bias
if self.with_norm and self.with_bias:
warnings.warn('ConvModule has norm and bias at the same time')
if self.with_explicit_padding:
pad_cfg = dict(type=padding_mode)
self.padding_layer = build_padding_layer(pad_cfg, padding)
# reset padding to 0 for conv module
conv_padding = 0 if self.with_explicit_padding else padding
# build convolution layer
self.conv = build_conv_layer(
conv_cfg,
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=conv_padding,
dilation=dilation,
groups=groups,
bias=bias)
# export the attributes of self.conv to a higher level for convenience
self.in_channels = self.conv.in_channels
self.out_channels = self.conv.out_channels
self.kernel_size = self.conv.kernel_size
self.stride = self.conv.stride
self.padding = padding
self.dilation = self.conv.dilation
self.transposed = self.conv.transposed
self.output_padding = self.conv.output_padding
self.groups = self.conv.groups
if self.with_spectral_norm:
self.conv = nn.utils.spectral_norm(self.conv)
# build normalization layers
if self.with_norm:
# norm layer is after conv layer
if order.index('norm') > order.index('conv'):
norm_channels = out_channels
else:
norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
self.add_module(self.norm_name, norm)
# build activation layer
if self.with_activation:
act_cfg_ = act_cfg.copy()
act_cfg_.setdefault('inplace', inplace)
self.activate = build_activation_layer(act_cfg_)
# Use msra init by default
self.init_weights()
@property
def norm(self):
return getattr(self, self.norm_name)
def init_weights(self):
if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
nonlinearity = 'leaky_relu'
a = self.act_cfg.get('negative_slope', 0.01)
else:
nonlinearity = 'relu'
a = 0
kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
if self.with_norm:
constant_init(self.norm, 1, bias=0)
def forward(self, x, activate=True, norm=True):
for layer in self.order:
if layer == 'conv':
if self.with_explicit_padding:
x = self.padding_layer(x)
x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm:
x = self.norm(x)
elif layer == 'act' and activate and self.with_activation:
x = self.activate(x)
return x

View File

@ -0,0 +1,118 @@
import inspect
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from .registry import NORM_LAYERS
NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
NORM_LAYERS.register_module('SyncBN', module=nn.SyncBatchNorm)
NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
def infer_abbr(class_type):
"""Infer abbreviation from the class name.
When we build a norm layer with `build_norm_layer()`, we want to preserve
the norm type in variable names, e.g, self.bn1, self.gn. This method will
infer the abbreviation to map class types to abbreviations.
Rule 1: If the class has the property "abbr", return the property.
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
"in" respectively.
Rule 3: If the class name contains "batch", "group", "layer" or "instance",
the abbreviation of this layer will be "bn", "gn", "ln" and "in"
respectively.
Rule 4: Otherwise, the abbreviation falls back to "norm".
Args:
class_type (type): The norm layer type.
Returns:
str: The inferred abbreviation.
"""
if not inspect.isclass(class_type):
raise TypeError(
f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, 'abbr'):
return class_type.abbr
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
return 'in'
elif issubclass(class_type, _BatchNorm):
return 'bn'
elif issubclass(class_type, nn.GroupNorm):
return 'gn'
elif issubclass(class_type, nn.LayerNorm):
return 'ln'
else:
class_name = class_type.__name__.lower()
if 'batch' in class_name:
return 'bn'
elif 'group' in class_name:
return 'gn'
elif 'layer' in class_name:
return 'ln'
elif 'instance' in class_name:
return 'in'
else:
return 'norm'
def build_norm_layer(cfg, num_features, postfix=''):
"""Build normalization layer.
Args:
cfg (dict): The norm layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a norm layer.
- requires_grad (bool, optional): Whether stop gradient updates.
num_features (int): Number of input channels.
postfix (int | str): The postfix to be appended into norm abbreviation
to create named layer.
Returns:
tuple[str, nn.Module]:
name (str): The layer name consisting of abbreviation and postfix,
e.g., bn1, gn.
layer (nn.Module): Created norm layer.
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in NORM_LAYERS:
raise KeyError(f'Unrecognized norm type {layer_type}')
norm_layer = NORM_LAYERS.get(layer_type)
abbr = infer_abbr(norm_layer)
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN':
layer._specify_ddp_gpu_num(1)
else:
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)
for param in layer.parameters():
param.requires_grad = requires_grad
return name, layer

View File

@ -0,0 +1,35 @@
import torch.nn as nn
from .registry import PADDING_LAYERS
PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
def build_padding_layer(cfg, *args, **kwargs):
"""Build padding layer.
Args:
cfg (None or dict): The padding layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a padding layer.
Returns:
nn.Module: Created padding layer.
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
padding_type = cfg_.pop('type')
if padding_type not in PADDING_LAYERS:
raise KeyError(f'Unrecognized padding type {padding_type}.')
else:
padding_layer = PADDING_LAYERS.get(padding_type)
layer = padding_layer(*args, **kwargs, **cfg_)
return layer

View File

@ -0,0 +1,7 @@
from mmcv.utils import Registry
CONV_LAYERS = Registry('conv layer')
NORM_LAYERS = Registry('norm layer')
ACTIVATION_LAYERS = Registry('activation layer')
PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer')

View File

@ -0,0 +1,20 @@
import torch
import torch.nn as nn
class Scale(nn.Module):
"""A learnable scale parameter.
This layer scales the input by a learnable factor. It multiplies a
learnable scale parameter of shape (1,) with input of any shape.
Args:
scale (float): Initial value of scale factor. Default: 1.0
"""
def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
return x * self.scale

View File

@ -0,0 +1,79 @@
import torch.nn as nn
import torch.nn.functional as F
from ..weight_init import xavier_init
from .registry import UPSAMPLE_LAYERS
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('deconv', module=nn.ConvTranspose2d)
@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
class PixelShufflePack(nn.Module):
"""Pixel Shuffle upsample layer.
This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
achieve a simple upsampling with pixel shuffle.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
scale_factor (int): Upsample ratio.
upsample_kernel (int): Kernel size of the conv layer to expand the
channels.
"""
def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel):
super(PixelShufflePack, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.scale_factor = scale_factor
self.upsample_kernel = upsample_kernel
self.upsample_conv = nn.Conv2d(
self.in_channels,
self.out_channels * scale_factor * scale_factor,
self.upsample_kernel,
padding=(self.upsample_kernel - 1) // 2)
self.init_weights()
def init_weights(self):
xavier_init(self.upsample_conv, distribution='uniform')
def forward(self, x):
x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor)
return x
def build_upsample_layer(cfg):
"""Build upsample layer.
Args:
cfg (dict): The upsample layer config, which should contain:
- type (str): Layer type.
- scale_factor (int): Upsample ratio, which is not applicable to
deconv.
- layer args: Args needed to instantiate a upsample layer.
Returns:
nn.Module: Created upsample layer.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
raise KeyError(
f'the cfg dict must contain the key "type", but got {cfg}')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in UPSAMPLE_LAYERS:
raise KeyError(f'Unrecognized upsample type {layer_type}')
else:
upsample = UPSAMPLE_LAYERS.get(layer_type)
if upsample is nn.Upsample:
cfg_['mode'] = layer_type
layer = upsample(**cfg_)
return layer

View File

@ -100,6 +100,8 @@ class Registry(object):
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
@ -109,7 +111,7 @@ class Registry(object):
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return
return module
# raise the error ahead of time
if not (name is None or isinstance(name, str)):
@ -119,6 +121,7 @@ class Registry(object):
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
@ -132,10 +135,13 @@ def build_from_cfg(cfg, registry, default_args=None):
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
object: The constructed object.
"""
if not (isinstance(cfg, dict) and 'type' in cfg):
raise TypeError('cfg must be a dict containing the key "type"')
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
raise KeyError(
f'the cfg dict must contain the key "type", but got {cfg}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')

View File

@ -0,0 +1,245 @@
import pytest
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, build_activation_layer,
build_conv_layer, build_norm_layer,
build_padding_layer, build_upsample_layer)
from mmcv.cnn.bricks.norm import infer_abbr
from mmcv.cnn.bricks.upsample import PixelShufflePack
def test_build_conv_layer():
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'Conv2d'
build_conv_layer(cfg)
with pytest.raises(KeyError):
# `type` must be in cfg
cfg = dict(kernel_size=3)
build_conv_layer(cfg)
with pytest.raises(KeyError):
# unsupported conv type
cfg = dict(type='FancyConv')
build_conv_layer(cfg)
kwargs = dict(
in_channels=4, out_channels=8, kernel_size=3, groups=2, dilation=2)
cfg = None
layer = build_conv_layer(cfg, **kwargs)
assert isinstance(layer, nn.Conv2d)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size'])
assert layer.groups == kwargs['groups']
assert layer.dilation == (kwargs['dilation'], kwargs['dilation'])
cfg = dict(type='Conv')
layer = build_conv_layer(cfg, **kwargs)
assert isinstance(layer, nn.Conv2d)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size'])
assert layer.groups == kwargs['groups']
assert layer.dilation == (kwargs['dilation'], kwargs['dilation'])
for type_name, module in CONV_LAYERS.module_dict.items():
cfg = dict(type=type_name)
layer = build_conv_layer(cfg, **kwargs)
assert isinstance(layer, module)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
def test_infer_abbr():
with pytest.raises(TypeError):
# class_type must be a class
infer_abbr(0)
class MyNorm:
abbr = 'mn'
assert infer_abbr(MyNorm) == 'mn'
class FancyBatchNorm:
pass
assert infer_abbr(FancyBatchNorm) == 'bn'
class FancyInstanceNorm:
pass
assert infer_abbr(FancyInstanceNorm) == 'in'
class FancyLayerNorm:
pass
assert infer_abbr(FancyLayerNorm) == 'ln'
class FancyGroupNorm:
pass
assert infer_abbr(FancyGroupNorm) == 'gn'
class FancyNorm:
pass
assert infer_abbr(FancyNorm) == 'norm'
def test_build_norm_layer():
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'BN'
build_norm_layer(cfg, 3)
with pytest.raises(KeyError):
# `type` must be in cfg
cfg = dict()
build_norm_layer(cfg, 3)
with pytest.raises(KeyError):
# unsupported norm type
cfg = dict(type='FancyNorm')
build_norm_layer(cfg, 3)
with pytest.raises(AssertionError):
# postfix must be int or str
cfg = dict(type='BN')
build_norm_layer(cfg, 3, postfix=[1, 2])
with pytest.raises(AssertionError):
# `num_groups` must be in cfg when using 'GN'
cfg = dict(type='GN')
build_norm_layer(cfg, 3)
# test each type of norm layer in norm_cfg
abbr_mapping = {
'BN': 'bn',
'BN1d': 'bn',
'BN2d': 'bn',
'BN3d': 'bn',
'SyncBN': 'bn',
'GN': 'gn',
'LN': 'ln',
'IN': 'in',
'IN1d': 'in',
'IN2d': 'in',
'IN3d': 'in',
}
for type_name, module in NORM_LAYERS.module_dict.items():
for postfix in ['_test', 1]:
cfg = dict(type=type_name)
if type_name == 'GN':
cfg['num_groups'] = 2
name, layer = build_norm_layer(cfg, 3, postfix=postfix)
assert name == abbr_mapping[type_name] + str(postfix)
assert isinstance(layer, module)
if type_name == 'GN':
assert layer.num_channels == 3
assert layer.num_groups == cfg['num_groups']
elif type_name != 'LN':
assert layer.num_features == 3
def test_build_activation_layer():
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'ReLU'
build_activation_layer(cfg)
with pytest.raises(KeyError):
# `type` must be in cfg
cfg = dict()
build_activation_layer(cfg)
with pytest.raises(KeyError):
# unsupported activation type
cfg = dict(type='FancyReLU')
build_activation_layer(cfg)
# test each type of activation layer in activation_cfg
for type_name, module in ACTIVATION_LAYERS.module_dict.items():
cfg['type'] = type_name
layer = build_activation_layer(cfg)
assert isinstance(layer, module)
def test_build_padding_layer():
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'reflect'
build_padding_layer(cfg)
with pytest.raises(KeyError):
# `type` must be in cfg
cfg = dict()
build_padding_layer(cfg)
with pytest.raises(KeyError):
# unsupported activation type
cfg = dict(type='FancyPad')
build_padding_layer(cfg)
for type_name, module in PADDING_LAYERS.module_dict.items():
cfg['type'] = type_name
layer = build_padding_layer(cfg, 2)
assert isinstance(layer, module)
input_x = torch.randn(1, 2, 5, 5)
cfg = dict(type='reflect')
padding_layer = build_padding_layer(cfg, 2)
res = padding_layer(input_x)
assert res.shape == (1, 2, 9, 9)
def test_upsample_layer():
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'bilinear'
build_upsample_layer(cfg)
with pytest.raises(KeyError):
# `type` must be in cfg
cfg = dict()
build_upsample_layer(cfg)
with pytest.raises(KeyError):
# unsupported activation type
cfg = dict(type='FancyUpsample')
build_upsample_layer(cfg)
for type_name in ['nearest', 'bilinear']:
cfg['type'] = type_name
layer = build_upsample_layer(cfg)
assert isinstance(layer, nn.Upsample)
assert layer.mode == type_name
cfg = dict(
type='deconv', in_channels=3, out_channels=3, kernel_size=3, stride=2)
layer = build_upsample_layer(cfg)
assert isinstance(layer, nn.ConvTranspose2d)
cfg = dict(
type='pixel_shuffle',
in_channels=3,
out_channels=3,
scale_factor=2,
upsample_kernel=3)
layer = build_upsample_layer(cfg)
assert isinstance(layer, PixelShufflePack)
assert layer.scale_factor == 2
assert layer.upsample_kernel == 3
def test_pixel_shuffle_pack():
x_in = torch.rand(2, 3, 10, 10)
pixel_shuffle = PixelShufflePack(3, 3, scale_factor=2, upsample_kernel=3)
assert pixel_shuffle.upsample_conv.kernel_size == (3, 3)
x_out = pixel_shuffle(x_in)
assert x_out.shape == (2, 3, 20, 20)

View File

@ -0,0 +1,156 @@
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from mmcv.cnn.bricks import ConvModule
def test_conv_module():
with pytest.raises(AssertionError):
# conv_cfg must be a dict or None
conv_cfg = 'conv'
ConvModule(3, 8, 2, conv_cfg=conv_cfg)
with pytest.raises(AssertionError):
# norm_cfg must be a dict or None
norm_cfg = 'norm'
ConvModule(3, 8, 2, norm_cfg=norm_cfg)
with pytest.raises(KeyError):
# softmax is not supported
act_cfg = dict(type='softmax')
ConvModule(3, 8, 2, act_cfg=act_cfg)
# conv + norm + act
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
assert conv.with_activation
assert hasattr(conv, 'activate')
assert conv.with_norm
assert hasattr(conv, 'norm')
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# conv + act
conv = ConvModule(3, 8, 2)
assert conv.with_activation
assert hasattr(conv, 'activate')
assert not conv.with_norm
assert not hasattr(conv, 'norm')
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# conv
conv = ConvModule(3, 8, 2, act_cfg=None)
assert not conv.with_norm
assert not hasattr(conv, 'norm')
assert not conv.with_activation
assert not hasattr(conv, 'activate')
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# with_spectral_norm=True
conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True)
assert hasattr(conv.conv, 'weight_orig')
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# padding_mode='reflect'
conv = ConvModule(3, 8, 3, padding=1, padding_mode='reflect')
assert isinstance(conv.padding_layer, nn.ReflectionPad2d)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# non-existing padding mode
with pytest.raises(KeyError):
conv = ConvModule(3, 8, 3, padding=1, padding_mode='non_exists')
# leaky relu
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU'))
assert isinstance(conv.activate, nn.LeakyReLU)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
def test_bias():
# bias: auto, without norm
conv = ConvModule(3, 8, 2)
assert conv.conv.bias is not None
# bias: auto, with norm
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
assert conv.conv.bias is None
# bias: False, without norm
conv = ConvModule(3, 8, 2, bias=False)
assert conv.conv.bias is None
# bias: True, with norm
with pytest.warns(UserWarning) as record:
ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='BN'))
assert len(record) == 1
assert record[0].message.args[
0] == 'ConvModule has norm and bias at the same time'
def conv_forward(self, x):
return x + '_conv'
def bn_forward(self, x):
return x + '_bn'
def relu_forward(self, x):
return x + '_relu'
@patch('torch.nn.ReLU.forward', relu_forward)
@patch('torch.nn.BatchNorm2d.forward', bn_forward)
@patch('torch.nn.Conv2d.forward', conv_forward)
def test_order():
with pytest.raises(AssertionError):
# order must be a tuple
order = ['conv', 'norm', 'act']
ConvModule(3, 8, 2, order=order)
with pytest.raises(AssertionError):
# length of order must be 3
order = ('conv', 'norm')
ConvModule(3, 8, 2, order=order)
with pytest.raises(AssertionError):
# order must be an order of 'conv', 'norm', 'act'
order = ('conv', 'norm', 'norm')
ConvModule(3, 8, 2, order=order)
with pytest.raises(AssertionError):
# order must be an order of 'conv', 'norm', 'act'
order = ('conv', 'norm', 'something')
ConvModule(3, 8, 2, order=order)
# ('conv', 'norm', 'act')
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
out = conv('input')
assert out == 'input_conv_bn_relu'
# ('norm', 'conv', 'act')
conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), order=('norm', 'conv', 'act'))
out = conv('input')
assert out == 'input_bn_conv_relu'
# ('conv', 'norm', 'act'), activate=False
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
out = conv('input', activate=False)
assert out == 'input_conv_bn'
# ('conv', 'norm', 'act'), activate=False
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
out = conv('input', norm=False)
assert out == 'input_conv_relu'

View File

@ -0,0 +1,21 @@
import torch
from mmcv.cnn.bricks import Scale
def test_scale():
# test default scale
scale = Scale()
assert scale.scale.data == 1.
assert scale.scale.dtype == torch.float
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)
# test given scale
scale = Scale(10.)
assert scale.scale.data == 10.
assert scale.scale.dtype == torch.float
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)

View File

@ -179,16 +179,16 @@ def test_build_from_cfg():
model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg should contain the key "type"
with pytest.raises(TypeError):
with pytest.raises(KeyError):
cfg = dict(depth=50, stages=4)
model = mmcv.build_from_cfg(cfg, BACKBONES)
# incorrect registry type
with pytest.raises(TypeError):
dict(type='ResNet', depth=50)
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, 'BACKBONES')
# incorrect default_args type
with pytest.raises(TypeError):
dict(type='ResNet', depth=50)
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)