From 88a123d16ffe2e3f8a9f34a6c5b107f1e892bac0 Mon Sep 17 00:00:00 2001 From: johnzja Date: Fri, 14 Aug 2020 12:09:22 +0800 Subject: [PATCH] Arg scale_factor deleted. --- configs/_base_/models/fast_scnn.py | 2 +- mmseg/models/backbones/fast_scnn.py | 36 +++++++++++++++++++---------- tests/test_models/test_backbone.py | 9 +++++--- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/configs/_base_/models/fast_scnn.py b/configs/_base_/models/fast_scnn.py index 94d6ab93e..0aefda274 100644 --- a/configs/_base_/models/fast_scnn.py +++ b/configs/_base_/models/fast_scnn.py @@ -7,11 +7,11 @@ model = dict( downsample_dw_channels=(32, 48), global_in_channels=64, global_block_channels=(64, 96, 128), + global_block_downsample_factors=(2, 2, 1), global_out_channels=128, higher_in_channels=64, lower_in_channels=128, fusion_out_channels=128, - scale_factor=4, out_indices=(0, 1, 2), norm_cfg=norm_cfg, align_corners=False), diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index c5322dfb2..c3531baa4 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -83,7 +83,11 @@ class GlobalFeatureExtractor(nn.Module): module. Default: 6 num_blocks (tuple[int]): Tuple of ints. Each int specifies the number of times each Inverted Residual module is repeated. + The repeated Inverted Residual modules are called a 'group'. Default: (3, 3, 3) + downsample_factors (tuple[int]): Tuple of ints. Each int specifies + the downsampling factor of each 'group'. + Default: (2, 2, 1) pool_scales (tuple[int]): Tuple of ints. Each int specifies the parameter required in 'global average pooling' within PPM. Default: (1, 2, 3, 6) @@ -102,6 +106,7 @@ class GlobalFeatureExtractor(nn.Module): out_channels=128, expand_ratio=6, num_blocks=(3, 3, 3), + downsample_factors=(2, 2, 1), pool_scales=(1, 2, 3, 6), conv_cfg=None, norm_cfg=dict(type='BN'), @@ -113,13 +118,17 @@ class GlobalFeatureExtractor(nn.Module): self.act_cfg = act_cfg assert len(block_channels) == len(num_blocks) == 3 self.bottleneck1 = self._make_layer(in_channels, block_channels[0], - num_blocks[0], 2, expand_ratio) + num_blocks[0], + downsample_factors[0], + expand_ratio) self.bottleneck2 = self._make_layer(block_channels[0], block_channels[1], num_blocks[1], - 2, expand_ratio) + downsample_factors[1], + expand_ratio) self.bottleneck3 = self._make_layer(block_channels[1], block_channels[2], num_blocks[2], - 1, expand_ratio) + downsample_factors[2], + expand_ratio) self.ppm = PPM( pool_scales, block_channels[2], @@ -260,6 +269,10 @@ class FastSCNN(nn.Module): the output channels for each of the MobileNet-v2 bottleneck residual blocks in GFE. Default: (64, 96, 128). + global_block_downsample_factors (tuple[int]): Tuple of integers + that describe the downsampling factors for each of the + MobileNet-v2 bottleneck residual blocks in GFE. + Default: (2, 2, 1). global_out_channels (int): Number of output channels of GFE. Default: 128. higher_in_channels (int): Number of input channels of the higher @@ -272,10 +285,6 @@ class FastSCNN(nn.Module): Default: 128. fusion_out_channels (int): Number of output channels of FFM. Default: 128. - scale_factor (int): The upsampling factor of the higher resolution - branch in FFM. - Equal to the downsampling factor in GFE. - Default: 4. out_indices (tuple): Tuple of indices of list [higher_res_features, lower_res_features, fusion_output]. Often set to (0,1,2) to enable aux. heads. @@ -294,11 +303,11 @@ class FastSCNN(nn.Module): downsample_dw_channels=(32, 48), global_in_channels=64, global_block_channels=(64, 96, 128), + global_block_downsample_factors=(2, 2, 1), global_out_channels=128, higher_in_channels=64, lower_in_channels=128, fusion_out_channels=128, - scale_factor=4, out_indices=(0, 1, 2), conv_cfg=None, norm_cfg=dict(type='BN'), @@ -312,20 +321,22 @@ class FastSCNN(nn.Module): elif global_out_channels != lower_in_channels: raise AssertionError('Global Output Channels must be the same \ with Lower Input Channels!') - if scale_factor != 4: - raise AssertionError('Scale-factor must compensate the \ - downsampling factor in the GFE module!') + + # Calculate scale factor used in FFM. + self.scale_factor = 1 + for factor in global_block_downsample_factors: + self.scale_factor *= factor self.in_channels = in_channels self.downsample_dw_channels1 = downsample_dw_channels[0] self.downsample_dw_channels2 = downsample_dw_channels[1] self.global_in_channels = global_in_channels self.global_block_channels = global_block_channels + self.global_block_downsample_factors = global_block_downsample_factors self.global_out_channels = global_out_channels self.higher_in_channels = higher_in_channels self.lower_in_channels = lower_in_channels self.fusion_out_channels = fusion_out_channels - self.scale_factor = scale_factor self.out_indices = out_indices self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg @@ -342,6 +353,7 @@ class FastSCNN(nn.Module): global_in_channels, global_block_channels, global_out_channels, + downsample_factors=self.global_block_downsample_factors, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 767981bc5..65d814d38 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -47,7 +47,6 @@ def check_norm_state(modules, train_state): def test_resnet_basic_block(): - with pytest.raises(AssertionError): # Not implemented yet. dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) @@ -97,7 +96,6 @@ def test_resnet_basic_block(): def test_resnet_bottleneck(): - with pytest.raises(AssertionError): # Style must be in ['pytorch', 'caffe'] Bottleneck(64, 64, style='tensorflow') @@ -669,7 +667,12 @@ def test_resnext_backbone(): def test_fastscnn_backbone(): with pytest.raises(AssertionError): # Fast-SCNN channel constraints. - FastSCNN(3, (32, 48), 64, (64, 96, 128), 127, 64, 128) + FastSCNN( + 3, (32, 48), + 64, (64, 96, 128), (2, 2, 1), + global_out_channels=127, + higher_in_channels=64, + lower_in_channels=128) # Test FastSCNN Standard Forward model = FastSCNN()