Arg scale_factor deleted.

This commit is contained in:
johnzja 2020-08-14 12:09:22 +08:00
parent a8a5ff80b3
commit 88a123d16f
3 changed files with 31 additions and 16 deletions

View File

@ -7,11 +7,11 @@ model = dict(
downsample_dw_channels=(32, 48), downsample_dw_channels=(32, 48),
global_in_channels=64, global_in_channels=64,
global_block_channels=(64, 96, 128), global_block_channels=(64, 96, 128),
global_block_downsample_factors=(2, 2, 1),
global_out_channels=128, global_out_channels=128,
higher_in_channels=64, higher_in_channels=64,
lower_in_channels=128, lower_in_channels=128,
fusion_out_channels=128, fusion_out_channels=128,
scale_factor=4,
out_indices=(0, 1, 2), out_indices=(0, 1, 2),
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
align_corners=False), align_corners=False),

View File

@ -83,7 +83,11 @@ class GlobalFeatureExtractor(nn.Module):
module. Default: 6 module. Default: 6
num_blocks (tuple[int]): Tuple of ints. Each int specifies the num_blocks (tuple[int]): Tuple of ints. Each int specifies the
number of times each Inverted Residual module is repeated. number of times each Inverted Residual module is repeated.
The repeated Inverted Residual modules are called a 'group'.
Default: (3, 3, 3) Default: (3, 3, 3)
downsample_factors (tuple[int]): Tuple of ints. Each int specifies
the downsampling factor of each 'group'.
Default: (2, 2, 1)
pool_scales (tuple[int]): Tuple of ints. Each int specifies pool_scales (tuple[int]): Tuple of ints. Each int specifies
the parameter required in 'global average pooling' within PPM. the parameter required in 'global average pooling' within PPM.
Default: (1, 2, 3, 6) Default: (1, 2, 3, 6)
@ -102,6 +106,7 @@ class GlobalFeatureExtractor(nn.Module):
out_channels=128, out_channels=128,
expand_ratio=6, expand_ratio=6,
num_blocks=(3, 3, 3), num_blocks=(3, 3, 3),
downsample_factors=(2, 2, 1),
pool_scales=(1, 2, 3, 6), pool_scales=(1, 2, 3, 6),
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
@ -113,13 +118,17 @@ class GlobalFeatureExtractor(nn.Module):
self.act_cfg = act_cfg self.act_cfg = act_cfg
assert len(block_channels) == len(num_blocks) == 3 assert len(block_channels) == len(num_blocks) == 3
self.bottleneck1 = self._make_layer(in_channels, block_channels[0], self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
num_blocks[0], 2, expand_ratio) num_blocks[0],
downsample_factors[0],
expand_ratio)
self.bottleneck2 = self._make_layer(block_channels[0], self.bottleneck2 = self._make_layer(block_channels[0],
block_channels[1], num_blocks[1], block_channels[1], num_blocks[1],
2, expand_ratio) downsample_factors[1],
expand_ratio)
self.bottleneck3 = self._make_layer(block_channels[1], self.bottleneck3 = self._make_layer(block_channels[1],
block_channels[2], num_blocks[2], block_channels[2], num_blocks[2],
1, expand_ratio) downsample_factors[2],
expand_ratio)
self.ppm = PPM( self.ppm = PPM(
pool_scales, pool_scales,
block_channels[2], block_channels[2],
@ -260,6 +269,10 @@ class FastSCNN(nn.Module):
the output channels for each of the MobileNet-v2 bottleneck the output channels for each of the MobileNet-v2 bottleneck
residual blocks in GFE. residual blocks in GFE.
Default: (64, 96, 128). Default: (64, 96, 128).
global_block_downsample_factors (tuple[int]): Tuple of integers
that describe the downsampling factors for each of the
MobileNet-v2 bottleneck residual blocks in GFE.
Default: (2, 2, 1).
global_out_channels (int): Number of output channels of GFE. global_out_channels (int): Number of output channels of GFE.
Default: 128. Default: 128.
higher_in_channels (int): Number of input channels of the higher higher_in_channels (int): Number of input channels of the higher
@ -272,10 +285,6 @@ class FastSCNN(nn.Module):
Default: 128. Default: 128.
fusion_out_channels (int): Number of output channels of FFM. fusion_out_channels (int): Number of output channels of FFM.
Default: 128. Default: 128.
scale_factor (int): The upsampling factor of the higher resolution
branch in FFM.
Equal to the downsampling factor in GFE.
Default: 4.
out_indices (tuple): Tuple of indices of list out_indices (tuple): Tuple of indices of list
[higher_res_features, lower_res_features, fusion_output]. [higher_res_features, lower_res_features, fusion_output].
Often set to (0,1,2) to enable aux. heads. Often set to (0,1,2) to enable aux. heads.
@ -294,11 +303,11 @@ class FastSCNN(nn.Module):
downsample_dw_channels=(32, 48), downsample_dw_channels=(32, 48),
global_in_channels=64, global_in_channels=64,
global_block_channels=(64, 96, 128), global_block_channels=(64, 96, 128),
global_block_downsample_factors=(2, 2, 1),
global_out_channels=128, global_out_channels=128,
higher_in_channels=64, higher_in_channels=64,
lower_in_channels=128, lower_in_channels=128,
fusion_out_channels=128, fusion_out_channels=128,
scale_factor=4,
out_indices=(0, 1, 2), out_indices=(0, 1, 2),
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
@ -312,20 +321,22 @@ class FastSCNN(nn.Module):
elif global_out_channels != lower_in_channels: elif global_out_channels != lower_in_channels:
raise AssertionError('Global Output Channels must be the same \ raise AssertionError('Global Output Channels must be the same \
with Lower Input Channels!') with Lower Input Channels!')
if scale_factor != 4:
raise AssertionError('Scale-factor must compensate the \ # Calculate scale factor used in FFM.
downsampling factor in the GFE module!') self.scale_factor = 1
for factor in global_block_downsample_factors:
self.scale_factor *= factor
self.in_channels = in_channels self.in_channels = in_channels
self.downsample_dw_channels1 = downsample_dw_channels[0] self.downsample_dw_channels1 = downsample_dw_channels[0]
self.downsample_dw_channels2 = downsample_dw_channels[1] self.downsample_dw_channels2 = downsample_dw_channels[1]
self.global_in_channels = global_in_channels self.global_in_channels = global_in_channels
self.global_block_channels = global_block_channels self.global_block_channels = global_block_channels
self.global_block_downsample_factors = global_block_downsample_factors
self.global_out_channels = global_out_channels self.global_out_channels = global_out_channels
self.higher_in_channels = higher_in_channels self.higher_in_channels = higher_in_channels
self.lower_in_channels = lower_in_channels self.lower_in_channels = lower_in_channels
self.fusion_out_channels = fusion_out_channels self.fusion_out_channels = fusion_out_channels
self.scale_factor = scale_factor
self.out_indices = out_indices self.out_indices = out_indices
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
@ -342,6 +353,7 @@ class FastSCNN(nn.Module):
global_in_channels, global_in_channels,
global_block_channels, global_block_channels,
global_out_channels, global_out_channels,
downsample_factors=self.global_block_downsample_factors,
conv_cfg=self.conv_cfg, conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg, act_cfg=self.act_cfg,

View File

@ -47,7 +47,6 @@ def check_norm_state(modules, train_state):
def test_resnet_basic_block(): def test_resnet_basic_block():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# Not implemented yet. # Not implemented yet.
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
@ -97,7 +96,6 @@ def test_resnet_basic_block():
def test_resnet_bottleneck(): def test_resnet_bottleneck():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe'] # Style must be in ['pytorch', 'caffe']
Bottleneck(64, 64, style='tensorflow') Bottleneck(64, 64, style='tensorflow')
@ -669,7 +667,12 @@ def test_resnext_backbone():
def test_fastscnn_backbone(): def test_fastscnn_backbone():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# Fast-SCNN channel constraints. # Fast-SCNN channel constraints.
FastSCNN(3, (32, 48), 64, (64, 96, 128), 127, 64, 128) FastSCNN(
3, (32, 48),
64, (64, 96, 128), (2, 2, 1),
global_out_channels=127,
higher_in_channels=64,
lower_in_channels=128)
# Test FastSCNN Standard Forward # Test FastSCNN Standard Forward
model = FastSCNN() model = FastSCNN()