add unit test for inverted_residual

This commit is contained in:
johnzja 2020-08-11 20:04:31 +08:00
parent 8c553e0eb6
commit ae85850d30
6 changed files with 87 additions and 53 deletions

View File

@ -12,20 +12,17 @@ class LearningToDownsample(nn.Module):
"""Learning to downsample module.
Args:
in_channels (int): Number of input channels.
dw_channels1 (int): Number of output channels of the first
depthwise conv (dwconv) layer.
dw_channels2 (int): Number of output channels of the second
dwconv layer.
out_channels (int): Number of output channels of the whole
'learning to downsample' module.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
"""
def __init__(self,
@ -89,10 +86,13 @@ class GlobalFeatureExtractor(nn.Module):
pool_scales (tuple): Tuple of ints. Each int specifies the parameter
required in 'global average pooling' within PPM.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,
@ -105,7 +105,7 @@ class GlobalFeatureExtractor(nn.Module):
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=True):
align_corners=False):
super(GlobalFeatureExtractor, self).__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
@ -169,10 +169,13 @@ class FeatureFusionModule(nn.Module):
Should be coherent with the downsampling factor determined
by the GFE module.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,
@ -183,7 +186,7 @@ class FeatureFusionModule(nn.Module):
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=True):
align_corners=False):
super(FeatureFusionModule, self).__init__()
self.scale_factor = scale_factor
self.conv_cfg = conv_cfg
@ -231,46 +234,48 @@ class FeatureFusionModule(nn.Module):
class FastSCNN(nn.Module):
"""Fast-SCNN Backbone.
Args:
in_channels (int): Number of input image channels. Default=3 (RGB)
in_channels (int): Number of input image channels. Default: 3.
downsample_dw_channels1 (int): Number of output channels after
the first conv layer in Learning-To-Downsample (LTD) module.
Default: 32.
downsample_dw_channels2 (int): Number of output channels
after the second conv layer in LTD.
Default: 48.
global_in_channels (int): Number of input channels of
Global Feature Extractor(GFE).
Equal to number of output channels of LTD.
Default: 64.
global_block_channels (tuple): Tuple of integers that describe
the output channels for each of the MobileNet-v2 bottleneck
residual blocks in GFE.
Default: (64, 96, 128).
global_out_channels (int): Number of output channels of GFE.
Default: 128.
higher_in_channels (int): Number of input channels of the higher
resolution branch in FFM.
Equal to global_in_channels.
Default: 64.
lower_in_channels (int): Number of input channels of the lower
resolution branch in FFM.
Equal to global_out_channels.
Default: 128.
fusion_out_channels (int): Number of output channels of FFM.
Default: 128.
scale_factor (int): The upsampling factor of the higher resolution
branch in FFM.
Equal to the downsampling factor in GFE.
Default: 4.
out_indices (tuple): Tuple of indices of list
[higher_res_features, lower_res_features, fusion_output].
Often set to (0,1,2) to enable aux. heads.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
Default: (0, 1, 2).
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,

View File

@ -11,21 +11,16 @@ class DepthwiseSeparableFCNHead(FCNHead):
This head is implemented according to Fast-SCNN.
Args:
in_channels(int): Number of output channels of FFM.
channels(int): Number of middle-stage channels in the decode head.
concat_input(bool): Whether to concatenate original decode input into
the result of several consecutive convolution layers.
Default: True.
num_classes(int): Used to determine the dimension of
final prediction tensor.
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
norm_cfg (dict|None): Config of norm layers.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
loss_decode(dict): Config of loss type and some
relevant additional options.
"""

View File

@ -9,10 +9,12 @@ class InvertedResidual(nn.Module):
oup (int): output channels.
stride (int): downsampling factor.
expand_ratio (int): 1 or 2.
dilation (int): Dilated conv. Default: 1
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
dilation (int): Dilated conv. Default: 1.
conv_cfg (dict | None): Config of conv layers. Default: None.
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN').
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU6').
"""
def __init__(self,

View File

@ -41,12 +41,9 @@ def collect_env():
for name, devids in devices.items():
env_info['GPU ' + ','.join(devids)] = name
try:
gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
gcc = gcc.decode('utf-8').strip()
env_info['GCC'] = gcc
except subprocess.CalledProcessError:
env_info['GCC'] = 'n/a'
gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
gcc = gcc.decode('utf-8').strip()
env_info['GCC'] = gcc
env_info['PyTorch'] = torch.__version__
env_info['PyTorch compiling details'] = get_build_config()

View File

@ -675,14 +675,14 @@ def test_fastscnn_backbone():
model = FastSCNN()
model.init_weights()
model.train()
num_batch_picts = 4
imgs = torch.randn(num_batch_picts, 3, 1024, 2048)
batch_size = 4
imgs = torch.randn(batch_size, 3, 1024, 2048)
feat = model(imgs)
assert len(feat) == 3
# higher-res
assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256])
assert feat[0].shape == torch.Size([batch_size, 64, 128, 256])
# lower-res
assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64])
assert feat[1].shape == torch.Size([batch_size, 128, 32, 64])
# FFM output
assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256])
assert feat[2].shape == torch.Size([batch_size, 128, 128, 256])

View File

@ -0,0 +1,35 @@
import pytest
import torch
import torch.nn as nn
from mmseg.ops import InvertedResidual
def test_inv_residual():
with pytest.raises(AssertionError):
# test stride assertion.
InvertedResidual(32, 32, 3, 4)
# test default config with res connection.
# set expand_ratio = 4, stride = 1 and inp=oup.
inv_module = InvertedResidual(32, 32, 1, 4)
assert inv_module.use_res_connect
assert inv_module.conv[0].kernel_size == 3
assert inv_module.conv[0].padding == 1
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 32, 64, 64)
# test inv_residual module without res connection.
# set expand_ratio = 4, stride = 2.
inv_module = InvertedResidual(32, 32, 2, 4)
assert not inv_module.use_res_connect
assert inv_module.conv[0].kernel_size == 1
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 32, 32, 32)