Code style improved.

pull/58/head
johnzja 2020-08-09 22:48:59 +08:00
parent daf93c6355
commit e35f9acde1
3 changed files with 37 additions and 27 deletions

View File

@ -10,7 +10,7 @@ from ..builder import BACKBONES
class LearningToDownsample(nn.Module):
"""Learning to downsample module"""
"""Learning to downsample module."""
def __init__(self,
in_channels,
@ -53,7 +53,7 @@ class LearningToDownsample(nn.Module):
class GlobalFeatureExtractor(nn.Module):
"""Global feature extractor module"""
"""Global feature extractor module."""
def __init__(self,
in_channels=64,
@ -115,7 +115,7 @@ class GlobalFeatureExtractor(nn.Module):
class FeatureFusionModule(nn.Module):
"""Feature fusion module"""
"""Feature fusion module."""
def __init__(self,
higher_in_channels,
@ -188,7 +188,6 @@ class FastSCNN(nn.Module):
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
"""Fast-SCNN Backbone.
Args:
in_channels(int): Number of input image channels. Default=3 (RGB)
@ -196,28 +195,35 @@ class FastSCNN(nn.Module):
downsample_dw_channels1(int): Number of output channels after
the first conv layer in Learning-To-Downsample (LTD) module.
downsample_dw_channels2(int): Number of output channels after the second conv layer in LTD.
downsample_dw_channels2(int): Number of output channels
after the second conv layer in LTD.
global_in_channels(int): Number of input channels of Global Feature Extractor(GFE).
global_in_channels(int): Number of input channels of
Global Feature Extractor(GFE).
Equal to number of output channels of LTD.
global_block_channels(tuple): Tuple of integers that describe the output channels for
each of the MobileNet-v2 bottleneck residual blocks in GFE.
global_block_channels(tuple): Tuple of integers that describe
the output channels for each of the MobileNet-v2 bottleneck
residual blocks in GFE.
global_out_channels(int): Number of output channels of GFE.
higher_in_channels(int): Number of input channels of the higher resolution branch in FFM.
higher_in_channels(int): Number of input channels of the higher
resolution branch in FFM.
Equal to global_in_channels.
lower_in_channels(int): Number of input channels of the lower resolution branch in FFM.
lower_in_channels(int): Number of input channels of the lower
resolution branch in FFM.
Equal to global_out_channels.
fusion_out_channels(int): Number of output channels of FFM.
scale_factor(int): The upsampling factor of the higher resolution branch in FFM.
scale_factor(int): The upsampling factor of the higher resolution
branch in FFM.
Equal to the downsampling factor in GFE.
out_indices(tuple): Tuple of indices of list [higher_res_features, lower_res_features, fusion_output].
out_indices(tuple): Tuple of indices of list
[higher_res_features, lower_res_features, fusion_output].
Often set to (0,1,2) to enable aux. heads.
conv_cfg (dict|None): Config of conv layers.
@ -228,11 +234,14 @@ class FastSCNN(nn.Module):
super(FastSCNN, self).__init__()
if global_in_channels != higher_in_channels:
raise AssertionError('Global Input Channels must be the same with Higher Input Channels!')
raise AssertionError('Global Input Channels must be the same \
with Higher Input Channels!')
elif global_out_channels != lower_in_channels:
raise AssertionError('Global Output Channels must be the same with Lower Input Channels!')
raise AssertionError('Global Output Channels must be the same \
with Lower Input Channels!')
if scale_factor != 4:
raise AssertionError('Scale-factor must compensate the downsampling factor in the GFE module!')
raise AssertionError('Scale-factor must compensate the \
downsampling factor in the GFE module!')
self.in_channels = in_channels
self.downsample_dw_channels1 = downsample_dw_channels1

View File

@ -5,7 +5,8 @@ from .fcn_head import FCNHead
@HEADS.register_module()
class SepFCNHead(FCNHead):
"""Depthwise-Separable Fully Convolutional Network for Semantic Segmentation
"""Depthwise-Separable Fully Convolutional Network for Semantic
Segmentation.
This head is implemented according to Fast-SCNN.
Args:
@ -16,7 +17,8 @@ class SepFCNHead(FCNHead):
concat_input(bool): Whether to concatenate original decode input into
the result of consecutive convolution layers.
num_classes(int): Used to determine the dimension of final prediction tensor.
num_classes(int): Used to determine the dimension of
final prediction tensor.
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
@ -24,7 +26,8 @@ class SepFCNHead(FCNHead):
align_corners (bool): align_corners argument of F.interpolate.
loss_decode(dict): Config of loss type and some relevant additional options.
loss_decode(dict): Config of loss type and some
relevant additional options.
"""
def __init__(self, **kwargs):

View File

@ -4,7 +4,7 @@ from mmcv.ops import DeformConv2dPack
from mmcv.utils.parrots_wrapper import _BatchNorm
from torch.nn.modules import AvgPool2d, GroupNorm
from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt, FastSCNN
from mmseg.models.backbones import FastSCNN, ResNet, ResNetV1d, ResNeXt
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
from mmseg.models.utils import ResLayer
@ -680,11 +680,9 @@ def test_fastscnn_backbone():
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256]) # higher-res
assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64]) # lower-res
assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256]) # FFM output
# higher-res
assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256])
# lower-res
assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64])
# FFM output
assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256])