FastSCNN type(dwchannels) changed to tuple.

pull/58/head
johnzja 2020-08-12 11:22:11 +08:00
parent ed3a6d0a70
commit b08e1d4e9e
3 changed files with 17 additions and 21 deletions

View File

@ -4,8 +4,7 @@ model = dict(
type='EncoderDecoder', type='EncoderDecoder',
backbone=dict( backbone=dict(
type='FastSCNN', type='FastSCNN',
downsample_dw_channels1=32, downsample_dw_channels=(32, 48),
downsample_dw_channels2=48,
global_in_channels=64, global_in_channels=64,
global_block_channels=(64, 96, 128), global_block_channels=(64, 96, 128),
global_out_channels=128, global_out_channels=128,

View File

@ -10,12 +10,11 @@ from ..builder import BACKBONES
class LearningToDownsample(nn.Module): class LearningToDownsample(nn.Module):
"""Learning to downsample module. """Learning to downsample module.
Args: Args:
in_channels (int): Number of input channels. in_channels (int): Number of input channels.
dw_channels1 (int): Number of output channels of the first dw_channels (tuple): Number of output channels of the first and
depthwise conv (dwconv) layer. the second depthwise conv (dwconv) layers.
dw_channels2 (int): Number of output channels of the second
dwconv layer.
out_channels (int): Number of output channels of the whole out_channels (int): Number of output channels of the whole
'learning to downsample' module. 'learning to downsample' module.
conv_cfg (dict | None): Config of conv layers. Default: None conv_cfg (dict | None): Config of conv layers. Default: None
@ -27,8 +26,7 @@ class LearningToDownsample(nn.Module):
def __init__(self, def __init__(self,
in_channels, in_channels,
dw_channels1, dw_channels,
dw_channels2,
out_channels, out_channels,
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
@ -37,6 +35,9 @@ class LearningToDownsample(nn.Module):
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.act_cfg = act_cfg self.act_cfg = act_cfg
dw_channels1 = dw_channels[0]
dw_channels2 = dw_channels[1]
self.conv = ConvModule( self.conv = ConvModule(
in_channels, in_channels,
dw_channels1, dw_channels1,
@ -235,12 +236,10 @@ class FastSCNN(nn.Module):
"""Fast-SCNN Backbone. """Fast-SCNN Backbone.
Args: Args:
in_channels (int): Number of input image channels. Default: 3. in_channels (int): Number of input image channels. Default: 3.
downsample_dw_channels1 (int): Number of output channels after downsample_dw_channels (tuple): Number of output channels after
the first conv layer in Learning-To-Downsample (LTD) module. the first conv layer & the second conv layer in
Default: 32. Learning-To-Downsample (LTD) module.
downsample_dw_channels2 (int): Number of output channels Default: (32, 48).
after the second conv layer in LTD.
Default: 48.
global_in_channels (int): Number of input channels of global_in_channels (int): Number of input channels of
Global Feature Extractor(GFE). Global Feature Extractor(GFE).
Equal to number of output channels of LTD. Equal to number of output channels of LTD.
@ -280,8 +279,7 @@ class FastSCNN(nn.Module):
def __init__(self, def __init__(self,
in_channels=3, in_channels=3,
downsample_dw_channels1=32, downsample_dw_channels=(32, 48),
downsample_dw_channels2=48,
global_in_channels=64, global_in_channels=64,
global_block_channels=(64, 96, 128), global_block_channels=(64, 96, 128),
global_out_channels=128, global_out_channels=128,
@ -307,8 +305,8 @@ class FastSCNN(nn.Module):
downsampling factor in the GFE module!') downsampling factor in the GFE module!')
self.in_channels = in_channels self.in_channels = in_channels
self.downsample_dw_channels1 = downsample_dw_channels1 self.downsample_dw_channels1 = downsample_dw_channels[0]
self.downsample_dw_channels2 = downsample_dw_channels2 self.downsample_dw_channels2 = downsample_dw_channels[1]
self.global_in_channels = global_in_channels self.global_in_channels = global_in_channels
self.global_block_channels = global_block_channels self.global_block_channels = global_block_channels
self.global_out_channels = global_out_channels self.global_out_channels = global_out_channels
@ -323,8 +321,7 @@ class FastSCNN(nn.Module):
self.align_corners = align_corners self.align_corners = align_corners
self.learning_to_downsample = LearningToDownsample( self.learning_to_downsample = LearningToDownsample(
in_channels, in_channels,
downsample_dw_channels1, downsample_dw_channels,
downsample_dw_channels2,
global_in_channels, global_in_channels,
conv_cfg=self.conv_cfg, conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,

View File

@ -669,7 +669,7 @@ def test_resnext_backbone():
def test_fastscnn_backbone(): def test_fastscnn_backbone():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# Fast-SCNN channel constraints. # Fast-SCNN channel constraints.
FastSCNN(3, 32, 48, 64, (64, 96, 128), 127, 64, 128) FastSCNN(3, (32, 48), 64, (64, 96, 128), 127, 64, 128)
# Test FastSCNN Standard Forward # Test FastSCNN Standard Forward
model = FastSCNN() model = FastSCNN()