FastSCNN type(dwchannels) changed to tuple.
parent
ed3a6d0a70
commit
b08e1d4e9e
|
@ -4,8 +4,7 @@ model = dict(
|
|||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
type='FastSCNN',
|
||||
downsample_dw_channels1=32,
|
||||
downsample_dw_channels2=48,
|
||||
downsample_dw_channels=(32, 48),
|
||||
global_in_channels=64,
|
||||
global_block_channels=(64, 96, 128),
|
||||
global_out_channels=128,
|
||||
|
|
|
@ -10,12 +10,11 @@ from ..builder import BACKBONES
|
|||
|
||||
class LearningToDownsample(nn.Module):
|
||||
"""Learning to downsample module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
dw_channels1 (int): Number of output channels of the first
|
||||
depthwise conv (dwconv) layer.
|
||||
dw_channels2 (int): Number of output channels of the second
|
||||
dwconv layer.
|
||||
dw_channels (tuple): Number of output channels of the first and
|
||||
the second depthwise conv (dwconv) layers.
|
||||
out_channels (int): Number of output channels of the whole
|
||||
'learning to downsample' module.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
|
@ -27,8 +26,7 @@ class LearningToDownsample(nn.Module):
|
|||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
dw_channels1,
|
||||
dw_channels2,
|
||||
dw_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
|
@ -37,6 +35,9 @@ class LearningToDownsample(nn.Module):
|
|||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
dw_channels1 = dw_channels[0]
|
||||
dw_channels2 = dw_channels[1]
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
dw_channels1,
|
||||
|
@ -235,12 +236,10 @@ class FastSCNN(nn.Module):
|
|||
"""Fast-SCNN Backbone.
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
downsample_dw_channels1 (int): Number of output channels after
|
||||
the first conv layer in Learning-To-Downsample (LTD) module.
|
||||
Default: 32.
|
||||
downsample_dw_channels2 (int): Number of output channels
|
||||
after the second conv layer in LTD.
|
||||
Default: 48.
|
||||
downsample_dw_channels (tuple): Number of output channels after
|
||||
the first conv layer & the second conv layer in
|
||||
Learning-To-Downsample (LTD) module.
|
||||
Default: (32, 48).
|
||||
global_in_channels (int): Number of input channels of
|
||||
Global Feature Extractor(GFE).
|
||||
Equal to number of output channels of LTD.
|
||||
|
@ -280,8 +279,7 @@ class FastSCNN(nn.Module):
|
|||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
downsample_dw_channels1=32,
|
||||
downsample_dw_channels2=48,
|
||||
downsample_dw_channels=(32, 48),
|
||||
global_in_channels=64,
|
||||
global_block_channels=(64, 96, 128),
|
||||
global_out_channels=128,
|
||||
|
@ -307,8 +305,8 @@ class FastSCNN(nn.Module):
|
|||
downsampling factor in the GFE module!')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_dw_channels1 = downsample_dw_channels1
|
||||
self.downsample_dw_channels2 = downsample_dw_channels2
|
||||
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
||||
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
||||
self.global_in_channels = global_in_channels
|
||||
self.global_block_channels = global_block_channels
|
||||
self.global_out_channels = global_out_channels
|
||||
|
@ -323,8 +321,7 @@ class FastSCNN(nn.Module):
|
|||
self.align_corners = align_corners
|
||||
self.learning_to_downsample = LearningToDownsample(
|
||||
in_channels,
|
||||
downsample_dw_channels1,
|
||||
downsample_dw_channels2,
|
||||
downsample_dw_channels,
|
||||
global_in_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
|
|
|
@ -669,7 +669,7 @@ def test_resnext_backbone():
|
|||
def test_fastscnn_backbone():
|
||||
with pytest.raises(AssertionError):
|
||||
# Fast-SCNN channel constraints.
|
||||
FastSCNN(3, 32, 48, 64, (64, 96, 128), 127, 64, 128)
|
||||
FastSCNN(3, (32, 48), 64, (64, 96, 128), 127, 64, 128)
|
||||
|
||||
# Test FastSCNN Standard Forward
|
||||
model = FastSCNN()
|
||||
|
|
Loading…
Reference in New Issue