mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Fix fastscnn resize problems. (#82)
* Fix fast_scnn resize problems * Fix fast_scnn resize problems 1 * Fix fast_scnn resize problems 2 * test for pascal voc
This commit is contained in:
parent
11dd9859c2
commit
65dae41bbf
70
configs/fastscnn/fast_scnn_4x8_80k_lr0.12_pascal.py
Normal file
70
configs/fastscnn/fast_scnn_4x8_80k_lr0.12_pascal.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/fast_scnn.py', '../_base_/datasets/pascal_voc12.py',
|
||||||
|
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Re-config the data sampler.
|
||||||
|
data = dict(samples_per_gpu=8, workers_per_gpu=4)
|
||||||
|
|
||||||
|
# Re-config the optimizer.
|
||||||
|
optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5)
|
||||||
|
|
||||||
|
# update num_classes of the segmentor.
|
||||||
|
# model settings
|
||||||
|
norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
|
||||||
|
model = dict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
backbone=dict(
|
||||||
|
type='FastSCNN',
|
||||||
|
downsample_dw_channels=(32, 48),
|
||||||
|
global_in_channels=64,
|
||||||
|
global_block_channels=(64, 96, 128),
|
||||||
|
global_block_strides=(2, 2, 1),
|
||||||
|
global_out_channels=128,
|
||||||
|
higher_in_channels=64,
|
||||||
|
lower_in_channels=128,
|
||||||
|
fusion_out_channels=128,
|
||||||
|
out_indices=(0, 1, 2),
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False),
|
||||||
|
decode_head=dict(
|
||||||
|
type='DepthwiseSeparableFCNHead',
|
||||||
|
in_channels=128,
|
||||||
|
channels=128,
|
||||||
|
concat_input=False,
|
||||||
|
num_classes=21,
|
||||||
|
in_index=-1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.)),
|
||||||
|
auxiliary_head=[
|
||||||
|
dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=128,
|
||||||
|
channels=32,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=21,
|
||||||
|
in_index=-2,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||||
|
dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=64,
|
||||||
|
channels=32,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=21,
|
||||||
|
in_index=-3,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||||
|
])
|
||||||
|
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg = dict()
|
||||||
|
test_cfg = dict(mode='whole')
|
@ -3,7 +3,7 @@ import mmcv
|
|||||||
from .version import __version__, version_info
|
from .version import __version__, version_info
|
||||||
|
|
||||||
MMCV_MIN = '1.0.5'
|
MMCV_MIN = '1.0.5'
|
||||||
MMCV_MAX = '1.0.5'
|
MMCV_MAX = '1.1.0'
|
||||||
|
|
||||||
|
|
||||||
def digit_version(version_str):
|
def digit_version(version_str):
|
||||||
|
@ -186,9 +186,6 @@ class FeatureFusionModule(nn.Module):
|
|||||||
lower_in_channels (int): Number of input channels of the
|
lower_in_channels (int): Number of input channels of the
|
||||||
lower-resolution branch.
|
lower-resolution branch.
|
||||||
out_channels (int): Number of output channels.
|
out_channels (int): Number of output channels.
|
||||||
scale_factor (int): Scale factor applied to the lower-res input.
|
|
||||||
Should be coherent with the downsampling factor determined
|
|
||||||
by the GFE module.
|
|
||||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||||
norm_cfg (dict | None): Config of norm layers. Default:
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
dict(type='BN')
|
dict(type='BN')
|
||||||
@ -202,13 +199,11 @@ class FeatureFusionModule(nn.Module):
|
|||||||
higher_in_channels,
|
higher_in_channels,
|
||||||
lower_in_channels,
|
lower_in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
scale_factor,
|
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
act_cfg=dict(type='ReLU'),
|
act_cfg=dict(type='ReLU'),
|
||||||
align_corners=False):
|
align_corners=False):
|
||||||
super(FeatureFusionModule, self).__init__()
|
super(FeatureFusionModule, self).__init__()
|
||||||
self.scale_factor = scale_factor
|
|
||||||
self.conv_cfg = conv_cfg
|
self.conv_cfg = conv_cfg
|
||||||
self.norm_cfg = norm_cfg
|
self.norm_cfg = norm_cfg
|
||||||
self.act_cfg = act_cfg
|
self.act_cfg = act_cfg
|
||||||
@ -239,7 +234,7 @@ class FeatureFusionModule(nn.Module):
|
|||||||
def forward(self, higher_res_feature, lower_res_feature):
|
def forward(self, higher_res_feature, lower_res_feature):
|
||||||
lower_res_feature = resize(
|
lower_res_feature = resize(
|
||||||
lower_res_feature,
|
lower_res_feature,
|
||||||
scale_factor=self.scale_factor,
|
size=higher_res_feature.size()[2:],
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=self.align_corners)
|
align_corners=self.align_corners)
|
||||||
lower_res_feature = self.dwconv(lower_res_feature)
|
lower_res_feature = self.dwconv(lower_res_feature)
|
||||||
@ -321,11 +316,6 @@ class FastSCNN(nn.Module):
|
|||||||
raise AssertionError('Global Output Channels must be the same \
|
raise AssertionError('Global Output Channels must be the same \
|
||||||
with Lower Input Channels!')
|
with Lower Input Channels!')
|
||||||
|
|
||||||
# Calculate scale factor used in FFM.
|
|
||||||
self.scale_factor = 1
|
|
||||||
for factor in global_block_strides:
|
|
||||||
self.scale_factor *= factor
|
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
||||||
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
||||||
@ -361,7 +351,6 @@ class FastSCNN(nn.Module):
|
|||||||
higher_in_channels,
|
higher_in_channels,
|
||||||
lower_in_channels,
|
lower_in_channels,
|
||||||
fusion_out_channels,
|
fusion_out_channels,
|
||||||
scale_factor=self.scale_factor,
|
|
||||||
conv_cfg=self.conv_cfg,
|
conv_cfg=self.conv_cfg,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
act_cfg=self.act_cfg,
|
act_cfg=self.act_cfg,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user