diff --git a/configs/_base_/models/fast_scnn.py b/configs/_base_/models/fast_scnn.py index 0aefda274..67ee0d39a 100644 --- a/configs/_base_/models/fast_scnn.py +++ b/configs/_base_/models/fast_scnn.py @@ -7,7 +7,7 @@ model = dict( downsample_dw_channels=(32, 48), global_in_channels=64, global_block_channels=(64, 96, 128), - global_block_downsample_factors=(2, 2, 1), + global_block_strides=(2, 2, 1), global_out_channels=128, higher_in_channels=64, lower_in_channels=128, diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index c3531baa4..f1c465771 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -79,13 +79,14 @@ class GlobalFeatureExtractor(nn.Module): Default: (64, 96, 128) out_channels(int): Number of output channels of the GFE module. Default: 128 - expand_ratio (int): Upsampling factor of each Inverted Residual - module. Default: 6 + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + Default: 6 num_blocks (tuple[int]): Tuple of ints. Each int specifies the number of times each Inverted Residual module is repeated. The repeated Inverted Residual modules are called a 'group'. Default: (3, 3, 3) - downsample_factors (tuple[int]): Tuple of ints. Each int specifies + strides (tuple[int]): Tuple of ints. Each int specifies the downsampling factor of each 'group'. Default: (2, 2, 1) pool_scales (tuple[int]): Tuple of ints. Each int specifies @@ -106,7 +107,7 @@ class GlobalFeatureExtractor(nn.Module): out_channels=128, expand_ratio=6, num_blocks=(3, 3, 3), - downsample_factors=(2, 2, 1), + strides=(2, 2, 1), pool_scales=(1, 2, 3, 6), conv_cfg=None, norm_cfg=dict(type='BN'), @@ -118,17 +119,14 @@ class GlobalFeatureExtractor(nn.Module): self.act_cfg = act_cfg assert len(block_channels) == len(num_blocks) == 3 self.bottleneck1 = self._make_layer(in_channels, block_channels[0], - num_blocks[0], - downsample_factors[0], + num_blocks[0], strides[0], expand_ratio) self.bottleneck2 = self._make_layer(block_channels[0], block_channels[1], num_blocks[1], - downsample_factors[1], - expand_ratio) + strides[1], expand_ratio) self.bottleneck3 = self._make_layer(block_channels[1], block_channels[2], num_blocks[2], - downsample_factors[2], - expand_ratio) + strides[2], expand_ratio) self.ppm = PPM( pool_scales, block_channels[2], @@ -269,8 +267,8 @@ class FastSCNN(nn.Module): the output channels for each of the MobileNet-v2 bottleneck residual blocks in GFE. Default: (64, 96, 128). - global_block_downsample_factors (tuple[int]): Tuple of integers - that describe the downsampling factors for each of the + global_block_strides (tuple[int]): Tuple of integers + that describe the strides (downsampling factors) for each of the MobileNet-v2 bottleneck residual blocks in GFE. Default: (2, 2, 1). global_out_channels (int): Number of output channels of GFE. @@ -303,7 +301,7 @@ class FastSCNN(nn.Module): downsample_dw_channels=(32, 48), global_in_channels=64, global_block_channels=(64, 96, 128), - global_block_downsample_factors=(2, 2, 1), + global_block_strides=(2, 2, 1), global_out_channels=128, higher_in_channels=64, lower_in_channels=128, @@ -324,7 +322,7 @@ class FastSCNN(nn.Module): # Calculate scale factor used in FFM. self.scale_factor = 1 - for factor in global_block_downsample_factors: + for factor in global_block_strides: self.scale_factor *= factor self.in_channels = in_channels @@ -332,7 +330,7 @@ class FastSCNN(nn.Module): self.downsample_dw_channels2 = downsample_dw_channels[1] self.global_in_channels = global_in_channels self.global_block_channels = global_block_channels - self.global_block_downsample_factors = global_block_downsample_factors + self.global_block_strides = global_block_strides self.global_out_channels = global_out_channels self.higher_in_channels = higher_in_channels self.lower_in_channels = lower_in_channels @@ -353,7 +351,7 @@ class FastSCNN(nn.Module): global_in_channels, global_block_channels, global_out_channels, - downsample_factors=self.global_block_downsample_factors, + downsample_factors=self.global_block_strides, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg,