mirror of https://github.com/open-mmlab/mmcv.git
252 lines
7.5 KiB
Python
252 lines
7.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import warnings
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish
|
|
from mmcv.utils import TORCH_VERSION, digit_version
|
|
|
|
|
|
@CONV_LAYERS.register_module()
|
|
class ExampleConv(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
norm_cfg=None):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.padding = padding
|
|
self.dilation = dilation
|
|
self.groups = groups
|
|
self.bias = bias
|
|
self.norm_cfg = norm_cfg
|
|
self.output_padding = (0, 0, 0)
|
|
self.transposed = False
|
|
|
|
self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size)
|
|
self.init_weights()
|
|
|
|
def forward(self, x):
|
|
x = self.conv0(x)
|
|
return x
|
|
|
|
def init_weights(self):
|
|
nn.init.constant_(self.conv0.weight, 0)
|
|
|
|
|
|
def test_conv_module():
|
|
with pytest.raises(AssertionError):
|
|
# conv_cfg must be a dict or None
|
|
conv_cfg = 'conv'
|
|
ConvModule(3, 8, 2, conv_cfg=conv_cfg)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# norm_cfg must be a dict or None
|
|
norm_cfg = 'norm'
|
|
ConvModule(3, 8, 2, norm_cfg=norm_cfg)
|
|
|
|
with pytest.raises(KeyError):
|
|
# softmax is not supported
|
|
act_cfg = dict(type='softmax')
|
|
ConvModule(3, 8, 2, act_cfg=act_cfg)
|
|
|
|
# conv + norm + act
|
|
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
|
assert conv.with_activation
|
|
assert hasattr(conv, 'activate')
|
|
assert conv.with_norm
|
|
assert hasattr(conv, 'norm')
|
|
x = torch.rand(1, 3, 256, 256)
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 255, 255)
|
|
|
|
# conv + act
|
|
conv = ConvModule(3, 8, 2)
|
|
assert conv.with_activation
|
|
assert hasattr(conv, 'activate')
|
|
assert not conv.with_norm
|
|
assert conv.norm is None
|
|
x = torch.rand(1, 3, 256, 256)
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 255, 255)
|
|
|
|
# conv
|
|
conv = ConvModule(3, 8, 2, act_cfg=None)
|
|
assert not conv.with_norm
|
|
assert conv.norm is None
|
|
assert not conv.with_activation
|
|
assert not hasattr(conv, 'activate')
|
|
x = torch.rand(1, 3, 256, 256)
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 255, 255)
|
|
|
|
# conv with its own `init_weights` method
|
|
conv_module = ConvModule(
|
|
3, 8, 2, conv_cfg=dict(type='ExampleConv'), act_cfg=None)
|
|
assert torch.equal(conv_module.conv.conv0.weight, torch.zeros(8, 3, 2, 2))
|
|
|
|
# with_spectral_norm=True
|
|
conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True)
|
|
assert hasattr(conv.conv, 'weight_orig')
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 256, 256)
|
|
|
|
# padding_mode='reflect'
|
|
conv = ConvModule(3, 8, 3, padding=1, padding_mode='reflect')
|
|
assert isinstance(conv.padding_layer, nn.ReflectionPad2d)
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 256, 256)
|
|
|
|
# non-existing padding mode
|
|
with pytest.raises(KeyError):
|
|
conv = ConvModule(3, 8, 3, padding=1, padding_mode='non_exists')
|
|
|
|
# leaky relu
|
|
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU'))
|
|
assert isinstance(conv.activate, nn.LeakyReLU)
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 256, 256)
|
|
|
|
# tanh
|
|
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Tanh'))
|
|
assert isinstance(conv.activate, nn.Tanh)
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 256, 256)
|
|
|
|
# Sigmoid
|
|
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Sigmoid'))
|
|
assert isinstance(conv.activate, nn.Sigmoid)
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 256, 256)
|
|
|
|
# PReLU
|
|
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='PReLU'))
|
|
assert isinstance(conv.activate, nn.PReLU)
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 256, 256)
|
|
|
|
# HSwish
|
|
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish'))
|
|
if (TORCH_VERSION == 'parrots'
|
|
or digit_version(TORCH_VERSION) < digit_version('1.7')):
|
|
assert isinstance(conv.activate, HSwish)
|
|
else:
|
|
assert isinstance(conv.activate, nn.Hardswish)
|
|
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 256, 256)
|
|
|
|
# HSigmoid
|
|
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSigmoid'))
|
|
assert isinstance(conv.activate, HSigmoid)
|
|
output = conv(x)
|
|
assert output.shape == (1, 8, 256, 256)
|
|
|
|
|
|
def test_bias():
|
|
# bias: auto, without norm
|
|
conv = ConvModule(3, 8, 2)
|
|
assert conv.conv.bias is not None
|
|
|
|
# bias: auto, with norm
|
|
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
|
assert conv.conv.bias is None
|
|
|
|
# bias: False, without norm
|
|
conv = ConvModule(3, 8, 2, bias=False)
|
|
assert conv.conv.bias is None
|
|
|
|
# bias: True, with batch norm
|
|
with pytest.warns(UserWarning) as record:
|
|
ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='BN'))
|
|
assert len(record) == 1
|
|
assert record[0].message.args[
|
|
0] == 'Unnecessary conv bias before batch/instance norm'
|
|
|
|
# bias: True, with instance norm
|
|
with pytest.warns(UserWarning) as record:
|
|
ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='IN'))
|
|
assert len(record) == 1
|
|
assert record[0].message.args[
|
|
0] == 'Unnecessary conv bias before batch/instance norm'
|
|
|
|
# bias: True, with other norm
|
|
with pytest.warns(UserWarning) as record:
|
|
norm_cfg = dict(type='GN', num_groups=1)
|
|
ConvModule(3, 8, 2, bias=True, norm_cfg=norm_cfg)
|
|
warnings.warn('No warnings')
|
|
assert len(record) == 1
|
|
assert record[0].message.args[0] == 'No warnings'
|
|
|
|
|
|
def conv_forward(self, x):
|
|
return x + '_conv'
|
|
|
|
|
|
def bn_forward(self, x):
|
|
return x + '_bn'
|
|
|
|
|
|
def relu_forward(self, x):
|
|
return x + '_relu'
|
|
|
|
|
|
@patch('torch.nn.ReLU.forward', relu_forward)
|
|
@patch('torch.nn.BatchNorm2d.forward', bn_forward)
|
|
@patch('torch.nn.Conv2d.forward', conv_forward)
|
|
def test_order():
|
|
|
|
with pytest.raises(AssertionError):
|
|
# order must be a tuple
|
|
order = ['conv', 'norm', 'act']
|
|
ConvModule(3, 8, 2, order=order)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# length of order must be 3
|
|
order = ('conv', 'norm')
|
|
ConvModule(3, 8, 2, order=order)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# order must be an order of 'conv', 'norm', 'act'
|
|
order = ('conv', 'norm', 'norm')
|
|
ConvModule(3, 8, 2, order=order)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# order must be an order of 'conv', 'norm', 'act'
|
|
order = ('conv', 'norm', 'something')
|
|
ConvModule(3, 8, 2, order=order)
|
|
|
|
# ('conv', 'norm', 'act')
|
|
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
|
out = conv('input')
|
|
assert out == 'input_conv_bn_relu'
|
|
|
|
# ('norm', 'conv', 'act')
|
|
conv = ConvModule(
|
|
3, 8, 2, norm_cfg=dict(type='BN'), order=('norm', 'conv', 'act'))
|
|
out = conv('input')
|
|
assert out == 'input_bn_conv_relu'
|
|
|
|
# ('conv', 'norm', 'act'), activate=False
|
|
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
|
out = conv('input', activate=False)
|
|
assert out == 'input_conv_bn'
|
|
|
|
# ('conv', 'norm', 'act'), activate=False
|
|
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
|
|
out = conv('input', norm=False)
|
|
assert out == 'input_conv_relu'
|