diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index fb9df6383..d69bc5bed 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -227,6 +227,13 @@ class FastSCNN(nn.Module): """ super(FastSCNN, self).__init__() + if global_in_channels != higher_in_channels: + raise AssertionError('Global Input Channels must be the same with Higher Input Channels!') + elif global_out_channels != lower_in_channels: + raise AssertionError('Global Output Channels must be the same with Lower Input Channels!') + if scale_factor != 4: + raise AssertionError('Scale-factor must compensate the downsampling factor in the GFE module!') + self.in_channels = in_channels self.downsample_dw_channels1 = downsample_dw_channels1 self.downsample_dw_channels2 = downsample_dw_channels2 diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 00ae43d00..9ba33d283 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -4,7 +4,7 @@ from mmcv.ops import DeformConv2dPack from mmcv.utils.parrots_wrapper import _BatchNorm from torch.nn.modules import AvgPool2d, GroupNorm -from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt +from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt, FastSCNN from mmseg.models.backbones.resnet import BasicBlock, Bottleneck from mmseg.models.backbones.resnext import Bottleneck as BottleneckX from mmseg.models.utils import ResLayer @@ -664,3 +664,26 @@ def test_resnext_backbone(): assert feat[1].shape == torch.Size([1, 512, 28, 28]) assert feat[2].shape == torch.Size([1, 1024, 14, 14]) assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + +def test_fastscnn_backbone(): + with pytest.raises(AssertionError): + # Fast-SCNN channel constraints. + FastSCNN(3, 32, 48, 64, (64, 96, 128), 127, 64, 128) + + # Test FastSCNN Standard Forward + model = FastSCNN() + model.init_weights() + model.train() + imgs = torch.randn(1, 3, 1024, 2048) + feat = model(imgs) + + assert len(feat) == 3 + assert feat[0].shape == torch.Size([1, 64, 128, 256]) # higher-res + assert feat[1].shape == torch.Size([1, 128, 32, 64]) # lower-res + assert feat[2].shape == torch.Size([1, 128, 128, 256]) # FFM output + + + + +