diff --git a/mmseg/__init__.py b/mmseg/__init__.py index f301a5dc3..d1f472c04 100644 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -3,7 +3,7 @@ import mmcv from .version import __version__, version_info MMCV_MIN = '1.1.4' -MMCV_MAX = '1.3.0' +MMCV_MAX = '1.4.0' def digit_version(version_str): diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 413228626..8f0fc16ff 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -1,10 +1,11 @@ from .inverted_residual import InvertedResidual, InvertedResidualV3 from .make_divisible import make_divisible from .res_layer import ResLayer +from .se_layer import SELayer from .self_attention_block import SelfAttentionBlock from .up_conv_block import UpConvBlock __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', - 'UpConvBlock', 'InvertedResidualV3' + 'UpConvBlock', 'InvertedResidualV3', 'SELayer' ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_models/__init__.py b/tests/test_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_models/test_backbones/__init__.py b/tests/test_models/test_backbones/__init__.py new file mode 100644 index 000000000..78a93a54f --- /dev/null +++ b/tests/test_models/test_backbones/__init__.py @@ -0,0 +1,3 @@ +from .utils import all_zeros, check_norm_state, is_block, is_norm + +__all__ = ['is_norm', 'is_block', 'all_zeros', 'check_norm_state'] diff --git a/tests/test_utils/test_inverted_residual_module.py b/tests/test_models/test_backbones/test_blocks.py similarity index 71% rename from tests/test_utils/test_inverted_residual_module.py rename to tests/test_models/test_backbones/test_blocks.py index 8d5eecf15..f459fbba8 100644 --- a/tests/test_utils/test_inverted_residual_module.py +++ b/tests/test_models/test_backbones/test_blocks.py @@ -2,7 +2,20 @@ import mmcv import pytest import torch -from mmseg.models.utils import InvertedResidual, InvertedResidualV3 +from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer, + make_divisible) + + +def test_make_divisible(): + # test with min_value = None + assert make_divisible(10, 4) == 12 + assert make_divisible(9, 4) == 12 + assert make_divisible(1, 4) == 4 + + # test with min_value = 8 + assert make_divisible(10, 4, 8) == 12 + assert make_divisible(9, 4, 8) == 12 + assert make_divisible(1, 4, 8) == 8 def test_inv_residual(): @@ -118,3 +131,39 @@ def test_inv_residualv3(): x = torch.randn(2, 32, 64, 64, requires_grad=True) output = inv_module(x) assert output.shape == (2, 40, 32, 32) + + +def test_se_layer(): + with pytest.raises(AssertionError): + # test act_cfg assertion. + SELayer(32, act_cfg=(dict(type='ReLU'), )) + + # test config with channels = 16. + se_layer = SELayer(16) + assert se_layer.conv1.conv.kernel_size == (1, 1) + assert se_layer.conv1.conv.stride == (1, 1) + assert se_layer.conv1.conv.padding == (0, 0) + assert isinstance(se_layer.conv1.activate, torch.nn.ReLU) + assert se_layer.conv2.conv.kernel_size == (1, 1) + assert se_layer.conv2.conv.stride == (1, 1) + assert se_layer.conv2.conv.padding == (0, 0) + assert isinstance(se_layer.conv2.activate, mmcv.cnn.HSigmoid) + + x = torch.rand(1, 16, 64, 64) + output = se_layer(x) + assert output.shape == (1, 16, 64, 64) + + # test config with channels = 16, act_cfg = dict(type='ReLU'). + se_layer = SELayer(16, act_cfg=dict(type='ReLU')) + assert se_layer.conv1.conv.kernel_size == (1, 1) + assert se_layer.conv1.conv.stride == (1, 1) + assert se_layer.conv1.conv.padding == (0, 0) + assert isinstance(se_layer.conv1.activate, torch.nn.ReLU) + assert se_layer.conv2.conv.kernel_size == (1, 1) + assert se_layer.conv2.conv.stride == (1, 1) + assert se_layer.conv2.conv.padding == (0, 0) + assert isinstance(se_layer.conv2.activate, torch.nn.ReLU) + + x = torch.rand(1, 16, 64, 64) + output = se_layer(x) + assert output.shape == (1, 16, 64, 64) diff --git a/tests/test_models/test_backbones/test_cgnet.py b/tests/test_models/test_backbones/test_cgnet.py new file mode 100644 index 000000000..dfc4e9ade --- /dev/null +++ b/tests/test_models/test_backbones/test_cgnet.py @@ -0,0 +1,150 @@ +import pytest +import torch + +from mmseg.models.backbones import CGNet +from mmseg.models.backbones.cgnet import (ContextGuidedBlock, + GlobalContextExtractor) + + +def test_cgnet_GlobalContextExtractor(): + block = GlobalContextExtractor(16, 16, with_cp=True) + x = torch.randn(2, 16, 64, 64, requires_grad=True) + x_out = block(x) + assert x_out.shape == torch.Size([2, 16, 64, 64]) + + +def test_cgnet_context_guided_block(): + with pytest.raises(AssertionError): + # cgnet ContextGuidedBlock GlobalContextExtractor channel and reduction + # constraints. + ContextGuidedBlock(8, 8) + + # test cgnet ContextGuidedBlock with checkpoint forward + block = ContextGuidedBlock( + 16, 16, act_cfg=dict(type='PReLU'), with_cp=True) + assert block.with_cp + x = torch.randn(2, 16, 64, 64, requires_grad=True) + x_out = block(x) + assert x_out.shape == torch.Size([2, 16, 64, 64]) + + # test cgnet ContextGuidedBlock without checkpoint forward + block = ContextGuidedBlock(32, 32) + assert not block.with_cp + x = torch.randn(3, 32, 32, 32) + x_out = block(x) + assert x_out.shape == torch.Size([3, 32, 32, 32]) + + # test cgnet ContextGuidedBlock with down sampling + block = ContextGuidedBlock(32, 32, downsample=True) + assert block.conv1x1.conv.in_channels == 32 + assert block.conv1x1.conv.out_channels == 32 + assert block.conv1x1.conv.kernel_size == (3, 3) + assert block.conv1x1.conv.stride == (2, 2) + assert block.conv1x1.conv.padding == (1, 1) + + assert block.f_loc.in_channels == 32 + assert block.f_loc.out_channels == 32 + assert block.f_loc.kernel_size == (3, 3) + assert block.f_loc.stride == (1, 1) + assert block.f_loc.padding == (1, 1) + assert block.f_loc.groups == 32 + assert block.f_loc.dilation == (1, 1) + assert block.f_loc.bias is None + + assert block.f_sur.in_channels == 32 + assert block.f_sur.out_channels == 32 + assert block.f_sur.kernel_size == (3, 3) + assert block.f_sur.stride == (1, 1) + assert block.f_sur.padding == (2, 2) + assert block.f_sur.groups == 32 + assert block.f_sur.dilation == (2, 2) + assert block.f_sur.bias is None + + assert block.bottleneck.in_channels == 64 + assert block.bottleneck.out_channels == 32 + assert block.bottleneck.kernel_size == (1, 1) + assert block.bottleneck.stride == (1, 1) + assert block.bottleneck.bias is None + + x = torch.randn(1, 32, 32, 32) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 16, 16]) + + # test cgnet ContextGuidedBlock without down sampling + block = ContextGuidedBlock(32, 32, downsample=False) + assert block.conv1x1.conv.in_channels == 32 + assert block.conv1x1.conv.out_channels == 16 + assert block.conv1x1.conv.kernel_size == (1, 1) + assert block.conv1x1.conv.stride == (1, 1) + assert block.conv1x1.conv.padding == (0, 0) + + assert block.f_loc.in_channels == 16 + assert block.f_loc.out_channels == 16 + assert block.f_loc.kernel_size == (3, 3) + assert block.f_loc.stride == (1, 1) + assert block.f_loc.padding == (1, 1) + assert block.f_loc.groups == 16 + assert block.f_loc.dilation == (1, 1) + assert block.f_loc.bias is None + + assert block.f_sur.in_channels == 16 + assert block.f_sur.out_channels == 16 + assert block.f_sur.kernel_size == (3, 3) + assert block.f_sur.stride == (1, 1) + assert block.f_sur.padding == (2, 2) + assert block.f_sur.groups == 16 + assert block.f_sur.dilation == (2, 2) + assert block.f_sur.bias is None + + x = torch.randn(1, 32, 32, 32) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 32, 32]) + + +def test_cgnet_backbone(): + with pytest.raises(AssertionError): + # check invalid num_channels + CGNet(num_channels=(32, 64, 128, 256)) + + with pytest.raises(AssertionError): + # check invalid num_blocks + CGNet(num_blocks=(3, 21, 3)) + + with pytest.raises(AssertionError): + # check invalid dilation + CGNet(num_blocks=2) + + with pytest.raises(AssertionError): + # check invalid reduction + CGNet(reductions=16) + + with pytest.raises(AssertionError): + # check invalid num_channels and reduction + CGNet(num_channels=(32, 64, 128), reductions=(64, 129)) + + # Test CGNet with default settings + model = CGNet() + model.init_weights() + model.train() + + imgs = torch.randn(2, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size([2, 35, 112, 112]) + assert feat[1].shape == torch.Size([2, 131, 56, 56]) + assert feat[2].shape == torch.Size([2, 256, 28, 28]) + + # Test CGNet with norm_eval True and with_cp True + model = CGNet(norm_eval=True, with_cp=True) + with pytest.raises(TypeError): + # check invalid pretrained + model.init_weights(pretrained=8) + model.init_weights() + model.train() + + imgs = torch.randn(2, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size([2, 35, 112, 112]) + assert feat[1].shape == torch.Size([2, 131, 56, 56]) + assert feat[2].shape == torch.Size([2, 256, 28, 28]) diff --git a/tests/test_models/test_backbones/test_fast_scnn.py b/tests/test_models/test_backbones/test_fast_scnn.py new file mode 100644 index 000000000..f4a580987 --- /dev/null +++ b/tests/test_models/test_backbones/test_fast_scnn.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from mmseg.models.backbones import FastSCNN + + +def test_fastscnn_backbone(): + with pytest.raises(AssertionError): + # Fast-SCNN channel constraints. + FastSCNN( + 3, (32, 48), + 64, (64, 96, 128), (2, 2, 1), + global_out_channels=127, + higher_in_channels=64, + lower_in_channels=128) + + # Test FastSCNN Standard Forward + model = FastSCNN() + model.init_weights() + model.train() + batch_size = 4 + imgs = torch.randn(batch_size, 3, 512, 1024) + feat = model(imgs) + + assert len(feat) == 3 + # higher-res + assert feat[0].shape == torch.Size([batch_size, 64, 64, 128]) + # lower-res + assert feat[1].shape == torch.Size([batch_size, 128, 16, 32]) + # FFM output + assert feat[2].shape == torch.Size([batch_size, 128, 64, 128]) diff --git a/tests/test_models/test_backbones/test_mobilenet_v3.py b/tests/test_models/test_backbones/test_mobilenet_v3.py new file mode 100644 index 000000000..1ebeac410 --- /dev/null +++ b/tests/test_models/test_backbones/test_mobilenet_v3.py @@ -0,0 +1,66 @@ +import pytest +import torch + +from mmseg.models.backbones import MobileNetV3 + + +def test_mobilenet_v3(): + with pytest.raises(AssertionError): + # check invalid arch + MobileNetV3('big') + + with pytest.raises(AssertionError): + # check invalid reduction_factor + MobileNetV3(reduction_factor=0) + + with pytest.raises(ValueError): + # check invalid out_indices + MobileNetV3(out_indices=(0, 1, 15)) + + with pytest.raises(ValueError): + # check invalid frozen_stages + MobileNetV3(frozen_stages=15) + + with pytest.raises(TypeError): + # check invalid pretrained + model = MobileNetV3() + model.init_weights(pretrained=8) + + # Test MobileNetV3 with default settings + model = MobileNetV3() + model.init_weights() + model.train() + + imgs = torch.randn(2, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == (2, 16, 112, 112) + assert feat[1].shape == (2, 16, 56, 56) + assert feat[2].shape == (2, 576, 28, 28) + + # Test MobileNetV3 with arch = 'large' + model = MobileNetV3(arch='large', out_indices=(1, 3, 16)) + model.init_weights() + model.train() + + imgs = torch.randn(2, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == (2, 16, 112, 112) + assert feat[1].shape == (2, 24, 56, 56) + assert feat[2].shape == (2, 960, 28, 28) + + # Test MobileNetV3 with norm_eval True, with_cp True and frozen_stages=5 + model = MobileNetV3(norm_eval=True, with_cp=True, frozen_stages=5) + with pytest.raises(TypeError): + # check invalid pretrained + model.init_weights(pretrained=8) + model.init_weights() + model.train() + + imgs = torch.randn(2, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == (2, 16, 112, 112) + assert feat[1].shape == (2, 16, 56, 56) + assert feat[2].shape == (2, 576, 28, 28) diff --git a/tests/test_models/test_backbones/test_resnest.py b/tests/test_models/test_backbones/test_resnest.py new file mode 100644 index 000000000..78d97de0c --- /dev/null +++ b/tests/test_models/test_backbones/test_resnest.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from mmseg.models.backbones import ResNeSt +from mmseg.models.backbones.resnest import Bottleneck as BottleneckS + + +def test_resnest_bottleneck(): + with pytest.raises(AssertionError): + # Style must be in ['pytorch', 'caffe'] + BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow') + + # Test ResNeSt Bottleneck structure + block = BottleneckS( + 64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch') + assert block.avd_layer.stride == 2 + assert block.conv2.channels == 256 + + # Test ResNeSt Bottleneck forward + block = BottleneckS(64, 16, radix=2, reduction_factor=4) + x = torch.randn(2, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([2, 64, 56, 56]) + + +def test_resnest_backbone(): + with pytest.raises(KeyError): + # ResNeSt depth should be in [50, 101, 152, 200] + ResNeSt(depth=18) + + # Test ResNeSt with radix 2, reduction_factor 4 + model = ResNeSt( + depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3)) + model.init_weights() + model.train() + + imgs = torch.randn(2, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([2, 256, 56, 56]) + assert feat[1].shape == torch.Size([2, 512, 28, 28]) + assert feat[2].shape == torch.Size([2, 1024, 14, 14]) + assert feat[3].shape == torch.Size([2, 2048, 7, 7]) diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbones/test_resnet.py similarity index 62% rename from tests/test_models/test_backbone.py rename to tests/test_models/test_backbones/test_resnet.py index 9ed6ce222..b95277ee4 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbones/test_resnet.py @@ -4,50 +4,10 @@ from mmcv.ops import DeformConv2dPack from mmcv.utils.parrots_wrapper import _BatchNorm from torch.nn.modules import AvgPool2d, GroupNorm -from mmseg.models.backbones import (CGNet, FastSCNN, MobileNetV3, ResNeSt, - ResNet, ResNetV1d, ResNeXt) -from mmseg.models.backbones.cgnet import (ContextGuidedBlock, - GlobalContextExtractor) -from mmseg.models.backbones.resnest import Bottleneck as BottleneckS +from mmseg.models.backbones import ResNet, ResNetV1d from mmseg.models.backbones.resnet import BasicBlock, Bottleneck -from mmseg.models.backbones.resnext import Bottleneck as BottleneckX from mmseg.models.utils import ResLayer - - -def is_block(modules): - """Check if is ResNet building block.""" - if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX)): - return True - return False - - -def is_norm(modules): - """Check if is one of the norms.""" - if isinstance(modules, (GroupNorm, _BatchNorm)): - return True - return False - - -def all_zeros(modules): - """Check if the weight(and bias) is all zero.""" - weight_zero = torch.allclose(modules.weight.data, - torch.zeros_like(modules.weight.data)) - if hasattr(modules, 'bias'): - bias_zero = torch.allclose(modules.bias.data, - torch.zeros_like(modules.bias.data)) - else: - bias_zero = True - - return weight_zero and bias_zero - - -def check_norm_state(modules, train_state): - """Check if norm layer is in correct train state.""" - for mod in modules: - if isinstance(mod, _BatchNorm): - if mod.training != train_state: - return False - return True +from .utils import all_zeros, check_norm_state, is_block, is_norm def test_resnet_basic_block(): @@ -611,329 +571,3 @@ def test_resnet_backbone(): assert feat[1].shape == torch.Size([1, 512, 28, 28]) assert feat[2].shape == torch.Size([1, 1024, 14, 14]) assert feat[3].shape == torch.Size([1, 2048, 7, 7]) - - -def test_renext_bottleneck(): - with pytest.raises(AssertionError): - # Style must be in ['pytorch', 'caffe'] - BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow') - - # Test ResNeXt Bottleneck structure - block = BottleneckX( - 64, 64, groups=32, base_width=4, stride=2, style='pytorch') - assert block.conv2.stride == (2, 2) - assert block.conv2.groups == 32 - assert block.conv2.out_channels == 128 - - # Test ResNeXt Bottleneck with DCN - dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) - with pytest.raises(AssertionError): - # conv_cfg must be None if dcn is not None - BottleneckX( - 64, - 64, - groups=32, - base_width=4, - dcn=dcn, - conv_cfg=dict(type='Conv')) - BottleneckX(64, 64, dcn=dcn) - - # Test ResNeXt Bottleneck forward - block = BottleneckX(64, 16, groups=32, base_width=4) - x = torch.randn(1, 64, 56, 56) - x_out = block(x) - assert x_out.shape == torch.Size([1, 64, 56, 56]) - - -def test_resnext_backbone(): - with pytest.raises(KeyError): - # ResNeXt depth should be in [50, 101, 152] - ResNeXt(depth=18) - - # Test ResNeXt with group 32, base_width 4 - model = ResNeXt(depth=50, groups=32, base_width=4) - print(model) - for m in model.modules(): - if is_block(m): - assert m.conv2.groups == 32 - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 4 - assert feat[0].shape == torch.Size([1, 256, 56, 56]) - assert feat[1].shape == torch.Size([1, 512, 28, 28]) - assert feat[2].shape == torch.Size([1, 1024, 14, 14]) - assert feat[3].shape == torch.Size([1, 2048, 7, 7]) - - -def test_fastscnn_backbone(): - with pytest.raises(AssertionError): - # Fast-SCNN channel constraints. - FastSCNN( - 3, (32, 48), - 64, (64, 96, 128), (2, 2, 1), - global_out_channels=127, - higher_in_channels=64, - lower_in_channels=128) - - # Test FastSCNN Standard Forward - model = FastSCNN() - model.init_weights() - model.train() - batch_size = 4 - imgs = torch.randn(batch_size, 3, 512, 1024) - feat = model(imgs) - - assert len(feat) == 3 - # higher-res - assert feat[0].shape == torch.Size([batch_size, 64, 64, 128]) - # lower-res - assert feat[1].shape == torch.Size([batch_size, 128, 16, 32]) - # FFM output - assert feat[2].shape == torch.Size([batch_size, 128, 64, 128]) - - -def test_resnest_bottleneck(): - with pytest.raises(AssertionError): - # Style must be in ['pytorch', 'caffe'] - BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow') - - # Test ResNeSt Bottleneck structure - block = BottleneckS( - 64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch') - assert block.avd_layer.stride == 2 - assert block.conv2.channels == 256 - - # Test ResNeSt Bottleneck forward - block = BottleneckS(64, 16, radix=2, reduction_factor=4) - x = torch.randn(2, 64, 56, 56) - x_out = block(x) - assert x_out.shape == torch.Size([2, 64, 56, 56]) - - -def test_resnest_backbone(): - with pytest.raises(KeyError): - # ResNeSt depth should be in [50, 101, 152, 200] - ResNeSt(depth=18) - - # Test ResNeSt with radix 2, reduction_factor 4 - model = ResNeSt( - depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3)) - model.init_weights() - model.train() - - imgs = torch.randn(2, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 4 - assert feat[0].shape == torch.Size([2, 256, 56, 56]) - assert feat[1].shape == torch.Size([2, 512, 28, 28]) - assert feat[2].shape == torch.Size([2, 1024, 14, 14]) - assert feat[3].shape == torch.Size([2, 2048, 7, 7]) - - -def test_cgnet_GlobalContextExtractor(): - block = GlobalContextExtractor(16, 16, with_cp=True) - x = torch.randn(2, 16, 64, 64, requires_grad=True) - x_out = block(x) - assert x_out.shape == torch.Size([2, 16, 64, 64]) - - -def test_cgnet_context_guided_block(): - with pytest.raises(AssertionError): - # cgnet ContextGuidedBlock GlobalContextExtractor channel and reduction - # constraints. - ContextGuidedBlock(8, 8) - - # test cgnet ContextGuidedBlock with checkpoint forward - block = ContextGuidedBlock( - 16, 16, act_cfg=dict(type='PReLU'), with_cp=True) - assert block.with_cp - x = torch.randn(2, 16, 64, 64, requires_grad=True) - x_out = block(x) - assert x_out.shape == torch.Size([2, 16, 64, 64]) - - # test cgnet ContextGuidedBlock without checkpoint forward - block = ContextGuidedBlock(32, 32) - assert not block.with_cp - x = torch.randn(3, 32, 32, 32) - x_out = block(x) - assert x_out.shape == torch.Size([3, 32, 32, 32]) - - # test cgnet ContextGuidedBlock with down sampling - block = ContextGuidedBlock(32, 32, downsample=True) - assert block.conv1x1.conv.in_channels == 32 - assert block.conv1x1.conv.out_channels == 32 - assert block.conv1x1.conv.kernel_size == (3, 3) - assert block.conv1x1.conv.stride == (2, 2) - assert block.conv1x1.conv.padding == (1, 1) - - assert block.f_loc.in_channels == 32 - assert block.f_loc.out_channels == 32 - assert block.f_loc.kernel_size == (3, 3) - assert block.f_loc.stride == (1, 1) - assert block.f_loc.padding == (1, 1) - assert block.f_loc.groups == 32 - assert block.f_loc.dilation == (1, 1) - assert block.f_loc.bias is None - - assert block.f_sur.in_channels == 32 - assert block.f_sur.out_channels == 32 - assert block.f_sur.kernel_size == (3, 3) - assert block.f_sur.stride == (1, 1) - assert block.f_sur.padding == (2, 2) - assert block.f_sur.groups == 32 - assert block.f_sur.dilation == (2, 2) - assert block.f_sur.bias is None - - assert block.bottleneck.in_channels == 64 - assert block.bottleneck.out_channels == 32 - assert block.bottleneck.kernel_size == (1, 1) - assert block.bottleneck.stride == (1, 1) - assert block.bottleneck.bias is None - - x = torch.randn(1, 32, 32, 32) - x_out = block(x) - assert x_out.shape == torch.Size([1, 32, 16, 16]) - - # test cgnet ContextGuidedBlock without down sampling - block = ContextGuidedBlock(32, 32, downsample=False) - assert block.conv1x1.conv.in_channels == 32 - assert block.conv1x1.conv.out_channels == 16 - assert block.conv1x1.conv.kernel_size == (1, 1) - assert block.conv1x1.conv.stride == (1, 1) - assert block.conv1x1.conv.padding == (0, 0) - - assert block.f_loc.in_channels == 16 - assert block.f_loc.out_channels == 16 - assert block.f_loc.kernel_size == (3, 3) - assert block.f_loc.stride == (1, 1) - assert block.f_loc.padding == (1, 1) - assert block.f_loc.groups == 16 - assert block.f_loc.dilation == (1, 1) - assert block.f_loc.bias is None - - assert block.f_sur.in_channels == 16 - assert block.f_sur.out_channels == 16 - assert block.f_sur.kernel_size == (3, 3) - assert block.f_sur.stride == (1, 1) - assert block.f_sur.padding == (2, 2) - assert block.f_sur.groups == 16 - assert block.f_sur.dilation == (2, 2) - assert block.f_sur.bias is None - - x = torch.randn(1, 32, 32, 32) - x_out = block(x) - assert x_out.shape == torch.Size([1, 32, 32, 32]) - - -def test_cgnet_backbone(): - with pytest.raises(AssertionError): - # check invalid num_channels - CGNet(num_channels=(32, 64, 128, 256)) - - with pytest.raises(AssertionError): - # check invalid num_blocks - CGNet(num_blocks=(3, 21, 3)) - - with pytest.raises(AssertionError): - # check invalid dilation - CGNet(num_blocks=2) - - with pytest.raises(AssertionError): - # check invalid reduction - CGNet(reductions=16) - - with pytest.raises(AssertionError): - # check invalid num_channels and reduction - CGNet(num_channels=(32, 64, 128), reductions=(64, 129)) - - # Test CGNet with default settings - model = CGNet() - model.init_weights() - model.train() - - imgs = torch.randn(2, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 3 - assert feat[0].shape == torch.Size([2, 35, 112, 112]) - assert feat[1].shape == torch.Size([2, 131, 56, 56]) - assert feat[2].shape == torch.Size([2, 256, 28, 28]) - - # Test CGNet with norm_eval True and with_cp True - model = CGNet(norm_eval=True, with_cp=True) - with pytest.raises(TypeError): - # check invalid pretrained - model.init_weights(pretrained=8) - model.init_weights() - model.train() - - imgs = torch.randn(2, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 3 - assert feat[0].shape == torch.Size([2, 35, 112, 112]) - assert feat[1].shape == torch.Size([2, 131, 56, 56]) - assert feat[2].shape == torch.Size([2, 256, 28, 28]) - - -def test_mobilenet_v3(): - with pytest.raises(AssertionError): - # check invalid arch - MobileNetV3('big') - - with pytest.raises(AssertionError): - # check invalid reduction_factor - MobileNetV3(reduction_factor=0) - - with pytest.raises(ValueError): - # check invalid out_indices - MobileNetV3(out_indices=(0, 1, 15)) - - with pytest.raises(ValueError): - # check invalid frozen_stages - MobileNetV3(frozen_stages=15) - - with pytest.raises(TypeError): - # check invalid pretrained - model = MobileNetV3() - model.init_weights(pretrained=8) - - # Test MobileNetV3 with default settings - model = MobileNetV3() - model.init_weights() - model.train() - - imgs = torch.randn(2, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 3 - assert feat[0].shape == (2, 16, 112, 112) - assert feat[1].shape == (2, 16, 56, 56) - assert feat[2].shape == (2, 576, 28, 28) - - # Test MobileNetV3 with arch = 'large' - model = MobileNetV3(arch='large', out_indices=(1, 3, 16)) - model.init_weights() - model.train() - - imgs = torch.randn(2, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 3 - assert feat[0].shape == (2, 16, 112, 112) - assert feat[1].shape == (2, 24, 56, 56) - assert feat[2].shape == (2, 960, 28, 28) - - # Test MobileNetV3 with norm_eval True, with_cp True and frozen_stages=5 - model = MobileNetV3(norm_eval=True, with_cp=True, frozen_stages=5) - with pytest.raises(TypeError): - # check invalid pretrained - model.init_weights(pretrained=8) - model.init_weights() - model.train() - - imgs = torch.randn(2, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 3 - assert feat[0].shape == (2, 16, 112, 112) - assert feat[1].shape == (2, 16, 56, 56) - assert feat[2].shape == (2, 576, 28, 28) diff --git a/tests/test_models/test_backbones/test_resnext.py b/tests/test_models/test_backbones/test_resnext.py new file mode 100644 index 000000000..2ba5f8ec2 --- /dev/null +++ b/tests/test_models/test_backbones/test_resnext.py @@ -0,0 +1,61 @@ +import pytest +import torch + +from mmseg.models.backbones import ResNeXt +from mmseg.models.backbones.resnext import Bottleneck as BottleneckX +from .utils import is_block + + +def test_renext_bottleneck(): + with pytest.raises(AssertionError): + # Style must be in ['pytorch', 'caffe'] + BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow') + + # Test ResNeXt Bottleneck structure + block = BottleneckX( + 64, 64, groups=32, base_width=4, stride=2, style='pytorch') + assert block.conv2.stride == (2, 2) + assert block.conv2.groups == 32 + assert block.conv2.out_channels == 128 + + # Test ResNeXt Bottleneck with DCN + dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + with pytest.raises(AssertionError): + # conv_cfg must be None if dcn is not None + BottleneckX( + 64, + 64, + groups=32, + base_width=4, + dcn=dcn, + conv_cfg=dict(type='Conv')) + BottleneckX(64, 64, dcn=dcn) + + # Test ResNeXt Bottleneck forward + block = BottleneckX(64, 16, groups=32, base_width=4) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + +def test_resnext_backbone(): + with pytest.raises(KeyError): + # ResNeXt depth should be in [50, 101, 152] + ResNeXt(depth=18) + + # Test ResNeXt with group 32, base_width 4 + model = ResNeXt(depth=50, groups=32, base_width=4) + print(model) + for m in model.modules(): + if is_block(m): + assert m.conv2.groups == 32 + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) diff --git a/tests/test_models/test_unet.py b/tests/test_models/test_backbones/test_unet.py similarity index 98% rename from tests/test_models/test_unet.py rename to tests/test_models/test_backbones/test_unet.py index 59dbb67d2..b17b22a05 100644 --- a/tests/test_models/test_unet.py +++ b/tests/test_models/test_backbones/test_unet.py @@ -1,20 +1,11 @@ import pytest import torch from mmcv.cnn import ConvModule -from mmcv.utils.parrots_wrapper import _BatchNorm from torch import nn from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, InterpConv, UNet, UpConvBlock) - - -def check_norm_state(modules, train_state): - """Check if norm layer is in correct train state.""" - for mod in modules: - if isinstance(mod, _BatchNorm): - if mod.training != train_state: - return False - return True +from .utils import check_norm_state def test_unet_basic_conv_block(): diff --git a/tests/test_models/test_backbones/utils.py b/tests/test_models/test_backbones/utils.py new file mode 100644 index 000000000..d50b772c5 --- /dev/null +++ b/tests/test_models/test_backbones/utils.py @@ -0,0 +1,42 @@ +import torch +from torch.nn.modules import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.models.backbones.resnet import BasicBlock, Bottleneck +from mmseg.models.backbones.resnext import Bottleneck as BottleneckX + + +def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX)): + return True + return False + + +def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + +def all_zeros(modules): + """Check if the weight(and bias) is all zero.""" + weight_zero = torch.allclose(modules.weight.data, + torch.zeros_like(modules.weight.data)) + if hasattr(modules, 'bias'): + bias_zero = torch.allclose(modules.bias.data, + torch.zeros_like(modules.bias.data)) + else: + bias_zero = True + + return weight_zero and bias_zero + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py deleted file mode 100644 index e8a8493c1..000000000 --- a/tests/test_models/test_heads.py +++ /dev/null @@ -1,834 +0,0 @@ -from unittest.mock import patch - -import pytest -import torch -from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule -from mmcv.utils import ConfigDict -from mmcv.utils.parrots_wrapper import SyncBatchNorm - -from mmseg.models.decode_heads import (ANNHead, APCHead, ASPPHead, CCHead, - DAHead, DepthwiseSeparableASPPHead, - DepthwiseSeparableFCNHead, DMHead, - DNLHead, EMAHead, EncHead, FCNHead, - GCHead, LRASPPHead, NLHead, OCRHead, - PointHead, PSAHead, PSPHead, UPerHead) -from mmseg.models.decode_heads.decode_head import BaseDecodeHead - - -def _conv_has_norm(module, sync_bn): - for m in module.modules(): - if isinstance(m, ConvModule): - if not m.with_norm: - return False - if sync_bn: - if not isinstance(m.bn, SyncBatchNorm): - return False - return True - - -def to_cuda(module, data): - module = module.cuda() - if isinstance(data, list): - for i in range(len(data)): - data[i] = data[i].cuda() - return module, data - - -@patch.multiple(BaseDecodeHead, __abstractmethods__=set()) -def test_decode_head(): - - with pytest.raises(AssertionError): - # default input_transform doesn't accept multiple inputs - BaseDecodeHead([32, 16], 16, num_classes=19) - - with pytest.raises(AssertionError): - # default input_transform doesn't accept multiple inputs - BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2]) - - with pytest.raises(AssertionError): - # supported mode is resize_concat only - BaseDecodeHead(32, 16, num_classes=19, input_transform='concat') - - with pytest.raises(AssertionError): - # in_channels should be list|tuple - BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat') - - with pytest.raises(AssertionError): - # in_index should be list|tuple - BaseDecodeHead([32], - 16, - in_index=-1, - num_classes=19, - input_transform='resize_concat') - - with pytest.raises(AssertionError): - # len(in_index) should equal len(in_channels) - BaseDecodeHead([32, 16], - 16, - num_classes=19, - in_index=[-1], - input_transform='resize_concat') - - # test default dropout - head = BaseDecodeHead(32, 16, num_classes=19) - assert hasattr(head, 'dropout') and head.dropout.p == 0.1 - - # test set dropout - head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2) - assert hasattr(head, 'dropout') and head.dropout.p == 0.2 - - # test no input_transform - inputs = [torch.randn(1, 32, 45, 45)] - head = BaseDecodeHead(32, 16, num_classes=19) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.in_channels == 32 - assert head.input_transform is None - transformed_inputs = head._transform_inputs(inputs) - assert transformed_inputs.shape == (1, 32, 45, 45) - - # test input_transform = resize_concat - inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)] - head = BaseDecodeHead([32, 16], - 16, - num_classes=19, - in_index=[0, 1], - input_transform='resize_concat') - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.in_channels == 48 - assert head.input_transform == 'resize_concat' - transformed_inputs = head._transform_inputs(inputs) - assert transformed_inputs.shape == (1, 48, 45, 45) - - -def test_fcn_head(): - - with pytest.raises(AssertionError): - # num_convs must be not less than 0 - FCNHead(num_classes=19, num_convs=-1) - - # test no norm_cfg - head = FCNHead(in_channels=32, channels=16, num_classes=19) - for m in head.modules(): - if isinstance(m, ConvModule): - assert not m.with_norm - - # test with norm_cfg - head = FCNHead( - in_channels=32, - channels=16, - num_classes=19, - norm_cfg=dict(type='SyncBN')) - for m in head.modules(): - if isinstance(m, ConvModule): - assert m.with_norm and isinstance(m.bn, SyncBatchNorm) - - # test concat_input=False - inputs = [torch.randn(1, 32, 45, 45)] - head = FCNHead( - in_channels=32, channels=16, num_classes=19, concat_input=False) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert len(head.convs) == 2 - assert not head.concat_input and not hasattr(head, 'conv_cat') - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # test concat_input=True - inputs = [torch.randn(1, 32, 45, 45)] - head = FCNHead( - in_channels=32, channels=16, num_classes=19, concat_input=True) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert len(head.convs) == 2 - assert head.concat_input - assert head.conv_cat.in_channels == 48 - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # test kernel_size=3 - inputs = [torch.randn(1, 32, 45, 45)] - head = FCNHead(in_channels=32, channels=16, num_classes=19) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - for i in range(len(head.convs)): - assert head.convs[i].kernel_size == (3, 3) - assert head.convs[i].padding == 1 - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # test kernel_size=1 - inputs = [torch.randn(1, 32, 45, 45)] - head = FCNHead(in_channels=32, channels=16, num_classes=19, kernel_size=1) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - for i in range(len(head.convs)): - assert head.convs[i].kernel_size == (1, 1) - assert head.convs[i].padding == 0 - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # test num_conv - inputs = [torch.randn(1, 32, 45, 45)] - head = FCNHead(in_channels=32, channels=16, num_classes=19, num_convs=1) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert len(head.convs) == 1 - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # test num_conv = 0 - inputs = [torch.randn(1, 32, 45, 45)] - head = FCNHead( - in_channels=32, - channels=32, - num_classes=19, - num_convs=0, - concat_input=False) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert isinstance(head.convs, torch.nn.Identity) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_psp_head(): - - with pytest.raises(AssertionError): - # pool_scales must be list|tuple - PSPHead(in_channels=32, channels=16, num_classes=19, pool_scales=1) - - # test no norm_cfg - head = PSPHead(in_channels=32, channels=16, num_classes=19) - assert not _conv_has_norm(head, sync_bn=False) - - # test with norm_cfg - head = PSPHead( - in_channels=32, - channels=16, - num_classes=19, - norm_cfg=dict(type='SyncBN')) - assert _conv_has_norm(head, sync_bn=True) - - inputs = [torch.randn(1, 32, 45, 45)] - head = PSPHead( - in_channels=32, channels=16, num_classes=19, pool_scales=(1, 2, 3)) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.psp_modules[0][0].output_size == 1 - assert head.psp_modules[1][0].output_size == 2 - assert head.psp_modules[2][0].output_size == 3 - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_apc_head(): - - with pytest.raises(AssertionError): - # pool_scales must be list|tuple - APCHead(in_channels=32, channels=16, num_classes=19, pool_scales=1) - - # test no norm_cfg - head = APCHead(in_channels=32, channels=16, num_classes=19) - assert not _conv_has_norm(head, sync_bn=False) - - # test with norm_cfg - head = APCHead( - in_channels=32, - channels=16, - num_classes=19, - norm_cfg=dict(type='SyncBN')) - assert _conv_has_norm(head, sync_bn=True) - - # fusion=True - inputs = [torch.randn(1, 32, 45, 45)] - head = APCHead( - in_channels=32, - channels=16, - num_classes=19, - pool_scales=(1, 2, 3), - fusion=True) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.fusion is True - assert head.acm_modules[0].pool_scale == 1 - assert head.acm_modules[1].pool_scale == 2 - assert head.acm_modules[2].pool_scale == 3 - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # fusion=False - inputs = [torch.randn(1, 32, 45, 45)] - head = APCHead( - in_channels=32, - channels=16, - num_classes=19, - pool_scales=(1, 2, 3), - fusion=False) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.fusion is False - assert head.acm_modules[0].pool_scale == 1 - assert head.acm_modules[1].pool_scale == 2 - assert head.acm_modules[2].pool_scale == 3 - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_dm_head(): - - with pytest.raises(AssertionError): - # filter_sizes must be list|tuple - DMHead(in_channels=32, channels=16, num_classes=19, filter_sizes=1) - - # test no norm_cfg - head = DMHead(in_channels=32, channels=16, num_classes=19) - assert not _conv_has_norm(head, sync_bn=False) - - # test with norm_cfg - head = DMHead( - in_channels=32, - channels=16, - num_classes=19, - norm_cfg=dict(type='SyncBN')) - assert _conv_has_norm(head, sync_bn=True) - - # fusion=True - inputs = [torch.randn(1, 32, 45, 45)] - head = DMHead( - in_channels=32, - channels=16, - num_classes=19, - filter_sizes=(1, 3, 5), - fusion=True) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.fusion is True - assert head.dcm_modules[0].filter_size == 1 - assert head.dcm_modules[1].filter_size == 3 - assert head.dcm_modules[2].filter_size == 5 - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # fusion=False - inputs = [torch.randn(1, 32, 45, 45)] - head = DMHead( - in_channels=32, - channels=16, - num_classes=19, - filter_sizes=(1, 3, 5), - fusion=False) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.fusion is False - assert head.dcm_modules[0].filter_size == 1 - assert head.dcm_modules[1].filter_size == 3 - assert head.dcm_modules[2].filter_size == 5 - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_aspp_head(): - - with pytest.raises(AssertionError): - # pool_scales must be list|tuple - ASPPHead(in_channels=32, channels=16, num_classes=19, dilations=1) - - # test no norm_cfg - head = ASPPHead(in_channels=32, channels=16, num_classes=19) - assert not _conv_has_norm(head, sync_bn=False) - - # test with norm_cfg - head = ASPPHead( - in_channels=32, - channels=16, - num_classes=19, - norm_cfg=dict(type='SyncBN')) - assert _conv_has_norm(head, sync_bn=True) - - inputs = [torch.randn(1, 32, 45, 45)] - head = ASPPHead( - in_channels=32, channels=16, num_classes=19, dilations=(1, 12, 24)) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.aspp_modules[0].conv.dilation == (1, 1) - assert head.aspp_modules[1].conv.dilation == (12, 12) - assert head.aspp_modules[2].conv.dilation == (24, 24) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_psa_head(): - - with pytest.raises(AssertionError): - # psa_type must be in 'bi-direction', 'collect', 'distribute' - PSAHead( - in_channels=32, - channels=16, - num_classes=19, - mask_size=(39, 39), - psa_type='gather') - - # test no norm_cfg - head = PSAHead( - in_channels=32, channels=16, num_classes=19, mask_size=(39, 39)) - assert not _conv_has_norm(head, sync_bn=False) - - # test with norm_cfg - head = PSAHead( - in_channels=32, - channels=16, - num_classes=19, - mask_size=(39, 39), - norm_cfg=dict(type='SyncBN')) - assert _conv_has_norm(head, sync_bn=True) - - # test 'bi-direction' psa_type - inputs = [torch.randn(1, 32, 39, 39)] - head = PSAHead( - in_channels=32, channels=16, num_classes=19, mask_size=(39, 39)) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 39, 39) - - # test 'bi-direction' psa_type, shrink_factor=1 - inputs = [torch.randn(1, 32, 39, 39)] - head = PSAHead( - in_channels=32, - channels=16, - num_classes=19, - mask_size=(39, 39), - shrink_factor=1) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 39, 39) - - # test 'bi-direction' psa_type with soft_max - inputs = [torch.randn(1, 32, 39, 39)] - head = PSAHead( - in_channels=32, - channels=16, - num_classes=19, - mask_size=(39, 39), - psa_softmax=True) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 39, 39) - - # test 'collect' psa_type - inputs = [torch.randn(1, 32, 39, 39)] - head = PSAHead( - in_channels=32, - channels=16, - num_classes=19, - mask_size=(39, 39), - psa_type='collect') - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 39, 39) - - # test 'collect' psa_type, shrink_factor=1 - inputs = [torch.randn(1, 32, 39, 39)] - head = PSAHead( - in_channels=32, - channels=16, - num_classes=19, - mask_size=(39, 39), - shrink_factor=1, - psa_type='collect') - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 39, 39) - - # test 'collect' psa_type, shrink_factor=1, compact=True - inputs = [torch.randn(1, 32, 39, 39)] - head = PSAHead( - in_channels=32, - channels=16, - num_classes=19, - mask_size=(39, 39), - psa_type='collect', - shrink_factor=1, - compact=True) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 39, 39) - - # test 'distribute' psa_type - inputs = [torch.randn(1, 32, 39, 39)] - head = PSAHead( - in_channels=32, - channels=16, - num_classes=19, - mask_size=(39, 39), - psa_type='distribute') - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 39, 39) - - -def test_gc_head(): - head = GCHead(in_channels=32, channels=16, num_classes=19) - assert len(head.convs) == 2 - assert hasattr(head, 'gc_block') - inputs = [torch.randn(1, 32, 45, 45)] - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_nl_head(): - head = NLHead(in_channels=32, channels=16, num_classes=19) - assert len(head.convs) == 2 - assert hasattr(head, 'nl_block') - inputs = [torch.randn(1, 32, 45, 45)] - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_cc_head(): - head = CCHead(in_channels=32, channels=16, num_classes=19) - assert len(head.convs) == 2 - assert hasattr(head, 'cca') - if not torch.cuda.is_available(): - pytest.skip('CCHead requires CUDA') - inputs = [torch.randn(1, 32, 45, 45)] - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_uper_head(): - - with pytest.raises(AssertionError): - # fpn_in_channels must be list|tuple - UPerHead(in_channels=32, channels=16, num_classes=19) - - # test no norm_cfg - head = UPerHead( - in_channels=[32, 16], channels=16, num_classes=19, in_index=[-2, -1]) - assert not _conv_has_norm(head, sync_bn=False) - - # test with norm_cfg - head = UPerHead( - in_channels=[32, 16], - channels=16, - num_classes=19, - norm_cfg=dict(type='SyncBN'), - in_index=[-2, -1]) - assert _conv_has_norm(head, sync_bn=True) - - inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)] - head = UPerHead( - in_channels=[32, 16], channels=16, num_classes=19, in_index=[-2, -1]) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_ann_head(): - - inputs = [torch.randn(1, 16, 45, 45), torch.randn(1, 32, 21, 21)] - head = ANNHead( - in_channels=[16, 32], - channels=16, - num_classes=19, - in_index=[-2, -1], - project_channels=8) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 21, 21) - - -def test_da_head(): - - inputs = [torch.randn(1, 32, 45, 45)] - head = DAHead(in_channels=32, channels=16, num_classes=19, pam_channels=8) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert isinstance(outputs, tuple) and len(outputs) == 3 - for output in outputs: - assert output.shape == (1, head.num_classes, 45, 45) - test_output = head.forward_test(inputs, None, None) - assert test_output.shape == (1, head.num_classes, 45, 45) - - -def test_ocr_head(): - - inputs = [torch.randn(1, 32, 45, 45)] - ocr_head = OCRHead( - in_channels=32, channels=16, num_classes=19, ocr_channels=8) - fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19) - if torch.cuda.is_available(): - head, inputs = to_cuda(ocr_head, inputs) - head, inputs = to_cuda(fcn_head, inputs) - prev_output = fcn_head(inputs) - output = ocr_head(inputs, prev_output) - assert output.shape == (1, ocr_head.num_classes, 45, 45) - - -def test_enc_head(): - # with se_loss, w.o. lateral - inputs = [torch.randn(1, 32, 21, 21)] - head = EncHead( - in_channels=[32], channels=16, num_classes=19, in_index=[-1]) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert isinstance(outputs, tuple) and len(outputs) == 2 - assert outputs[0].shape == (1, head.num_classes, 21, 21) - assert outputs[1].shape == (1, head.num_classes) - - # w.o se_loss, w.o. lateral - inputs = [torch.randn(1, 32, 21, 21)] - head = EncHead( - in_channels=[32], - channels=16, - use_se_loss=False, - num_classes=19, - in_index=[-1]) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 21, 21) - - # with se_loss, with lateral - inputs = [torch.randn(1, 16, 45, 45), torch.randn(1, 32, 21, 21)] - head = EncHead( - in_channels=[16, 32], - channels=16, - add_lateral=True, - num_classes=19, - in_index=[-2, -1]) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert isinstance(outputs, tuple) and len(outputs) == 2 - assert outputs[0].shape == (1, head.num_classes, 21, 21) - assert outputs[1].shape == (1, head.num_classes) - test_output = head.forward_test(inputs, None, None) - assert test_output.shape == (1, head.num_classes, 21, 21) - - -def test_dw_aspp_head(): - - # test w.o. c1 - inputs = [torch.randn(1, 32, 45, 45)] - head = DepthwiseSeparableASPPHead( - c1_in_channels=0, - c1_channels=0, - in_channels=32, - channels=16, - num_classes=19, - dilations=(1, 12, 24)) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.c1_bottleneck is None - assert head.aspp_modules[0].conv.dilation == (1, 1) - assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12) - assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # test with c1 - inputs = [torch.randn(1, 8, 45, 45), torch.randn(1, 32, 21, 21)] - head = DepthwiseSeparableASPPHead( - c1_in_channels=8, - c1_channels=4, - in_channels=32, - channels=16, - num_classes=19, - dilations=(1, 12, 24)) - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - assert head.c1_bottleneck.in_channels == 8 - assert head.c1_bottleneck.out_channels == 4 - assert head.aspp_modules[0].conv.dilation == (1, 1) - assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12) - assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_sep_fcn_head(): - # test sep_fcn_head with concat_input=False - head = DepthwiseSeparableFCNHead( - in_channels=128, - channels=128, - concat_input=False, - num_classes=19, - in_index=-1, - norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) - x = [torch.rand(2, 128, 32, 32)] - output = head(x) - assert output.shape == (2, head.num_classes, 32, 32) - assert not head.concat_input - assert isinstance(head.convs[0], DepthwiseSeparableConvModule) - assert isinstance(head.convs[1], DepthwiseSeparableConvModule) - assert head.conv_seg.kernel_size == (1, 1) - - head = DepthwiseSeparableFCNHead( - in_channels=64, - channels=64, - concat_input=True, - num_classes=19, - in_index=-1, - norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) - x = [torch.rand(3, 64, 32, 32)] - output = head(x) - assert output.shape == (3, head.num_classes, 32, 32) - assert head.concat_input - assert isinstance(head.convs[0], DepthwiseSeparableConvModule) - assert isinstance(head.convs[1], DepthwiseSeparableConvModule) - - -def test_dnl_head(): - # DNL with 'embedded_gaussian' mode - head = DNLHead(in_channels=32, channels=16, num_classes=19) - assert len(head.convs) == 2 - assert hasattr(head, 'dnl_block') - assert head.dnl_block.temperature == 0.05 - inputs = [torch.randn(1, 32, 45, 45)] - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # NonLocal2d with 'dot_product' mode - head = DNLHead( - in_channels=32, channels=16, num_classes=19, mode='dot_product') - inputs = [torch.randn(1, 32, 45, 45)] - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # NonLocal2d with 'gaussian' mode - head = DNLHead( - in_channels=32, channels=16, num_classes=19, mode='gaussian') - inputs = [torch.randn(1, 32, 45, 45)] - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - # NonLocal2d with 'concatenation' mode - head = DNLHead( - in_channels=32, channels=16, num_classes=19, mode='concatenation') - inputs = [torch.randn(1, 32, 45, 45)] - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_emanet_head(): - head = EMAHead( - in_channels=32, - ema_channels=24, - channels=16, - num_stages=3, - num_bases=16, - num_classes=19) - for param in head.ema_mid_conv.parameters(): - assert not param.requires_grad - assert hasattr(head, 'ema_module') - inputs = [torch.randn(1, 32, 45, 45)] - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_point_head(): - - inputs = [torch.randn(1, 32, 45, 45)] - point_head = PointHead( - in_channels=[32], in_index=[0], channels=16, num_classes=19) - assert len(point_head.fcs) == 3 - fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19) - if torch.cuda.is_available(): - head, inputs = to_cuda(point_head, inputs) - head, inputs = to_cuda(fcn_head, inputs) - prev_output = fcn_head(inputs) - test_cfg = ConfigDict( - subdivision_steps=2, subdivision_num_points=8196, scale_factor=2) - output = point_head.forward_test(inputs, prev_output, None, test_cfg) - assert output.shape == (1, point_head.num_classes, 180, 180) - - -def test_lraspp_head(): - with pytest.raises(ValueError): - # check invalid input_transform - LRASPPHead( - in_channels=(16, 16, 576), - in_index=(0, 1, 2), - channels=128, - input_transform='resize_concat', - dropout_ratio=0.1, - num_classes=19, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - align_corners=False, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) - - with pytest.raises(AssertionError): - # check invalid branch_channels - LRASPPHead( - in_channels=(16, 16, 576), - in_index=(0, 1, 2), - channels=128, - branch_channels=64, - input_transform='multiple_select', - dropout_ratio=0.1, - num_classes=19, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - align_corners=False, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) - - # test with default settings - lraspp_head = LRASPPHead( - in_channels=(16, 16, 576), - in_index=(0, 1, 2), - channels=128, - input_transform='multiple_select', - dropout_ratio=0.1, - num_classes=19, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - align_corners=False, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) - inputs = [ - torch.randn(2, 16, 45, 45), - torch.randn(2, 16, 28, 28), - torch.randn(2, 576, 14, 14) - ] - with pytest.raises(RuntimeError): - # check invalid inputs - output = lraspp_head(inputs) - - inputs = [ - torch.randn(2, 16, 111, 111), - torch.randn(2, 16, 77, 77), - torch.randn(2, 576, 55, 55) - ] - output = lraspp_head(inputs) - assert output.shape == (2, 19, 111, 111) diff --git a/tests/test_models/test_heads/__init__.py b/tests/test_models/test_heads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_models/test_heads/test_ann_head.py b/tests/test_models/test_heads/test_ann_head.py new file mode 100644 index 000000000..61556c0a0 --- /dev/null +++ b/tests/test_models/test_heads/test_ann_head.py @@ -0,0 +1,19 @@ +import torch + +from mmseg.models.decode_heads import ANNHead +from .utils import to_cuda + + +def test_ann_head(): + + inputs = [torch.randn(1, 16, 45, 45), torch.randn(1, 32, 21, 21)] + head = ANNHead( + in_channels=[16, 32], + channels=16, + num_classes=19, + in_index=[-2, -1], + project_channels=8) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 21, 21) diff --git a/tests/test_models/test_heads/test_apc_head.py b/tests/test_models/test_heads/test_apc_head.py new file mode 100644 index 000000000..37f1a559b --- /dev/null +++ b/tests/test_models/test_heads/test_apc_head.py @@ -0,0 +1,58 @@ +import pytest +import torch + +from mmseg.models.decode_heads import APCHead +from .utils import _conv_has_norm, to_cuda + + +def test_apc_head(): + + with pytest.raises(AssertionError): + # pool_scales must be list|tuple + APCHead(in_channels=32, channels=16, num_classes=19, pool_scales=1) + + # test no norm_cfg + head = APCHead(in_channels=32, channels=16, num_classes=19) + assert not _conv_has_norm(head, sync_bn=False) + + # test with norm_cfg + head = APCHead( + in_channels=32, + channels=16, + num_classes=19, + norm_cfg=dict(type='SyncBN')) + assert _conv_has_norm(head, sync_bn=True) + + # fusion=True + inputs = [torch.randn(1, 32, 45, 45)] + head = APCHead( + in_channels=32, + channels=16, + num_classes=19, + pool_scales=(1, 2, 3), + fusion=True) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.fusion is True + assert head.acm_modules[0].pool_scale == 1 + assert head.acm_modules[1].pool_scale == 2 + assert head.acm_modules[2].pool_scale == 3 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # fusion=False + inputs = [torch.randn(1, 32, 45, 45)] + head = APCHead( + in_channels=32, + channels=16, + num_classes=19, + pool_scales=(1, 2, 3), + fusion=False) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.fusion is False + assert head.acm_modules[0].pool_scale == 1 + assert head.acm_modules[1].pool_scale == 2 + assert head.acm_modules[2].pool_scale == 3 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_aspp_head.py b/tests/test_models/test_heads/test_aspp_head.py new file mode 100644 index 000000000..bd4ce56a3 --- /dev/null +++ b/tests/test_models/test_heads/test_aspp_head.py @@ -0,0 +1,75 @@ +import pytest +import torch + +from mmseg.models.decode_heads import ASPPHead, DepthwiseSeparableASPPHead +from .utils import _conv_has_norm, to_cuda + + +def test_aspp_head(): + + with pytest.raises(AssertionError): + # pool_scales must be list|tuple + ASPPHead(in_channels=32, channels=16, num_classes=19, dilations=1) + + # test no norm_cfg + head = ASPPHead(in_channels=32, channels=16, num_classes=19) + assert not _conv_has_norm(head, sync_bn=False) + + # test with norm_cfg + head = ASPPHead( + in_channels=32, + channels=16, + num_classes=19, + norm_cfg=dict(type='SyncBN')) + assert _conv_has_norm(head, sync_bn=True) + + inputs = [torch.randn(1, 32, 45, 45)] + head = ASPPHead( + in_channels=32, channels=16, num_classes=19, dilations=(1, 12, 24)) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.aspp_modules[0].conv.dilation == (1, 1) + assert head.aspp_modules[1].conv.dilation == (12, 12) + assert head.aspp_modules[2].conv.dilation == (24, 24) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + +def test_dw_aspp_head(): + + # test w.o. c1 + inputs = [torch.randn(1, 32, 45, 45)] + head = DepthwiseSeparableASPPHead( + c1_in_channels=0, + c1_channels=0, + in_channels=32, + channels=16, + num_classes=19, + dilations=(1, 12, 24)) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.c1_bottleneck is None + assert head.aspp_modules[0].conv.dilation == (1, 1) + assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12) + assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # test with c1 + inputs = [torch.randn(1, 8, 45, 45), torch.randn(1, 32, 21, 21)] + head = DepthwiseSeparableASPPHead( + c1_in_channels=8, + c1_channels=4, + in_channels=32, + channels=16, + num_classes=19, + dilations=(1, 12, 24)) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.c1_bottleneck.in_channels == 8 + assert head.c1_bottleneck.out_channels == 4 + assert head.aspp_modules[0].conv.dilation == (1, 1) + assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12) + assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_cc_head.py b/tests/test_models/test_heads/test_cc_head.py new file mode 100644 index 000000000..12a19bf0a --- /dev/null +++ b/tests/test_models/test_heads/test_cc_head.py @@ -0,0 +1,17 @@ +import pytest +import torch + +from mmseg.models.decode_heads import CCHead +from .utils import to_cuda + + +def test_cc_head(): + head = CCHead(in_channels=32, channels=16, num_classes=19) + assert len(head.convs) == 2 + assert hasattr(head, 'cca') + if not torch.cuda.is_available(): + pytest.skip('CCHead requires CUDA') + inputs = [torch.randn(1, 32, 45, 45)] + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_da_head.py b/tests/test_models/test_heads/test_da_head.py new file mode 100644 index 000000000..20f3a2181 --- /dev/null +++ b/tests/test_models/test_heads/test_da_head.py @@ -0,0 +1,18 @@ +import torch + +from mmseg.models.decode_heads import DAHead +from .utils import to_cuda + + +def test_da_head(): + + inputs = [torch.randn(1, 32, 45, 45)] + head = DAHead(in_channels=32, channels=16, num_classes=19, pam_channels=8) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert isinstance(outputs, tuple) and len(outputs) == 3 + for output in outputs: + assert output.shape == (1, head.num_classes, 45, 45) + test_output = head.forward_test(inputs, None, None) + assert test_output.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_decode_head.py b/tests/test_models/test_heads/test_decode_head.py new file mode 100644 index 000000000..97262b92c --- /dev/null +++ b/tests/test_models/test_heads/test_decode_head.py @@ -0,0 +1,75 @@ +from unittest.mock import patch + +import pytest +import torch + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from .utils import to_cuda + + +@patch.multiple(BaseDecodeHead, __abstractmethods__=set()) +def test_decode_head(): + + with pytest.raises(AssertionError): + # default input_transform doesn't accept multiple inputs + BaseDecodeHead([32, 16], 16, num_classes=19) + + with pytest.raises(AssertionError): + # default input_transform doesn't accept multiple inputs + BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2]) + + with pytest.raises(AssertionError): + # supported mode is resize_concat only + BaseDecodeHead(32, 16, num_classes=19, input_transform='concat') + + with pytest.raises(AssertionError): + # in_channels should be list|tuple + BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat') + + with pytest.raises(AssertionError): + # in_index should be list|tuple + BaseDecodeHead([32], + 16, + in_index=-1, + num_classes=19, + input_transform='resize_concat') + + with pytest.raises(AssertionError): + # len(in_index) should equal len(in_channels) + BaseDecodeHead([32, 16], + 16, + num_classes=19, + in_index=[-1], + input_transform='resize_concat') + + # test default dropout + head = BaseDecodeHead(32, 16, num_classes=19) + assert hasattr(head, 'dropout') and head.dropout.p == 0.1 + + # test set dropout + head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2) + assert hasattr(head, 'dropout') and head.dropout.p == 0.2 + + # test no input_transform + inputs = [torch.randn(1, 32, 45, 45)] + head = BaseDecodeHead(32, 16, num_classes=19) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.in_channels == 32 + assert head.input_transform is None + transformed_inputs = head._transform_inputs(inputs) + assert transformed_inputs.shape == (1, 32, 45, 45) + + # test input_transform = resize_concat + inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)] + head = BaseDecodeHead([32, 16], + 16, + num_classes=19, + in_index=[0, 1], + input_transform='resize_concat') + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.in_channels == 48 + assert head.input_transform == 'resize_concat' + transformed_inputs = head._transform_inputs(inputs) + assert transformed_inputs.shape == (1, 48, 45, 45) diff --git a/tests/test_models/test_heads/test_dm_head.py b/tests/test_models/test_heads/test_dm_head.py new file mode 100644 index 000000000..e85127b30 --- /dev/null +++ b/tests/test_models/test_heads/test_dm_head.py @@ -0,0 +1,58 @@ +import pytest +import torch + +from mmseg.models.decode_heads import DMHead +from .utils import _conv_has_norm, to_cuda + + +def test_dm_head(): + + with pytest.raises(AssertionError): + # filter_sizes must be list|tuple + DMHead(in_channels=32, channels=16, num_classes=19, filter_sizes=1) + + # test no norm_cfg + head = DMHead(in_channels=32, channels=16, num_classes=19) + assert not _conv_has_norm(head, sync_bn=False) + + # test with norm_cfg + head = DMHead( + in_channels=32, + channels=16, + num_classes=19, + norm_cfg=dict(type='SyncBN')) + assert _conv_has_norm(head, sync_bn=True) + + # fusion=True + inputs = [torch.randn(1, 32, 45, 45)] + head = DMHead( + in_channels=32, + channels=16, + num_classes=19, + filter_sizes=(1, 3, 5), + fusion=True) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.fusion is True + assert head.dcm_modules[0].filter_size == 1 + assert head.dcm_modules[1].filter_size == 3 + assert head.dcm_modules[2].filter_size == 5 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # fusion=False + inputs = [torch.randn(1, 32, 45, 45)] + head = DMHead( + in_channels=32, + channels=16, + num_classes=19, + filter_sizes=(1, 3, 5), + fusion=False) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.fusion is False + assert head.dcm_modules[0].filter_size == 1 + assert head.dcm_modules[1].filter_size == 3 + assert head.dcm_modules[2].filter_size == 5 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_dnl_head.py b/tests/test_models/test_heads/test_dnl_head.py new file mode 100644 index 000000000..b3e98aa27 --- /dev/null +++ b/tests/test_models/test_heads/test_dnl_head.py @@ -0,0 +1,44 @@ +import torch + +from mmseg.models.decode_heads import DNLHead +from .utils import to_cuda + + +def test_dnl_head(): + # DNL with 'embedded_gaussian' mode + head = DNLHead(in_channels=32, channels=16, num_classes=19) + assert len(head.convs) == 2 + assert hasattr(head, 'dnl_block') + assert head.dnl_block.temperature == 0.05 + inputs = [torch.randn(1, 32, 45, 45)] + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # NonLocal2d with 'dot_product' mode + head = DNLHead( + in_channels=32, channels=16, num_classes=19, mode='dot_product') + inputs = [torch.randn(1, 32, 45, 45)] + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # NonLocal2d with 'gaussian' mode + head = DNLHead( + in_channels=32, channels=16, num_classes=19, mode='gaussian') + inputs = [torch.randn(1, 32, 45, 45)] + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # NonLocal2d with 'concatenation' mode + head = DNLHead( + in_channels=32, channels=16, num_classes=19, mode='concatenation') + inputs = [torch.randn(1, 32, 45, 45)] + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_ema_head.py b/tests/test_models/test_heads/test_ema_head.py new file mode 100644 index 000000000..4214b0c96 --- /dev/null +++ b/tests/test_models/test_heads/test_ema_head.py @@ -0,0 +1,22 @@ +import torch + +from mmseg.models.decode_heads import EMAHead +from .utils import to_cuda + + +def test_emanet_head(): + head = EMAHead( + in_channels=32, + ema_channels=24, + channels=16, + num_stages=3, + num_bases=16, + num_classes=19) + for param in head.ema_mid_conv.parameters(): + assert not param.requires_grad + assert hasattr(head, 'ema_module') + inputs = [torch.randn(1, 32, 45, 45)] + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_enc_head.py b/tests/test_models/test_heads/test_enc_head.py new file mode 100644 index 000000000..3a293300f --- /dev/null +++ b/tests/test_models/test_heads/test_enc_head.py @@ -0,0 +1,47 @@ +import torch + +from mmseg.models.decode_heads import EncHead +from .utils import to_cuda + + +def test_enc_head(): + # with se_loss, w.o. lateral + inputs = [torch.randn(1, 32, 21, 21)] + head = EncHead( + in_channels=[32], channels=16, num_classes=19, in_index=[-1]) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert isinstance(outputs, tuple) and len(outputs) == 2 + assert outputs[0].shape == (1, head.num_classes, 21, 21) + assert outputs[1].shape == (1, head.num_classes) + + # w.o se_loss, w.o. lateral + inputs = [torch.randn(1, 32, 21, 21)] + head = EncHead( + in_channels=[32], + channels=16, + use_se_loss=False, + num_classes=19, + in_index=[-1]) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 21, 21) + + # with se_loss, with lateral + inputs = [torch.randn(1, 16, 45, 45), torch.randn(1, 32, 21, 21)] + head = EncHead( + in_channels=[16, 32], + channels=16, + add_lateral=True, + num_classes=19, + in_index=[-2, -1]) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert isinstance(outputs, tuple) and len(outputs) == 2 + assert outputs[0].shape == (1, head.num_classes, 21, 21) + assert outputs[1].shape == (1, head.num_classes) + test_output = head.forward_test(inputs, None, None) + assert test_output.shape == (1, head.num_classes, 21, 21) diff --git a/tests/test_models/test_heads/test_fcn_head.py b/tests/test_models/test_heads/test_fcn_head.py new file mode 100644 index 000000000..24ae086d6 --- /dev/null +++ b/tests/test_models/test_heads/test_fcn_head.py @@ -0,0 +1,130 @@ +import pytest +import torch +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmcv.utils.parrots_wrapper import SyncBatchNorm + +from mmseg.models.decode_heads import DepthwiseSeparableFCNHead, FCNHead +from .utils import to_cuda + + +def test_fcn_head(): + + with pytest.raises(AssertionError): + # num_convs must be not less than 0 + FCNHead(num_classes=19, num_convs=-1) + + # test no norm_cfg + head = FCNHead(in_channels=32, channels=16, num_classes=19) + for m in head.modules(): + if isinstance(m, ConvModule): + assert not m.with_norm + + # test with norm_cfg + head = FCNHead( + in_channels=32, + channels=16, + num_classes=19, + norm_cfg=dict(type='SyncBN')) + for m in head.modules(): + if isinstance(m, ConvModule): + assert m.with_norm and isinstance(m.bn, SyncBatchNorm) + + # test concat_input=False + inputs = [torch.randn(1, 32, 45, 45)] + head = FCNHead( + in_channels=32, channels=16, num_classes=19, concat_input=False) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert len(head.convs) == 2 + assert not head.concat_input and not hasattr(head, 'conv_cat') + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # test concat_input=True + inputs = [torch.randn(1, 32, 45, 45)] + head = FCNHead( + in_channels=32, channels=16, num_classes=19, concat_input=True) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert len(head.convs) == 2 + assert head.concat_input + assert head.conv_cat.in_channels == 48 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # test kernel_size=3 + inputs = [torch.randn(1, 32, 45, 45)] + head = FCNHead(in_channels=32, channels=16, num_classes=19) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + for i in range(len(head.convs)): + assert head.convs[i].kernel_size == (3, 3) + assert head.convs[i].padding == 1 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # test kernel_size=1 + inputs = [torch.randn(1, 32, 45, 45)] + head = FCNHead(in_channels=32, channels=16, num_classes=19, kernel_size=1) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + for i in range(len(head.convs)): + assert head.convs[i].kernel_size == (1, 1) + assert head.convs[i].padding == 0 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # test num_conv + inputs = [torch.randn(1, 32, 45, 45)] + head = FCNHead(in_channels=32, channels=16, num_classes=19, num_convs=1) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert len(head.convs) == 1 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # test num_conv = 0 + inputs = [torch.randn(1, 32, 45, 45)] + head = FCNHead( + in_channels=32, + channels=32, + num_classes=19, + num_convs=0, + concat_input=False) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert isinstance(head.convs, torch.nn.Identity) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + +def test_sep_fcn_head(): + # test sep_fcn_head with concat_input=False + head = DepthwiseSeparableFCNHead( + in_channels=128, + channels=128, + concat_input=False, + num_classes=19, + in_index=-1, + norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) + x = [torch.rand(2, 128, 32, 32)] + output = head(x) + assert output.shape == (2, head.num_classes, 32, 32) + assert not head.concat_input + assert isinstance(head.convs[0], DepthwiseSeparableConvModule) + assert isinstance(head.convs[1], DepthwiseSeparableConvModule) + assert head.conv_seg.kernel_size == (1, 1) + + head = DepthwiseSeparableFCNHead( + in_channels=64, + channels=64, + concat_input=True, + num_classes=19, + in_index=-1, + norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) + x = [torch.rand(3, 64, 32, 32)] + output = head(x) + assert output.shape == (3, head.num_classes, 32, 32) + assert head.concat_input + assert isinstance(head.convs[0], DepthwiseSeparableConvModule) + assert isinstance(head.convs[1], DepthwiseSeparableConvModule) diff --git a/tests/test_models/test_heads/test_gc_head.py b/tests/test_models/test_heads/test_gc_head.py new file mode 100644 index 000000000..5201730b0 --- /dev/null +++ b/tests/test_models/test_heads/test_gc_head.py @@ -0,0 +1,15 @@ +import torch + +from mmseg.models.decode_heads import GCHead +from .utils import to_cuda + + +def test_gc_head(): + head = GCHead(in_channels=32, channels=16, num_classes=19) + assert len(head.convs) == 2 + assert hasattr(head, 'gc_block') + inputs = [torch.randn(1, 32, 45, 45)] + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_lraspp_head.py b/tests/test_models/test_heads/test_lraspp_head.py new file mode 100644 index 000000000..5031936c7 --- /dev/null +++ b/tests/test_models/test_heads/test_lraspp_head.py @@ -0,0 +1,67 @@ +import pytest +import torch + +from mmseg.models.decode_heads import LRASPPHead + + +def test_lraspp_head(): + with pytest.raises(ValueError): + # check invalid input_transform + LRASPPHead( + in_channels=(16, 16, 576), + in_index=(0, 1, 2), + channels=128, + input_transform='resize_concat', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + + with pytest.raises(AssertionError): + # check invalid branch_channels + LRASPPHead( + in_channels=(16, 16, 576), + in_index=(0, 1, 2), + channels=128, + branch_channels=64, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + + # test with default settings + lraspp_head = LRASPPHead( + in_channels=(16, 16, 576), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + inputs = [ + torch.randn(2, 16, 45, 45), + torch.randn(2, 16, 28, 28), + torch.randn(2, 576, 14, 14) + ] + with pytest.raises(RuntimeError): + # check invalid inputs + output = lraspp_head(inputs) + + inputs = [ + torch.randn(2, 16, 111, 111), + torch.randn(2, 16, 77, 77), + torch.randn(2, 576, 55, 55) + ] + output = lraspp_head(inputs) + assert output.shape == (2, 19, 111, 111) diff --git a/tests/test_models/test_heads/test_nl_head.py b/tests/test_models/test_heads/test_nl_head.py new file mode 100644 index 000000000..6f4bede5e --- /dev/null +++ b/tests/test_models/test_heads/test_nl_head.py @@ -0,0 +1,15 @@ +import torch + +from mmseg.models.decode_heads import NLHead +from .utils import to_cuda + + +def test_nl_head(): + head = NLHead(in_channels=32, channels=16, num_classes=19) + assert len(head.convs) == 2 + assert hasattr(head, 'nl_block') + inputs = [torch.randn(1, 32, 45, 45)] + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_ocr_head.py b/tests/test_models/test_heads/test_ocr_head.py new file mode 100644 index 000000000..bc2af75ad --- /dev/null +++ b/tests/test_models/test_heads/test_ocr_head.py @@ -0,0 +1,18 @@ +import torch + +from mmseg.models.decode_heads import FCNHead, OCRHead +from .utils import to_cuda + + +def test_ocr_head(): + + inputs = [torch.randn(1, 32, 45, 45)] + ocr_head = OCRHead( + in_channels=32, channels=16, num_classes=19, ocr_channels=8) + fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19) + if torch.cuda.is_available(): + head, inputs = to_cuda(ocr_head, inputs) + head, inputs = to_cuda(fcn_head, inputs) + prev_output = fcn_head(inputs) + output = ocr_head(inputs, prev_output) + assert output.shape == (1, ocr_head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_point_head.py b/tests/test_models/test_heads/test_point_head.py new file mode 100644 index 000000000..b54b979de --- /dev/null +++ b/tests/test_models/test_heads/test_point_head.py @@ -0,0 +1,22 @@ +import torch +from mmcv.utils import ConfigDict + +from mmseg.models.decode_heads import FCNHead, PointHead +from .utils import to_cuda + + +def test_point_head(): + + inputs = [torch.randn(1, 32, 45, 45)] + point_head = PointHead( + in_channels=[32], in_index=[0], channels=16, num_classes=19) + assert len(point_head.fcs) == 3 + fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19) + if torch.cuda.is_available(): + head, inputs = to_cuda(point_head, inputs) + head, inputs = to_cuda(fcn_head, inputs) + prev_output = fcn_head(inputs) + test_cfg = ConfigDict( + subdivision_steps=2, subdivision_num_points=8196, scale_factor=2) + output = point_head.forward_test(inputs, prev_output, None, test_cfg) + assert output.shape == (1, point_head.num_classes, 180, 180) diff --git a/tests/test_models/test_heads/test_psa_head.py b/tests/test_models/test_heads/test_psa_head.py new file mode 100644 index 000000000..d8f38b6aa --- /dev/null +++ b/tests/test_models/test_heads/test_psa_head.py @@ -0,0 +1,121 @@ +import pytest +import torch + +from mmseg.models.decode_heads import PSAHead +from .utils import _conv_has_norm, to_cuda + + +def test_psa_head(): + + with pytest.raises(AssertionError): + # psa_type must be in 'bi-direction', 'collect', 'distribute' + PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + psa_type='gather') + + # test no norm_cfg + head = PSAHead( + in_channels=32, channels=16, num_classes=19, mask_size=(39, 39)) + assert not _conv_has_norm(head, sync_bn=False) + + # test with norm_cfg + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + norm_cfg=dict(type='SyncBN')) + assert _conv_has_norm(head, sync_bn=True) + + # test 'bi-direction' psa_type + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, channels=16, num_classes=19, mask_size=(39, 39)) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'bi-direction' psa_type, shrink_factor=1 + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + shrink_factor=1) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'bi-direction' psa_type with soft_max + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + psa_softmax=True) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'collect' psa_type + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + psa_type='collect') + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'collect' psa_type, shrink_factor=1 + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + shrink_factor=1, + psa_type='collect') + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'collect' psa_type, shrink_factor=1, compact=True + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + psa_type='collect', + shrink_factor=1, + compact=True) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'distribute' psa_type + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + psa_type='distribute') + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) diff --git a/tests/test_models/test_heads/test_psp_head.py b/tests/test_models/test_heads/test_psp_head.py new file mode 100644 index 000000000..38b39d7ba --- /dev/null +++ b/tests/test_models/test_heads/test_psp_head.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from mmseg.models.decode_heads import PSPHead +from .utils import _conv_has_norm, to_cuda + + +def test_psp_head(): + + with pytest.raises(AssertionError): + # pool_scales must be list|tuple + PSPHead(in_channels=32, channels=16, num_classes=19, pool_scales=1) + + # test no norm_cfg + head = PSPHead(in_channels=32, channels=16, num_classes=19) + assert not _conv_has_norm(head, sync_bn=False) + + # test with norm_cfg + head = PSPHead( + in_channels=32, + channels=16, + num_classes=19, + norm_cfg=dict(type='SyncBN')) + assert _conv_has_norm(head, sync_bn=True) + + inputs = [torch.randn(1, 32, 45, 45)] + head = PSPHead( + in_channels=32, channels=16, num_classes=19, pool_scales=(1, 2, 3)) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.psp_modules[0][0].output_size == 1 + assert head.psp_modules[1][0].output_size == 2 + assert head.psp_modules[2][0].output_size == 3 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/test_uper_head.py b/tests/test_models/test_heads/test_uper_head.py new file mode 100644 index 000000000..2c66db892 --- /dev/null +++ b/tests/test_models/test_heads/test_uper_head.py @@ -0,0 +1,34 @@ +import pytest +import torch + +from mmseg.models.decode_heads import UPerHead +from .utils import _conv_has_norm, to_cuda + + +def test_uper_head(): + + with pytest.raises(AssertionError): + # fpn_in_channels must be list|tuple + UPerHead(in_channels=32, channels=16, num_classes=19) + + # test no norm_cfg + head = UPerHead( + in_channels=[32, 16], channels=16, num_classes=19, in_index=[-2, -1]) + assert not _conv_has_norm(head, sync_bn=False) + + # test with norm_cfg + head = UPerHead( + in_channels=[32, 16], + channels=16, + num_classes=19, + norm_cfg=dict(type='SyncBN'), + in_index=[-2, -1]) + assert _conv_has_norm(head, sync_bn=True) + + inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)] + head = UPerHead( + in_channels=[32, 16], channels=16, num_classes=19, in_index=[-2, -1]) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/utils.py b/tests/test_models/test_heads/utils.py new file mode 100644 index 000000000..1407f0a91 --- /dev/null +++ b/tests/test_models/test_heads/utils.py @@ -0,0 +1,21 @@ +from mmcv.cnn import ConvModule +from mmcv.utils.parrots_wrapper import SyncBatchNorm + + +def _conv_has_norm(module, sync_bn): + for m in module.modules(): + if isinstance(m, ConvModule): + if not m.with_norm: + return False + if sync_bn: + if not isinstance(m.bn, SyncBatchNorm): + return False + return True + + +def to_cuda(module, data): + module = module.cuda() + if isinstance(data, list): + for i in range(len(data)): + data[i] = data[i].cuda() + return module, data diff --git a/tests/test_models/test_losses.py b/tests/test_models/test_losses.py deleted file mode 100644 index c58e6a505..000000000 --- a/tests/test_models/test_losses.py +++ /dev/null @@ -1,233 +0,0 @@ -import numpy as np -import pytest -import torch - -from mmseg.models.losses import Accuracy, reduce_loss, weight_reduce_loss - - -def test_utils(): - loss = torch.rand(1, 3, 4, 4) - weight = torch.zeros(1, 3, 4, 4) - weight[:, :, :2, :2] = 1 - - # test reduce_loss() - reduced = reduce_loss(loss, 'none') - assert reduced is loss - - reduced = reduce_loss(loss, 'mean') - np.testing.assert_almost_equal(reduced.numpy(), loss.mean()) - - reduced = reduce_loss(loss, 'sum') - np.testing.assert_almost_equal(reduced.numpy(), loss.sum()) - - # test weight_reduce_loss() - reduced = weight_reduce_loss(loss, weight=None, reduction='none') - assert reduced is loss - - reduced = weight_reduce_loss(loss, weight=weight, reduction='mean') - target = (loss * weight).mean() - np.testing.assert_almost_equal(reduced.numpy(), target) - - reduced = weight_reduce_loss(loss, weight=weight, reduction='sum') - np.testing.assert_almost_equal(reduced.numpy(), (loss * weight).sum()) - - with pytest.raises(AssertionError): - weight_wrong = weight[0, 0, ...] - weight_reduce_loss(loss, weight=weight_wrong, reduction='mean') - - with pytest.raises(AssertionError): - weight_wrong = weight[:, 0:2, ...] - weight_reduce_loss(loss, weight=weight_wrong, reduction='mean') - - -def test_ce_loss(): - from mmseg.models import build_loss - - # use_mask and use_sigmoid cannot be true at the same time - with pytest.raises(AssertionError): - loss_cfg = dict( - type='CrossEntropyLoss', - use_mask=True, - use_sigmoid=True, - loss_weight=1.0) - build_loss(loss_cfg) - - # test loss with class weights - loss_cls_cfg = dict( - type='CrossEntropyLoss', - use_sigmoid=False, - class_weight=[0.8, 0.2], - loss_weight=1.0) - loss_cls = build_loss(loss_cls_cfg) - fake_pred = torch.Tensor([[100, -100]]) - fake_label = torch.Tensor([1]).long() - assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) - - loss_cls_cfg = dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) - loss_cls = build_loss(loss_cls_cfg) - assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) - - loss_cls_cfg = dict( - type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0) - loss_cls = build_loss(loss_cls_cfg) - assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.)) - - fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) - fake_label = torch.ones(2, 8, 8).long() - assert torch.allclose( - loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4) - fake_label[:, 0, 0] = 255 - assert torch.allclose( - loss_cls(fake_pred, fake_label, ignore_index=255), - torch.tensor(0.9354), - atol=1e-4) - - # TODO test use_mask - - -def test_accuracy(): - # test for empty pred - pred = torch.empty(0, 4) - label = torch.empty(0) - accuracy = Accuracy(topk=1) - acc = accuracy(pred, label) - assert acc.item() == 0 - - pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6], - [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1], - [0.0, 0.0, 0.99, 0]]) - # test for top1 - true_label = torch.Tensor([2, 3, 0, 1, 2]).long() - accuracy = Accuracy(topk=1) - acc = accuracy(pred, true_label) - assert acc.item() == 100 - - # test for top1 with score thresh=0.8 - true_label = torch.Tensor([2, 3, 0, 1, 2]).long() - accuracy = Accuracy(topk=1, thresh=0.8) - acc = accuracy(pred, true_label) - assert acc.item() == 40 - - # test for top2 - accuracy = Accuracy(topk=2) - label = torch.Tensor([3, 2, 0, 0, 2]).long() - acc = accuracy(pred, label) - assert acc.item() == 100 - - # test for both top1 and top2 - accuracy = Accuracy(topk=(1, 2)) - true_label = torch.Tensor([2, 3, 0, 1, 2]).long() - acc = accuracy(pred, true_label) - for a in acc: - assert a.item() == 100 - - # topk is larger than pred class number - with pytest.raises(AssertionError): - accuracy = Accuracy(topk=5) - accuracy(pred, true_label) - - # wrong topk type - with pytest.raises(AssertionError): - accuracy = Accuracy(topk='wrong type') - accuracy(pred, true_label) - - # label size is larger than required - with pytest.raises(AssertionError): - label = torch.Tensor([2, 3, 0, 1, 2, 0]).long() # size mismatch - accuracy = Accuracy() - accuracy(pred, label) - - # wrong pred dimension - with pytest.raises(AssertionError): - accuracy = Accuracy() - accuracy(pred[:, :, None], true_label) - - -def test_lovasz_loss(): - from mmseg.models import build_loss - - # loss_type should be 'binary' or 'multi_class' - with pytest.raises(AssertionError): - loss_cfg = dict( - type='LovaszLoss', - loss_type='Binary', - reduction='none', - loss_weight=1.0) - build_loss(loss_cfg) - - # reduction should be 'none' when per_image is False. - with pytest.raises(AssertionError): - loss_cfg = dict(type='LovaszLoss', loss_type='multi_class') - build_loss(loss_cfg) - - # test lovasz loss with loss_type = 'multi_class' and per_image = False - loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0) - lovasz_loss = build_loss(loss_cfg) - logits = torch.rand(1, 3, 4, 4) - labels = (torch.rand(1, 4, 4) * 2).long() - lovasz_loss(logits, labels) - - # test lovasz loss with loss_type = 'multi_class' and per_image = True - loss_cfg = dict( - type='LovaszLoss', - per_image=True, - reduction='mean', - class_weight=[1.0, 2.0, 3.0], - loss_weight=1.0) - lovasz_loss = build_loss(loss_cfg) - logits = torch.rand(1, 3, 4, 4) - labels = (torch.rand(1, 4, 4) * 2).long() - lovasz_loss(logits, labels, ignore_index=None) - - # test lovasz loss with loss_type = 'binary' and per_image = False - loss_cfg = dict( - type='LovaszLoss', - loss_type='binary', - reduction='none', - loss_weight=1.0) - lovasz_loss = build_loss(loss_cfg) - logits = torch.rand(2, 4, 4) - labels = (torch.rand(2, 4, 4)).long() - lovasz_loss(logits, labels) - - # test lovasz loss with loss_type = 'binary' and per_image = True - loss_cfg = dict( - type='LovaszLoss', - loss_type='binary', - per_image=True, - reduction='mean', - loss_weight=1.0) - lovasz_loss = build_loss(loss_cfg) - logits = torch.rand(2, 4, 4) - labels = (torch.rand(2, 4, 4)).long() - lovasz_loss(logits, labels, ignore_index=None) - - -def test_dice_lose(): - from mmseg.models import build_loss - - # test dice loss with loss_type = 'multi_class' - loss_cfg = dict( - type='DiceLoss', - reduction='none', - class_weight=[1.0, 2.0, 3.0], - loss_weight=1.0, - ignore_index=1) - dice_loss = build_loss(loss_cfg) - logits = torch.rand(8, 3, 4, 4) - labels = (torch.rand(8, 4, 4) * 3).long() - dice_loss(logits, labels) - - # test dice loss with loss_type = 'binary' - loss_cfg = dict( - type='DiceLoss', - smooth=2, - exponent=3, - reduction='sum', - loss_weight=1.0, - ignore_index=0) - dice_loss = build_loss(loss_cfg) - logits = torch.rand(8, 2, 4, 4) - labels = (torch.rand(8, 4, 4) * 2).long() - dice_loss(logits, labels) diff --git a/tests/test_models/test_losses/__init__.py b/tests/test_models/test_losses/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_models/test_losses/test_ce_loss.py b/tests/test_models/test_losses/test_ce_loss.py new file mode 100644 index 000000000..35ef84348 --- /dev/null +++ b/tests/test_models/test_losses/test_ce_loss.py @@ -0,0 +1,48 @@ +import pytest +import torch + + +def test_ce_loss(): + from mmseg.models import build_loss + + # use_mask and use_sigmoid cannot be true at the same time + with pytest.raises(AssertionError): + loss_cfg = dict( + type='CrossEntropyLoss', + use_mask=True, + use_sigmoid=True, + loss_weight=1.0) + build_loss(loss_cfg) + + # test loss with class weights + loss_cls_cfg = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=[0.8, 0.2], + loss_weight=1.0) + loss_cls = build_loss(loss_cls_cfg) + fake_pred = torch.Tensor([[100, -100]]) + fake_label = torch.Tensor([1]).long() + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) + + loss_cls_cfg = dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + loss_cls = build_loss(loss_cls_cfg) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) + + loss_cls_cfg = dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0) + loss_cls = build_loss(loss_cls_cfg) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.)) + + fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) + fake_label = torch.ones(2, 8, 8).long() + assert torch.allclose( + loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4) + fake_label[:, 0, 0] = 255 + assert torch.allclose( + loss_cls(fake_pred, fake_label, ignore_index=255), + torch.tensor(0.9354), + atol=1e-4) + + # TODO test use_mask diff --git a/tests/test_models/test_losses/test_dice_loss.py b/tests/test_models/test_losses/test_dice_loss.py new file mode 100644 index 000000000..94b9faab7 --- /dev/null +++ b/tests/test_models/test_losses/test_dice_loss.py @@ -0,0 +1,30 @@ +import torch + + +def test_dice_lose(): + from mmseg.models import build_loss + + # test dice loss with loss_type = 'multi_class' + loss_cfg = dict( + type='DiceLoss', + reduction='none', + class_weight=[1.0, 2.0, 3.0], + loss_weight=1.0, + ignore_index=1) + dice_loss = build_loss(loss_cfg) + logits = torch.rand(8, 3, 4, 4) + labels = (torch.rand(8, 4, 4) * 3).long() + dice_loss(logits, labels) + + # test dice loss with loss_type = 'binary' + loss_cfg = dict( + type='DiceLoss', + smooth=2, + exponent=3, + reduction='sum', + loss_weight=1.0, + ignore_index=0) + dice_loss = build_loss(loss_cfg) + logits = torch.rand(8, 2, 4, 4) + labels = (torch.rand(8, 4, 4) * 2).long() + dice_loss(logits, labels) diff --git a/tests/test_models/test_losses/test_lovasz_loss.py b/tests/test_models/test_losses/test_lovasz_loss.py new file mode 100644 index 000000000..e11dd613f --- /dev/null +++ b/tests/test_models/test_losses/test_lovasz_loss.py @@ -0,0 +1,62 @@ +import pytest +import torch + + +def test_lovasz_loss(): + from mmseg.models import build_loss + + # loss_type should be 'binary' or 'multi_class' + with pytest.raises(AssertionError): + loss_cfg = dict( + type='LovaszLoss', + loss_type='Binary', + reduction='none', + loss_weight=1.0) + build_loss(loss_cfg) + + # reduction should be 'none' when per_image is False. + with pytest.raises(AssertionError): + loss_cfg = dict(type='LovaszLoss', loss_type='multi_class') + build_loss(loss_cfg) + + # test lovasz loss with loss_type = 'multi_class' and per_image = False + loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0) + lovasz_loss = build_loss(loss_cfg) + logits = torch.rand(1, 3, 4, 4) + labels = (torch.rand(1, 4, 4) * 2).long() + lovasz_loss(logits, labels) + + # test lovasz loss with loss_type = 'multi_class' and per_image = True + loss_cfg = dict( + type='LovaszLoss', + per_image=True, + reduction='mean', + class_weight=[1.0, 2.0, 3.0], + loss_weight=1.0) + lovasz_loss = build_loss(loss_cfg) + logits = torch.rand(1, 3, 4, 4) + labels = (torch.rand(1, 4, 4) * 2).long() + lovasz_loss(logits, labels, ignore_index=None) + + # test lovasz loss with loss_type = 'binary' and per_image = False + loss_cfg = dict( + type='LovaszLoss', + loss_type='binary', + reduction='none', + loss_weight=1.0) + lovasz_loss = build_loss(loss_cfg) + logits = torch.rand(2, 4, 4) + labels = (torch.rand(2, 4, 4)).long() + lovasz_loss(logits, labels) + + # test lovasz loss with loss_type = 'binary' and per_image = True + loss_cfg = dict( + type='LovaszLoss', + loss_type='binary', + per_image=True, + reduction='mean', + loss_weight=1.0) + lovasz_loss = build_loss(loss_cfg) + logits = torch.rand(2, 4, 4) + labels = (torch.rand(2, 4, 4)).long() + lovasz_loss(logits, labels, ignore_index=None) diff --git a/tests/test_models/test_losses/test_utils.py b/tests/test_models/test_losses/test_utils.py new file mode 100644 index 000000000..a5251e49f --- /dev/null +++ b/tests/test_models/test_losses/test_utils.py @@ -0,0 +1,98 @@ +import numpy as np +import pytest +import torch + +from mmseg.models.losses import Accuracy, reduce_loss, weight_reduce_loss + + +def test_weight_reduce_loss(): + loss = torch.rand(1, 3, 4, 4) + weight = torch.zeros(1, 3, 4, 4) + weight[:, :, :2, :2] = 1 + + # test reduce_loss() + reduced = reduce_loss(loss, 'none') + assert reduced is loss + + reduced = reduce_loss(loss, 'mean') + np.testing.assert_almost_equal(reduced.numpy(), loss.mean()) + + reduced = reduce_loss(loss, 'sum') + np.testing.assert_almost_equal(reduced.numpy(), loss.sum()) + + # test weight_reduce_loss() + reduced = weight_reduce_loss(loss, weight=None, reduction='none') + assert reduced is loss + + reduced = weight_reduce_loss(loss, weight=weight, reduction='mean') + target = (loss * weight).mean() + np.testing.assert_almost_equal(reduced.numpy(), target) + + reduced = weight_reduce_loss(loss, weight=weight, reduction='sum') + np.testing.assert_almost_equal(reduced.numpy(), (loss * weight).sum()) + + with pytest.raises(AssertionError): + weight_wrong = weight[0, 0, ...] + weight_reduce_loss(loss, weight=weight_wrong, reduction='mean') + + with pytest.raises(AssertionError): + weight_wrong = weight[:, 0:2, ...] + weight_reduce_loss(loss, weight=weight_wrong, reduction='mean') + + +def test_accuracy(): + # test for empty pred + pred = torch.empty(0, 4) + label = torch.empty(0) + accuracy = Accuracy(topk=1) + acc = accuracy(pred, label) + assert acc.item() == 0 + + pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6], + [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1], + [0.0, 0.0, 0.99, 0]]) + # test for top1 + true_label = torch.Tensor([2, 3, 0, 1, 2]).long() + accuracy = Accuracy(topk=1) + acc = accuracy(pred, true_label) + assert acc.item() == 100 + + # test for top1 with score thresh=0.8 + true_label = torch.Tensor([2, 3, 0, 1, 2]).long() + accuracy = Accuracy(topk=1, thresh=0.8) + acc = accuracy(pred, true_label) + assert acc.item() == 40 + + # test for top2 + accuracy = Accuracy(topk=2) + label = torch.Tensor([3, 2, 0, 0, 2]).long() + acc = accuracy(pred, label) + assert acc.item() == 100 + + # test for both top1 and top2 + accuracy = Accuracy(topk=(1, 2)) + true_label = torch.Tensor([2, 3, 0, 1, 2]).long() + acc = accuracy(pred, true_label) + for a in acc: + assert a.item() == 100 + + # topk is larger than pred class number + with pytest.raises(AssertionError): + accuracy = Accuracy(topk=5) + accuracy(pred, true_label) + + # wrong topk type + with pytest.raises(AssertionError): + accuracy = Accuracy(topk='wrong type') + accuracy(pred, true_label) + + # label size is larger than required + with pytest.raises(AssertionError): + label = torch.Tensor([2, 3, 0, 1, 2, 0]).long() # size mismatch + accuracy = Accuracy() + accuracy(pred, label) + + # wrong pred dimension + with pytest.raises(AssertionError): + accuracy = Accuracy() + accuracy(pred[:, :, None], true_label) diff --git a/tests/test_models/test_necks/__init__.py b/tests/test_models/test_necks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_models/test_necks.py b/tests/test_models/test_necks/test_fpn.py similarity index 100% rename from tests/test_models/test_necks.py rename to tests/test_models/test_necks/test_fpn.py diff --git a/tests/test_models/test_segmentors/__init__.py b/tests/test_models/test_segmentors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py b/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py new file mode 100644 index 000000000..142e81f12 --- /dev/null +++ b/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py @@ -0,0 +1,56 @@ +from mmcv import ConfigDict + +from mmseg.models import build_segmentor +from .utils import _segmentor_forward_train_test + + +def test_cascade_encoder_decoder(): + + # test 1 decode head, w.o. aux head + cfg = ConfigDict( + type='CascadeEncoderDecoder', + num_stages=2, + backbone=dict(type='ExampleBackbone'), + decode_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleCascadeDecodeHead') + ]) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test slide mode + cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test 1 decode head, 1 aux head + cfg = ConfigDict( + type='CascadeEncoderDecoder', + num_stages=2, + backbone=dict(type='ExampleBackbone'), + decode_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleCascadeDecodeHead') + ], + auxiliary_head=dict(type='ExampleDecodeHead')) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test 1 decode head, 2 aux head + cfg = ConfigDict( + type='CascadeEncoderDecoder', + num_stages=2, + backbone=dict(type='ExampleBackbone'), + decode_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleCascadeDecodeHead') + ], + auxiliary_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleDecodeHead') + ]) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) diff --git a/tests/test_models/test_segmentors/test_encoder_decoder.py b/tests/test_models/test_segmentors/test_encoder_decoder.py new file mode 100644 index 000000000..f40c4ea47 --- /dev/null +++ b/tests/test_models/test_segmentors/test_encoder_decoder.py @@ -0,0 +1,46 @@ +from mmcv import ConfigDict + +from mmseg.models import build_segmentor +from .utils import _segmentor_forward_train_test + + +def test_encoder_decoder(): + + # test 1 decode head, w.o. aux head + + cfg = ConfigDict( + type='EncoderDecoder', + backbone=dict(type='ExampleBackbone'), + decode_head=dict(type='ExampleDecodeHead'), + train_cfg=None, + test_cfg=dict(mode='whole')) + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test slide mode + cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test 1 decode head, 1 aux head + cfg = ConfigDict( + type='EncoderDecoder', + backbone=dict(type='ExampleBackbone'), + decode_head=dict(type='ExampleDecodeHead'), + auxiliary_head=dict(type='ExampleDecodeHead')) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test 1 decode head, 2 aux head + cfg = ConfigDict( + type='EncoderDecoder', + backbone=dict(type='ExampleBackbone'), + decode_head=dict(type='ExampleDecodeHead'), + auxiliary_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleDecodeHead') + ]) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) diff --git a/tests/test_models/test_segmentor.py b/tests/test_models/test_segmentors/utils.py similarity index 52% rename from tests/test_models/test_segmentor.py rename to tests/test_models/test_segmentors/utils.py index 90d3bf631..cfe9a17da 100644 --- a/tests/test_models/test_segmentor.py +++ b/tests/test_models/test_segmentors/utils.py @@ -1,9 +1,8 @@ import numpy as np import torch -from mmcv import ConfigDict from torch import nn -from mmseg.models import BACKBONES, HEADS, build_segmentor +from mmseg.models import BACKBONES, HEADS from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead from mmseg.models.decode_heads.decode_head import BaseDecodeHead @@ -118,97 +117,3 @@ def _segmentor_forward_train_test(segmentor): img_meta_list = [[img_meta] for img_meta in img_metas] img_meta_list = img_meta_list + img_meta_list segmentor.forward(img_list, img_meta_list, return_loss=False) - - -def test_encoder_decoder(): - - # test 1 decode head, w.o. aux head - - cfg = ConfigDict( - type='EncoderDecoder', - backbone=dict(type='ExampleBackbone'), - decode_head=dict(type='ExampleDecodeHead'), - train_cfg=None, - test_cfg=dict(mode='whole')) - segmentor = build_segmentor(cfg) - _segmentor_forward_train_test(segmentor) - - # test slide mode - cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) - segmentor = build_segmentor(cfg) - _segmentor_forward_train_test(segmentor) - - # test 1 decode head, 1 aux head - cfg = ConfigDict( - type='EncoderDecoder', - backbone=dict(type='ExampleBackbone'), - decode_head=dict(type='ExampleDecodeHead'), - auxiliary_head=dict(type='ExampleDecodeHead')) - cfg.test_cfg = ConfigDict(mode='whole') - segmentor = build_segmentor(cfg) - _segmentor_forward_train_test(segmentor) - - # test 1 decode head, 2 aux head - cfg = ConfigDict( - type='EncoderDecoder', - backbone=dict(type='ExampleBackbone'), - decode_head=dict(type='ExampleDecodeHead'), - auxiliary_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleDecodeHead') - ]) - cfg.test_cfg = ConfigDict(mode='whole') - segmentor = build_segmentor(cfg) - _segmentor_forward_train_test(segmentor) - - -def test_cascade_encoder_decoder(): - - # test 1 decode head, w.o. aux head - cfg = ConfigDict( - type='CascadeEncoderDecoder', - num_stages=2, - backbone=dict(type='ExampleBackbone'), - decode_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleCascadeDecodeHead') - ]) - cfg.test_cfg = ConfigDict(mode='whole') - segmentor = build_segmentor(cfg) - _segmentor_forward_train_test(segmentor) - - # test slide mode - cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) - segmentor = build_segmentor(cfg) - _segmentor_forward_train_test(segmentor) - - # test 1 decode head, 1 aux head - cfg = ConfigDict( - type='CascadeEncoderDecoder', - num_stages=2, - backbone=dict(type='ExampleBackbone'), - decode_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleCascadeDecodeHead') - ], - auxiliary_head=dict(type='ExampleDecodeHead')) - cfg.test_cfg = ConfigDict(mode='whole') - segmentor = build_segmentor(cfg) - _segmentor_forward_train_test(segmentor) - - # test 1 decode head, 2 aux head - cfg = ConfigDict( - type='CascadeEncoderDecoder', - num_stages=2, - backbone=dict(type='ExampleBackbone'), - decode_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleCascadeDecodeHead') - ], - auxiliary_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleDecodeHead') - ]) - cfg.test_cfg = ConfigDict(mode='whole') - segmentor = build_segmentor(cfg) - _segmentor_forward_train_test(segmentor) diff --git a/tests/test_utils/test_make_divisible.py b/tests/test_utils/test_make_divisible.py deleted file mode 100644 index 5e9d1062f..000000000 --- a/tests/test_utils/test_make_divisible.py +++ /dev/null @@ -1,13 +0,0 @@ -from mmseg.models.utils import make_divisible - - -def test_make_divisible(): - # test with min_value = None - assert make_divisible(10, 4) == 12 - assert make_divisible(9, 4) == 12 - assert make_divisible(1, 4) == 4 - - # test with min_value = 8 - assert make_divisible(10, 4, 8) == 12 - assert make_divisible(9, 4, 8) == 12 - assert make_divisible(1, 4, 8) == 8 diff --git a/tests/test_utils/test_se_layer.py b/tests/test_utils/test_se_layer.py deleted file mode 100644 index 8bba7b33b..000000000 --- a/tests/test_utils/test_se_layer.py +++ /dev/null @@ -1,41 +0,0 @@ -import mmcv -import pytest -import torch - -from mmseg.models.utils.se_layer import SELayer - - -def test_se_layer(): - with pytest.raises(AssertionError): - # test act_cfg assertion. - SELayer(32, act_cfg=(dict(type='ReLU'), )) - - # test config with channels = 16. - se_layer = SELayer(16) - assert se_layer.conv1.conv.kernel_size == (1, 1) - assert se_layer.conv1.conv.stride == (1, 1) - assert se_layer.conv1.conv.padding == (0, 0) - assert isinstance(se_layer.conv1.activate, torch.nn.ReLU) - assert se_layer.conv2.conv.kernel_size == (1, 1) - assert se_layer.conv2.conv.stride == (1, 1) - assert se_layer.conv2.conv.padding == (0, 0) - assert isinstance(se_layer.conv2.activate, mmcv.cnn.HSigmoid) - - x = torch.rand(1, 16, 64, 64) - output = se_layer(x) - assert output.shape == (1, 16, 64, 64) - - # test config with channels = 16, act_cfg = dict(type='ReLU'). - se_layer = SELayer(16, act_cfg=dict(type='ReLU')) - assert se_layer.conv1.conv.kernel_size == (1, 1) - assert se_layer.conv1.conv.stride == (1, 1) - assert se_layer.conv1.conv.padding == (0, 0) - assert isinstance(se_layer.conv1.activate, torch.nn.ReLU) - assert se_layer.conv2.conv.kernel_size == (1, 1) - assert se_layer.conv2.conv.stride == (1, 1) - assert se_layer.conv2.conv.padding == (0, 0) - assert isinstance(se_layer.conv2.activate, torch.nn.ReLU) - - x = torch.rand(1, 16, 64, 64) - output = se_layer(x) - assert output.shape == (1, 16, 64, 64)