From e35f9acde1e4a259c00d3e9947b51f9773313287 Mon Sep 17 00:00:00 2001 From: johnzja Date: Sun, 9 Aug 2020 22:48:59 +0800 Subject: [PATCH] Code style improved. --- mmseg/models/backbones/fast_scnn.py | 39 ++++++++++++++--------- mmseg/models/decode_heads/sep_fcn_head.py | 9 ++++-- tests/test_models/test_backbone.py | 16 ++++------ 3 files changed, 37 insertions(+), 27 deletions(-) diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index d69bc5bed..522411ac7 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -10,7 +10,7 @@ from ..builder import BACKBONES class LearningToDownsample(nn.Module): - """Learning to downsample module""" + """Learning to downsample module.""" def __init__(self, in_channels, @@ -53,7 +53,7 @@ class LearningToDownsample(nn.Module): class GlobalFeatureExtractor(nn.Module): - """Global feature extractor module""" + """Global feature extractor module.""" def __init__(self, in_channels=64, @@ -115,7 +115,7 @@ class GlobalFeatureExtractor(nn.Module): class FeatureFusionModule(nn.Module): - """Feature fusion module""" + """Feature fusion module.""" def __init__(self, higher_in_channels, @@ -188,7 +188,6 @@ class FastSCNN(nn.Module): norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False): - """Fast-SCNN Backbone. Args: in_channels(int): Number of input image channels. Default=3 (RGB) @@ -196,28 +195,35 @@ class FastSCNN(nn.Module): downsample_dw_channels1(int): Number of output channels after the first conv layer in Learning-To-Downsample (LTD) module. - downsample_dw_channels2(int): Number of output channels after the second conv layer in LTD. + downsample_dw_channels2(int): Number of output channels + after the second conv layer in LTD. - global_in_channels(int): Number of input channels of Global Feature Extractor(GFE). + global_in_channels(int): Number of input channels of + Global Feature Extractor(GFE). Equal to number of output channels of LTD. - global_block_channels(tuple): Tuple of integers that describe the output channels for - each of the MobileNet-v2 bottleneck residual blocks in GFE. + global_block_channels(tuple): Tuple of integers that describe + the output channels for each of the MobileNet-v2 bottleneck + residual blocks in GFE. global_out_channels(int): Number of output channels of GFE. - higher_in_channels(int): Number of input channels of the higher resolution branch in FFM. + higher_in_channels(int): Number of input channels of the higher + resolution branch in FFM. Equal to global_in_channels. - lower_in_channels(int): Number of input channels of the lower resolution branch in FFM. + lower_in_channels(int): Number of input channels of the lower + resolution branch in FFM. Equal to global_out_channels. fusion_out_channels(int): Number of output channels of FFM. - scale_factor(int): The upsampling factor of the higher resolution branch in FFM. + scale_factor(int): The upsampling factor of the higher resolution + branch in FFM. Equal to the downsampling factor in GFE. - out_indices(tuple): Tuple of indices of list [higher_res_features, lower_res_features, fusion_output]. + out_indices(tuple): Tuple of indices of list + [higher_res_features, lower_res_features, fusion_output]. Often set to (0,1,2) to enable aux. heads. conv_cfg (dict|None): Config of conv layers. @@ -228,11 +234,14 @@ class FastSCNN(nn.Module): super(FastSCNN, self).__init__() if global_in_channels != higher_in_channels: - raise AssertionError('Global Input Channels must be the same with Higher Input Channels!') + raise AssertionError('Global Input Channels must be the same \ + with Higher Input Channels!') elif global_out_channels != lower_in_channels: - raise AssertionError('Global Output Channels must be the same with Lower Input Channels!') + raise AssertionError('Global Output Channels must be the same \ + with Lower Input Channels!') if scale_factor != 4: - raise AssertionError('Scale-factor must compensate the downsampling factor in the GFE module!') + raise AssertionError('Scale-factor must compensate the \ + downsampling factor in the GFE module!') self.in_channels = in_channels self.downsample_dw_channels1 = downsample_dw_channels1 diff --git a/mmseg/models/decode_heads/sep_fcn_head.py b/mmseg/models/decode_heads/sep_fcn_head.py index d93c246f7..c7030b2cb 100644 --- a/mmseg/models/decode_heads/sep_fcn_head.py +++ b/mmseg/models/decode_heads/sep_fcn_head.py @@ -5,7 +5,8 @@ from .fcn_head import FCNHead @HEADS.register_module() class SepFCNHead(FCNHead): - """Depthwise-Separable Fully Convolutional Network for Semantic Segmentation + """Depthwise-Separable Fully Convolutional Network for Semantic + Segmentation. This head is implemented according to Fast-SCNN. Args: @@ -16,7 +17,8 @@ class SepFCNHead(FCNHead): concat_input(bool): Whether to concatenate original decode input into the result of consecutive convolution layers. - num_classes(int): Used to determine the dimension of final prediction tensor. + num_classes(int): Used to determine the dimension of + final prediction tensor. in_index(int): Correspond with 'out_indices' in FastSCNN backbone. @@ -24,7 +26,8 @@ class SepFCNHead(FCNHead): align_corners (bool): align_corners argument of F.interpolate. - loss_decode(dict): Config of loss type and some relevant additional options. + loss_decode(dict): Config of loss type and some + relevant additional options. """ def __init__(self, **kwargs): diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 282550179..c030464f0 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -4,7 +4,7 @@ from mmcv.ops import DeformConv2dPack from mmcv.utils.parrots_wrapper import _BatchNorm from torch.nn.modules import AvgPool2d, GroupNorm -from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt, FastSCNN +from mmseg.models.backbones import FastSCNN, ResNet, ResNetV1d, ResNeXt from mmseg.models.backbones.resnet import BasicBlock, Bottleneck from mmseg.models.backbones.resnext import Bottleneck as BottleneckX from mmseg.models.utils import ResLayer @@ -680,11 +680,9 @@ def test_fastscnn_backbone(): feat = model(imgs) assert len(feat) == 3 - assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256]) # higher-res - assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64]) # lower-res - assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256]) # FFM output - - - - - + # higher-res + assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256]) + # lower-res + assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64]) + # FFM output + assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256])