Code style improved.
parent
daf93c6355
commit
e35f9acde1
|
@ -10,7 +10,7 @@ from ..builder import BACKBONES
|
||||||
|
|
||||||
|
|
||||||
class LearningToDownsample(nn.Module):
|
class LearningToDownsample(nn.Module):
|
||||||
"""Learning to downsample module"""
|
"""Learning to downsample module."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -53,7 +53,7 @@ class LearningToDownsample(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GlobalFeatureExtractor(nn.Module):
|
class GlobalFeatureExtractor(nn.Module):
|
||||||
"""Global feature extractor module"""
|
"""Global feature extractor module."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels=64,
|
in_channels=64,
|
||||||
|
@ -115,7 +115,7 @@ class GlobalFeatureExtractor(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FeatureFusionModule(nn.Module):
|
class FeatureFusionModule(nn.Module):
|
||||||
"""Feature fusion module"""
|
"""Feature fusion module."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
higher_in_channels,
|
higher_in_channels,
|
||||||
|
@ -188,7 +188,6 @@ class FastSCNN(nn.Module):
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
act_cfg=dict(type='ReLU'),
|
act_cfg=dict(type='ReLU'),
|
||||||
align_corners=False):
|
align_corners=False):
|
||||||
|
|
||||||
"""Fast-SCNN Backbone.
|
"""Fast-SCNN Backbone.
|
||||||
Args:
|
Args:
|
||||||
in_channels(int): Number of input image channels. Default=3 (RGB)
|
in_channels(int): Number of input image channels. Default=3 (RGB)
|
||||||
|
@ -196,28 +195,35 @@ class FastSCNN(nn.Module):
|
||||||
downsample_dw_channels1(int): Number of output channels after
|
downsample_dw_channels1(int): Number of output channels after
|
||||||
the first conv layer in Learning-To-Downsample (LTD) module.
|
the first conv layer in Learning-To-Downsample (LTD) module.
|
||||||
|
|
||||||
downsample_dw_channels2(int): Number of output channels after the second conv layer in LTD.
|
downsample_dw_channels2(int): Number of output channels
|
||||||
|
after the second conv layer in LTD.
|
||||||
|
|
||||||
global_in_channels(int): Number of input channels of Global Feature Extractor(GFE).
|
global_in_channels(int): Number of input channels of
|
||||||
|
Global Feature Extractor(GFE).
|
||||||
Equal to number of output channels of LTD.
|
Equal to number of output channels of LTD.
|
||||||
|
|
||||||
global_block_channels(tuple): Tuple of integers that describe the output channels for
|
global_block_channels(tuple): Tuple of integers that describe
|
||||||
each of the MobileNet-v2 bottleneck residual blocks in GFE.
|
the output channels for each of the MobileNet-v2 bottleneck
|
||||||
|
residual blocks in GFE.
|
||||||
|
|
||||||
global_out_channels(int): Number of output channels of GFE.
|
global_out_channels(int): Number of output channels of GFE.
|
||||||
|
|
||||||
higher_in_channels(int): Number of input channels of the higher resolution branch in FFM.
|
higher_in_channels(int): Number of input channels of the higher
|
||||||
|
resolution branch in FFM.
|
||||||
Equal to global_in_channels.
|
Equal to global_in_channels.
|
||||||
|
|
||||||
lower_in_channels(int): Number of input channels of the lower resolution branch in FFM.
|
lower_in_channels(int): Number of input channels of the lower
|
||||||
|
resolution branch in FFM.
|
||||||
Equal to global_out_channels.
|
Equal to global_out_channels.
|
||||||
|
|
||||||
fusion_out_channels(int): Number of output channels of FFM.
|
fusion_out_channels(int): Number of output channels of FFM.
|
||||||
|
|
||||||
scale_factor(int): The upsampling factor of the higher resolution branch in FFM.
|
scale_factor(int): The upsampling factor of the higher resolution
|
||||||
|
branch in FFM.
|
||||||
Equal to the downsampling factor in GFE.
|
Equal to the downsampling factor in GFE.
|
||||||
|
|
||||||
out_indices(tuple): Tuple of indices of list [higher_res_features, lower_res_features, fusion_output].
|
out_indices(tuple): Tuple of indices of list
|
||||||
|
[higher_res_features, lower_res_features, fusion_output].
|
||||||
Often set to (0,1,2) to enable aux. heads.
|
Often set to (0,1,2) to enable aux. heads.
|
||||||
|
|
||||||
conv_cfg (dict|None): Config of conv layers.
|
conv_cfg (dict|None): Config of conv layers.
|
||||||
|
@ -228,11 +234,14 @@ class FastSCNN(nn.Module):
|
||||||
|
|
||||||
super(FastSCNN, self).__init__()
|
super(FastSCNN, self).__init__()
|
||||||
if global_in_channels != higher_in_channels:
|
if global_in_channels != higher_in_channels:
|
||||||
raise AssertionError('Global Input Channels must be the same with Higher Input Channels!')
|
raise AssertionError('Global Input Channels must be the same \
|
||||||
|
with Higher Input Channels!')
|
||||||
elif global_out_channels != lower_in_channels:
|
elif global_out_channels != lower_in_channels:
|
||||||
raise AssertionError('Global Output Channels must be the same with Lower Input Channels!')
|
raise AssertionError('Global Output Channels must be the same \
|
||||||
|
with Lower Input Channels!')
|
||||||
if scale_factor != 4:
|
if scale_factor != 4:
|
||||||
raise AssertionError('Scale-factor must compensate the downsampling factor in the GFE module!')
|
raise AssertionError('Scale-factor must compensate the \
|
||||||
|
downsampling factor in the GFE module!')
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.downsample_dw_channels1 = downsample_dw_channels1
|
self.downsample_dw_channels1 = downsample_dw_channels1
|
||||||
|
|
|
@ -5,7 +5,8 @@ from .fcn_head import FCNHead
|
||||||
|
|
||||||
@HEADS.register_module()
|
@HEADS.register_module()
|
||||||
class SepFCNHead(FCNHead):
|
class SepFCNHead(FCNHead):
|
||||||
"""Depthwise-Separable Fully Convolutional Network for Semantic Segmentation
|
"""Depthwise-Separable Fully Convolutional Network for Semantic
|
||||||
|
Segmentation.
|
||||||
|
|
||||||
This head is implemented according to Fast-SCNN.
|
This head is implemented according to Fast-SCNN.
|
||||||
Args:
|
Args:
|
||||||
|
@ -16,7 +17,8 @@ class SepFCNHead(FCNHead):
|
||||||
concat_input(bool): Whether to concatenate original decode input into
|
concat_input(bool): Whether to concatenate original decode input into
|
||||||
the result of consecutive convolution layers.
|
the result of consecutive convolution layers.
|
||||||
|
|
||||||
num_classes(int): Used to determine the dimension of final prediction tensor.
|
num_classes(int): Used to determine the dimension of
|
||||||
|
final prediction tensor.
|
||||||
|
|
||||||
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
|
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
|
||||||
|
|
||||||
|
@ -24,7 +26,8 @@ class SepFCNHead(FCNHead):
|
||||||
|
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
|
||||||
loss_decode(dict): Config of loss type and some relevant additional options.
|
loss_decode(dict): Config of loss type and some
|
||||||
|
relevant additional options.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
|
|
@ -4,7 +4,7 @@ from mmcv.ops import DeformConv2dPack
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
from torch.nn.modules import AvgPool2d, GroupNorm
|
from torch.nn.modules import AvgPool2d, GroupNorm
|
||||||
|
|
||||||
from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt, FastSCNN
|
from mmseg.models.backbones import FastSCNN, ResNet, ResNetV1d, ResNeXt
|
||||||
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||||
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
||||||
from mmseg.models.utils import ResLayer
|
from mmseg.models.utils import ResLayer
|
||||||
|
@ -680,11 +680,9 @@ def test_fastscnn_backbone():
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
|
|
||||||
assert len(feat) == 3
|
assert len(feat) == 3
|
||||||
assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256]) # higher-res
|
# higher-res
|
||||||
assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64]) # lower-res
|
assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256])
|
||||||
assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256]) # FFM output
|
# lower-res
|
||||||
|
assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64])
|
||||||
|
# FFM output
|
||||||
|
assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue