FastSCNN type(dwchannels) changed to tuple.

pull/58/head
johnzja 2020-08-12 11:22:11 +08:00
parent ed3a6d0a70
commit b08e1d4e9e
3 changed files with 17 additions and 21 deletions

View File

@ -4,8 +4,7 @@ model = dict(
type='EncoderDecoder',
backbone=dict(
type='FastSCNN',
downsample_dw_channels1=32,
downsample_dw_channels2=48,
downsample_dw_channels=(32, 48),
global_in_channels=64,
global_block_channels=(64, 96, 128),
global_out_channels=128,

View File

@ -10,12 +10,11 @@ from ..builder import BACKBONES
class LearningToDownsample(nn.Module):
"""Learning to downsample module.
Args:
in_channels (int): Number of input channels.
dw_channels1 (int): Number of output channels of the first
depthwise conv (dwconv) layer.
dw_channels2 (int): Number of output channels of the second
dwconv layer.
dw_channels (tuple): Number of output channels of the first and
the second depthwise conv (dwconv) layers.
out_channels (int): Number of output channels of the whole
'learning to downsample' module.
conv_cfg (dict | None): Config of conv layers. Default: None
@ -27,8 +26,7 @@ class LearningToDownsample(nn.Module):
def __init__(self,
in_channels,
dw_channels1,
dw_channels2,
dw_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
@ -37,6 +35,9 @@ class LearningToDownsample(nn.Module):
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
dw_channels1 = dw_channels[0]
dw_channels2 = dw_channels[1]
self.conv = ConvModule(
in_channels,
dw_channels1,
@ -235,12 +236,10 @@ class FastSCNN(nn.Module):
"""Fast-SCNN Backbone.
Args:
in_channels (int): Number of input image channels. Default: 3.
downsample_dw_channels1 (int): Number of output channels after
the first conv layer in Learning-To-Downsample (LTD) module.
Default: 32.
downsample_dw_channels2 (int): Number of output channels
after the second conv layer in LTD.
Default: 48.
downsample_dw_channels (tuple): Number of output channels after
the first conv layer & the second conv layer in
Learning-To-Downsample (LTD) module.
Default: (32, 48).
global_in_channels (int): Number of input channels of
Global Feature Extractor(GFE).
Equal to number of output channels of LTD.
@ -280,8 +279,7 @@ class FastSCNN(nn.Module):
def __init__(self,
in_channels=3,
downsample_dw_channels1=32,
downsample_dw_channels2=48,
downsample_dw_channels=(32, 48),
global_in_channels=64,
global_block_channels=(64, 96, 128),
global_out_channels=128,
@ -307,8 +305,8 @@ class FastSCNN(nn.Module):
downsampling factor in the GFE module!')
self.in_channels = in_channels
self.downsample_dw_channels1 = downsample_dw_channels1
self.downsample_dw_channels2 = downsample_dw_channels2
self.downsample_dw_channels1 = downsample_dw_channels[0]
self.downsample_dw_channels2 = downsample_dw_channels[1]
self.global_in_channels = global_in_channels
self.global_block_channels = global_block_channels
self.global_out_channels = global_out_channels
@ -323,8 +321,7 @@ class FastSCNN(nn.Module):
self.align_corners = align_corners
self.learning_to_downsample = LearningToDownsample(
in_channels,
downsample_dw_channels1,
downsample_dw_channels2,
downsample_dw_channels,
global_in_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,

View File

@ -669,7 +669,7 @@ def test_resnext_backbone():
def test_fastscnn_backbone():
with pytest.raises(AssertionError):
# Fast-SCNN channel constraints.
FastSCNN(3, 32, 48, 64, (64, 96, 128), 127, 64, 128)
FastSCNN(3, (32, 48), 64, (64, 96, 128), 127, 64, 128)
# Test FastSCNN Standard Forward
model = FastSCNN()