diff --git a/configs/_base_/models/fast_scnn.py b/configs/_base_/models/fast_scnn.py index 5be999f00..94d6ab93e 100644 --- a/configs/_base_/models/fast_scnn.py +++ b/configs/_base_/models/fast_scnn.py @@ -4,8 +4,7 @@ model = dict( type='EncoderDecoder', backbone=dict( type='FastSCNN', - downsample_dw_channels1=32, - downsample_dw_channels2=48, + downsample_dw_channels=(32, 48), global_in_channels=64, global_block_channels=(64, 96, 128), global_out_channels=128, diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index d6eb23e48..dabbf6c3b 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -10,12 +10,11 @@ from ..builder import BACKBONES class LearningToDownsample(nn.Module): """Learning to downsample module. + Args: in_channels (int): Number of input channels. - dw_channels1 (int): Number of output channels of the first - depthwise conv (dwconv) layer. - dw_channels2 (int): Number of output channels of the second - dwconv layer. + dw_channels (tuple): Number of output channels of the first and + the second depthwise conv (dwconv) layers. out_channels (int): Number of output channels of the whole 'learning to downsample' module. conv_cfg (dict | None): Config of conv layers. Default: None @@ -27,8 +26,7 @@ class LearningToDownsample(nn.Module): def __init__(self, in_channels, - dw_channels1, - dw_channels2, + dw_channels, out_channels, conv_cfg=None, norm_cfg=dict(type='BN'), @@ -37,6 +35,9 @@ class LearningToDownsample(nn.Module): self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg + dw_channels1 = dw_channels[0] + dw_channels2 = dw_channels[1] + self.conv = ConvModule( in_channels, dw_channels1, @@ -235,12 +236,10 @@ class FastSCNN(nn.Module): """Fast-SCNN Backbone. Args: in_channels (int): Number of input image channels. Default: 3. - downsample_dw_channels1 (int): Number of output channels after - the first conv layer in Learning-To-Downsample (LTD) module. - Default: 32. - downsample_dw_channels2 (int): Number of output channels - after the second conv layer in LTD. - Default: 48. + downsample_dw_channels (tuple): Number of output channels after + the first conv layer & the second conv layer in + Learning-To-Downsample (LTD) module. + Default: (32, 48). global_in_channels (int): Number of input channels of Global Feature Extractor(GFE). Equal to number of output channels of LTD. @@ -280,8 +279,7 @@ class FastSCNN(nn.Module): def __init__(self, in_channels=3, - downsample_dw_channels1=32, - downsample_dw_channels2=48, + downsample_dw_channels=(32, 48), global_in_channels=64, global_block_channels=(64, 96, 128), global_out_channels=128, @@ -307,8 +305,8 @@ class FastSCNN(nn.Module): downsampling factor in the GFE module!') self.in_channels = in_channels - self.downsample_dw_channels1 = downsample_dw_channels1 - self.downsample_dw_channels2 = downsample_dw_channels2 + self.downsample_dw_channels1 = downsample_dw_channels[0] + self.downsample_dw_channels2 = downsample_dw_channels[1] self.global_in_channels = global_in_channels self.global_block_channels = global_block_channels self.global_out_channels = global_out_channels @@ -323,8 +321,7 @@ class FastSCNN(nn.Module): self.align_corners = align_corners self.learning_to_downsample = LearningToDownsample( in_channels, - downsample_dw_channels1, - downsample_dw_channels2, + downsample_dw_channels, global_in_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 084849aa4..767981bc5 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -669,7 +669,7 @@ def test_resnext_backbone(): def test_fastscnn_backbone(): with pytest.raises(AssertionError): # Fast-SCNN channel constraints. - FastSCNN(3, 32, 48, 64, (64, 96, 128), 127, 64, 128) + FastSCNN(3, (32, 48), 64, (64, 96, 128), 127, 64, 128) # Test FastSCNN Standard Forward model = FastSCNN()