Expand_ratio docstrings updated.

This commit is contained in:
johnzja 2020-08-14 13:13:24 +08:00
parent 88a123d16f
commit a9ac0d8188
2 changed files with 15 additions and 17 deletions

View File

@ -7,7 +7,7 @@ model = dict(
downsample_dw_channels=(32, 48), downsample_dw_channels=(32, 48),
global_in_channels=64, global_in_channels=64,
global_block_channels=(64, 96, 128), global_block_channels=(64, 96, 128),
global_block_downsample_factors=(2, 2, 1), global_block_strides=(2, 2, 1),
global_out_channels=128, global_out_channels=128,
higher_in_channels=64, higher_in_channels=64,
lower_in_channels=128, lower_in_channels=128,

View File

@ -79,13 +79,14 @@ class GlobalFeatureExtractor(nn.Module):
Default: (64, 96, 128) Default: (64, 96, 128)
out_channels(int): Number of output channels of the GFE module. out_channels(int): Number of output channels of the GFE module.
Default: 128 Default: 128
expand_ratio (int): Upsampling factor of each Inverted Residual expand_ratio (int): Adjusts number of channels of the hidden layer
module. Default: 6 in InvertedResidual by this amount.
Default: 6
num_blocks (tuple[int]): Tuple of ints. Each int specifies the num_blocks (tuple[int]): Tuple of ints. Each int specifies the
number of times each Inverted Residual module is repeated. number of times each Inverted Residual module is repeated.
The repeated Inverted Residual modules are called a 'group'. The repeated Inverted Residual modules are called a 'group'.
Default: (3, 3, 3) Default: (3, 3, 3)
downsample_factors (tuple[int]): Tuple of ints. Each int specifies strides (tuple[int]): Tuple of ints. Each int specifies
the downsampling factor of each 'group'. the downsampling factor of each 'group'.
Default: (2, 2, 1) Default: (2, 2, 1)
pool_scales (tuple[int]): Tuple of ints. Each int specifies pool_scales (tuple[int]): Tuple of ints. Each int specifies
@ -106,7 +107,7 @@ class GlobalFeatureExtractor(nn.Module):
out_channels=128, out_channels=128,
expand_ratio=6, expand_ratio=6,
num_blocks=(3, 3, 3), num_blocks=(3, 3, 3),
downsample_factors=(2, 2, 1), strides=(2, 2, 1),
pool_scales=(1, 2, 3, 6), pool_scales=(1, 2, 3, 6),
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
@ -118,17 +119,14 @@ class GlobalFeatureExtractor(nn.Module):
self.act_cfg = act_cfg self.act_cfg = act_cfg
assert len(block_channels) == len(num_blocks) == 3 assert len(block_channels) == len(num_blocks) == 3
self.bottleneck1 = self._make_layer(in_channels, block_channels[0], self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
num_blocks[0], num_blocks[0], strides[0],
downsample_factors[0],
expand_ratio) expand_ratio)
self.bottleneck2 = self._make_layer(block_channels[0], self.bottleneck2 = self._make_layer(block_channels[0],
block_channels[1], num_blocks[1], block_channels[1], num_blocks[1],
downsample_factors[1], strides[1], expand_ratio)
expand_ratio)
self.bottleneck3 = self._make_layer(block_channels[1], self.bottleneck3 = self._make_layer(block_channels[1],
block_channels[2], num_blocks[2], block_channels[2], num_blocks[2],
downsample_factors[2], strides[2], expand_ratio)
expand_ratio)
self.ppm = PPM( self.ppm = PPM(
pool_scales, pool_scales,
block_channels[2], block_channels[2],
@ -269,8 +267,8 @@ class FastSCNN(nn.Module):
the output channels for each of the MobileNet-v2 bottleneck the output channels for each of the MobileNet-v2 bottleneck
residual blocks in GFE. residual blocks in GFE.
Default: (64, 96, 128). Default: (64, 96, 128).
global_block_downsample_factors (tuple[int]): Tuple of integers global_block_strides (tuple[int]): Tuple of integers
that describe the downsampling factors for each of the that describe the strides (downsampling factors) for each of the
MobileNet-v2 bottleneck residual blocks in GFE. MobileNet-v2 bottleneck residual blocks in GFE.
Default: (2, 2, 1). Default: (2, 2, 1).
global_out_channels (int): Number of output channels of GFE. global_out_channels (int): Number of output channels of GFE.
@ -303,7 +301,7 @@ class FastSCNN(nn.Module):
downsample_dw_channels=(32, 48), downsample_dw_channels=(32, 48),
global_in_channels=64, global_in_channels=64,
global_block_channels=(64, 96, 128), global_block_channels=(64, 96, 128),
global_block_downsample_factors=(2, 2, 1), global_block_strides=(2, 2, 1),
global_out_channels=128, global_out_channels=128,
higher_in_channels=64, higher_in_channels=64,
lower_in_channels=128, lower_in_channels=128,
@ -324,7 +322,7 @@ class FastSCNN(nn.Module):
# Calculate scale factor used in FFM. # Calculate scale factor used in FFM.
self.scale_factor = 1 self.scale_factor = 1
for factor in global_block_downsample_factors: for factor in global_block_strides:
self.scale_factor *= factor self.scale_factor *= factor
self.in_channels = in_channels self.in_channels = in_channels
@ -332,7 +330,7 @@ class FastSCNN(nn.Module):
self.downsample_dw_channels2 = downsample_dw_channels[1] self.downsample_dw_channels2 = downsample_dw_channels[1]
self.global_in_channels = global_in_channels self.global_in_channels = global_in_channels
self.global_block_channels = global_block_channels self.global_block_channels = global_block_channels
self.global_block_downsample_factors = global_block_downsample_factors self.global_block_strides = global_block_strides
self.global_out_channels = global_out_channels self.global_out_channels = global_out_channels
self.higher_in_channels = higher_in_channels self.higher_in_channels = higher_in_channels
self.lower_in_channels = lower_in_channels self.lower_in_channels = lower_in_channels
@ -353,7 +351,7 @@ class FastSCNN(nn.Module):
global_in_channels, global_in_channels,
global_block_channels, global_block_channels,
global_out_channels, global_out_channels,
downsample_factors=self.global_block_downsample_factors, downsample_factors=self.global_block_strides,
conv_cfg=self.conv_cfg, conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg, act_cfg=self.act_cfg,