mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Arg scale_factor deleted.
This commit is contained in:
parent
a8a5ff80b3
commit
88a123d16f
@ -7,11 +7,11 @@ model = dict(
|
|||||||
downsample_dw_channels=(32, 48),
|
downsample_dw_channels=(32, 48),
|
||||||
global_in_channels=64,
|
global_in_channels=64,
|
||||||
global_block_channels=(64, 96, 128),
|
global_block_channels=(64, 96, 128),
|
||||||
|
global_block_downsample_factors=(2, 2, 1),
|
||||||
global_out_channels=128,
|
global_out_channels=128,
|
||||||
higher_in_channels=64,
|
higher_in_channels=64,
|
||||||
lower_in_channels=128,
|
lower_in_channels=128,
|
||||||
fusion_out_channels=128,
|
fusion_out_channels=128,
|
||||||
scale_factor=4,
|
|
||||||
out_indices=(0, 1, 2),
|
out_indices=(0, 1, 2),
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
align_corners=False),
|
align_corners=False),
|
||||||
|
@ -83,7 +83,11 @@ class GlobalFeatureExtractor(nn.Module):
|
|||||||
module. Default: 6
|
module. Default: 6
|
||||||
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
|
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
|
||||||
number of times each Inverted Residual module is repeated.
|
number of times each Inverted Residual module is repeated.
|
||||||
|
The repeated Inverted Residual modules are called a 'group'.
|
||||||
Default: (3, 3, 3)
|
Default: (3, 3, 3)
|
||||||
|
downsample_factors (tuple[int]): Tuple of ints. Each int specifies
|
||||||
|
the downsampling factor of each 'group'.
|
||||||
|
Default: (2, 2, 1)
|
||||||
pool_scales (tuple[int]): Tuple of ints. Each int specifies
|
pool_scales (tuple[int]): Tuple of ints. Each int specifies
|
||||||
the parameter required in 'global average pooling' within PPM.
|
the parameter required in 'global average pooling' within PPM.
|
||||||
Default: (1, 2, 3, 6)
|
Default: (1, 2, 3, 6)
|
||||||
@ -102,6 +106,7 @@ class GlobalFeatureExtractor(nn.Module):
|
|||||||
out_channels=128,
|
out_channels=128,
|
||||||
expand_ratio=6,
|
expand_ratio=6,
|
||||||
num_blocks=(3, 3, 3),
|
num_blocks=(3, 3, 3),
|
||||||
|
downsample_factors=(2, 2, 1),
|
||||||
pool_scales=(1, 2, 3, 6),
|
pool_scales=(1, 2, 3, 6),
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
@ -113,13 +118,17 @@ class GlobalFeatureExtractor(nn.Module):
|
|||||||
self.act_cfg = act_cfg
|
self.act_cfg = act_cfg
|
||||||
assert len(block_channels) == len(num_blocks) == 3
|
assert len(block_channels) == len(num_blocks) == 3
|
||||||
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
|
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
|
||||||
num_blocks[0], 2, expand_ratio)
|
num_blocks[0],
|
||||||
|
downsample_factors[0],
|
||||||
|
expand_ratio)
|
||||||
self.bottleneck2 = self._make_layer(block_channels[0],
|
self.bottleneck2 = self._make_layer(block_channels[0],
|
||||||
block_channels[1], num_blocks[1],
|
block_channels[1], num_blocks[1],
|
||||||
2, expand_ratio)
|
downsample_factors[1],
|
||||||
|
expand_ratio)
|
||||||
self.bottleneck3 = self._make_layer(block_channels[1],
|
self.bottleneck3 = self._make_layer(block_channels[1],
|
||||||
block_channels[2], num_blocks[2],
|
block_channels[2], num_blocks[2],
|
||||||
1, expand_ratio)
|
downsample_factors[2],
|
||||||
|
expand_ratio)
|
||||||
self.ppm = PPM(
|
self.ppm = PPM(
|
||||||
pool_scales,
|
pool_scales,
|
||||||
block_channels[2],
|
block_channels[2],
|
||||||
@ -260,6 +269,10 @@ class FastSCNN(nn.Module):
|
|||||||
the output channels for each of the MobileNet-v2 bottleneck
|
the output channels for each of the MobileNet-v2 bottleneck
|
||||||
residual blocks in GFE.
|
residual blocks in GFE.
|
||||||
Default: (64, 96, 128).
|
Default: (64, 96, 128).
|
||||||
|
global_block_downsample_factors (tuple[int]): Tuple of integers
|
||||||
|
that describe the downsampling factors for each of the
|
||||||
|
MobileNet-v2 bottleneck residual blocks in GFE.
|
||||||
|
Default: (2, 2, 1).
|
||||||
global_out_channels (int): Number of output channels of GFE.
|
global_out_channels (int): Number of output channels of GFE.
|
||||||
Default: 128.
|
Default: 128.
|
||||||
higher_in_channels (int): Number of input channels of the higher
|
higher_in_channels (int): Number of input channels of the higher
|
||||||
@ -272,10 +285,6 @@ class FastSCNN(nn.Module):
|
|||||||
Default: 128.
|
Default: 128.
|
||||||
fusion_out_channels (int): Number of output channels of FFM.
|
fusion_out_channels (int): Number of output channels of FFM.
|
||||||
Default: 128.
|
Default: 128.
|
||||||
scale_factor (int): The upsampling factor of the higher resolution
|
|
||||||
branch in FFM.
|
|
||||||
Equal to the downsampling factor in GFE.
|
|
||||||
Default: 4.
|
|
||||||
out_indices (tuple): Tuple of indices of list
|
out_indices (tuple): Tuple of indices of list
|
||||||
[higher_res_features, lower_res_features, fusion_output].
|
[higher_res_features, lower_res_features, fusion_output].
|
||||||
Often set to (0,1,2) to enable aux. heads.
|
Often set to (0,1,2) to enable aux. heads.
|
||||||
@ -294,11 +303,11 @@ class FastSCNN(nn.Module):
|
|||||||
downsample_dw_channels=(32, 48),
|
downsample_dw_channels=(32, 48),
|
||||||
global_in_channels=64,
|
global_in_channels=64,
|
||||||
global_block_channels=(64, 96, 128),
|
global_block_channels=(64, 96, 128),
|
||||||
|
global_block_downsample_factors=(2, 2, 1),
|
||||||
global_out_channels=128,
|
global_out_channels=128,
|
||||||
higher_in_channels=64,
|
higher_in_channels=64,
|
||||||
lower_in_channels=128,
|
lower_in_channels=128,
|
||||||
fusion_out_channels=128,
|
fusion_out_channels=128,
|
||||||
scale_factor=4,
|
|
||||||
out_indices=(0, 1, 2),
|
out_indices=(0, 1, 2),
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
@ -312,20 +321,22 @@ class FastSCNN(nn.Module):
|
|||||||
elif global_out_channels != lower_in_channels:
|
elif global_out_channels != lower_in_channels:
|
||||||
raise AssertionError('Global Output Channels must be the same \
|
raise AssertionError('Global Output Channels must be the same \
|
||||||
with Lower Input Channels!')
|
with Lower Input Channels!')
|
||||||
if scale_factor != 4:
|
|
||||||
raise AssertionError('Scale-factor must compensate the \
|
# Calculate scale factor used in FFM.
|
||||||
downsampling factor in the GFE module!')
|
self.scale_factor = 1
|
||||||
|
for factor in global_block_downsample_factors:
|
||||||
|
self.scale_factor *= factor
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
||||||
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
||||||
self.global_in_channels = global_in_channels
|
self.global_in_channels = global_in_channels
|
||||||
self.global_block_channels = global_block_channels
|
self.global_block_channels = global_block_channels
|
||||||
|
self.global_block_downsample_factors = global_block_downsample_factors
|
||||||
self.global_out_channels = global_out_channels
|
self.global_out_channels = global_out_channels
|
||||||
self.higher_in_channels = higher_in_channels
|
self.higher_in_channels = higher_in_channels
|
||||||
self.lower_in_channels = lower_in_channels
|
self.lower_in_channels = lower_in_channels
|
||||||
self.fusion_out_channels = fusion_out_channels
|
self.fusion_out_channels = fusion_out_channels
|
||||||
self.scale_factor = scale_factor
|
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
self.conv_cfg = conv_cfg
|
self.conv_cfg = conv_cfg
|
||||||
self.norm_cfg = norm_cfg
|
self.norm_cfg = norm_cfg
|
||||||
@ -342,6 +353,7 @@ class FastSCNN(nn.Module):
|
|||||||
global_in_channels,
|
global_in_channels,
|
||||||
global_block_channels,
|
global_block_channels,
|
||||||
global_out_channels,
|
global_out_channels,
|
||||||
|
downsample_factors=self.global_block_downsample_factors,
|
||||||
conv_cfg=self.conv_cfg,
|
conv_cfg=self.conv_cfg,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
act_cfg=self.act_cfg,
|
act_cfg=self.act_cfg,
|
||||||
|
@ -47,7 +47,6 @@ def check_norm_state(modules, train_state):
|
|||||||
|
|
||||||
|
|
||||||
def test_resnet_basic_block():
|
def test_resnet_basic_block():
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# Not implemented yet.
|
# Not implemented yet.
|
||||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||||
@ -97,7 +96,6 @@ def test_resnet_basic_block():
|
|||||||
|
|
||||||
|
|
||||||
def test_resnet_bottleneck():
|
def test_resnet_bottleneck():
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# Style must be in ['pytorch', 'caffe']
|
# Style must be in ['pytorch', 'caffe']
|
||||||
Bottleneck(64, 64, style='tensorflow')
|
Bottleneck(64, 64, style='tensorflow')
|
||||||
@ -669,7 +667,12 @@ def test_resnext_backbone():
|
|||||||
def test_fastscnn_backbone():
|
def test_fastscnn_backbone():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# Fast-SCNN channel constraints.
|
# Fast-SCNN channel constraints.
|
||||||
FastSCNN(3, (32, 48), 64, (64, 96, 128), 127, 64, 128)
|
FastSCNN(
|
||||||
|
3, (32, 48),
|
||||||
|
64, (64, 96, 128), (2, 2, 1),
|
||||||
|
global_out_channels=127,
|
||||||
|
higher_in_channels=64,
|
||||||
|
lower_in_channels=128)
|
||||||
|
|
||||||
# Test FastSCNN Standard Forward
|
# Test FastSCNN Standard Forward
|
||||||
model = FastSCNN()
|
model = FastSCNN()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user