mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add unit test for inverted_residual
This commit is contained in:
parent
8c553e0eb6
commit
ae85850d30
@ -12,20 +12,17 @@ class LearningToDownsample(nn.Module):
|
|||||||
"""Learning to downsample module.
|
"""Learning to downsample module.
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): Number of input channels.
|
in_channels (int): Number of input channels.
|
||||||
|
|
||||||
dw_channels1 (int): Number of output channels of the first
|
dw_channels1 (int): Number of output channels of the first
|
||||||
depthwise conv (dwconv) layer.
|
depthwise conv (dwconv) layer.
|
||||||
|
|
||||||
dw_channels2 (int): Number of output channels of the second
|
dw_channels2 (int): Number of output channels of the second
|
||||||
dwconv layer.
|
dwconv layer.
|
||||||
|
|
||||||
out_channels (int): Number of output channels of the whole
|
out_channels (int): Number of output channels of the whole
|
||||||
'learning to downsample' module.
|
'learning to downsample' module.
|
||||||
|
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||||
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
conv_cfg (dict|None): Config of conv layers.
|
dict(type='BN')
|
||||||
norm_cfg (dict|None): Config of norm layers.
|
act_cfg (dict): Config of activation layers. Default:
|
||||||
act_cfg (dict): Config of activation layers.
|
dict(type='ReLU')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -89,10 +86,13 @@ class GlobalFeatureExtractor(nn.Module):
|
|||||||
pool_scales (tuple): Tuple of ints. Each int specifies the parameter
|
pool_scales (tuple): Tuple of ints. Each int specifies the parameter
|
||||||
required in 'global average pooling' within PPM.
|
required in 'global average pooling' within PPM.
|
||||||
|
|
||||||
conv_cfg (dict|None): Config of conv layers.
|
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||||
norm_cfg (dict|None): Config of norm layers.
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
act_cfg (dict): Config of activation layers.
|
dict(type='BN')
|
||||||
|
act_cfg (dict): Config of activation layers. Default:
|
||||||
|
dict(type='ReLU')
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
Default: False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -105,7 +105,7 @@ class GlobalFeatureExtractor(nn.Module):
|
|||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
act_cfg=dict(type='ReLU'),
|
act_cfg=dict(type='ReLU'),
|
||||||
align_corners=True):
|
align_corners=False):
|
||||||
super(GlobalFeatureExtractor, self).__init__()
|
super(GlobalFeatureExtractor, self).__init__()
|
||||||
self.conv_cfg = conv_cfg
|
self.conv_cfg = conv_cfg
|
||||||
self.norm_cfg = norm_cfg
|
self.norm_cfg = norm_cfg
|
||||||
@ -169,10 +169,13 @@ class FeatureFusionModule(nn.Module):
|
|||||||
Should be coherent with the downsampling factor determined
|
Should be coherent with the downsampling factor determined
|
||||||
by the GFE module.
|
by the GFE module.
|
||||||
|
|
||||||
conv_cfg (dict|None): Config of conv layers.
|
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||||
norm_cfg (dict|None): Config of norm layers.
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
act_cfg (dict): Config of activation layers.
|
dict(type='BN')
|
||||||
|
act_cfg (dict): Config of activation layers. Default:
|
||||||
|
dict(type='ReLU')
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
Default: False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -183,7 +186,7 @@ class FeatureFusionModule(nn.Module):
|
|||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
act_cfg=dict(type='ReLU'),
|
act_cfg=dict(type='ReLU'),
|
||||||
align_corners=True):
|
align_corners=False):
|
||||||
super(FeatureFusionModule, self).__init__()
|
super(FeatureFusionModule, self).__init__()
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
self.conv_cfg = conv_cfg
|
self.conv_cfg = conv_cfg
|
||||||
@ -231,46 +234,48 @@ class FeatureFusionModule(nn.Module):
|
|||||||
class FastSCNN(nn.Module):
|
class FastSCNN(nn.Module):
|
||||||
"""Fast-SCNN Backbone.
|
"""Fast-SCNN Backbone.
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): Number of input image channels. Default=3 (RGB)
|
in_channels (int): Number of input image channels. Default: 3.
|
||||||
|
|
||||||
downsample_dw_channels1 (int): Number of output channels after
|
downsample_dw_channels1 (int): Number of output channels after
|
||||||
the first conv layer in Learning-To-Downsample (LTD) module.
|
the first conv layer in Learning-To-Downsample (LTD) module.
|
||||||
|
Default: 32.
|
||||||
downsample_dw_channels2 (int): Number of output channels
|
downsample_dw_channels2 (int): Number of output channels
|
||||||
after the second conv layer in LTD.
|
after the second conv layer in LTD.
|
||||||
|
Default: 48.
|
||||||
global_in_channels (int): Number of input channels of
|
global_in_channels (int): Number of input channels of
|
||||||
Global Feature Extractor(GFE).
|
Global Feature Extractor(GFE).
|
||||||
Equal to number of output channels of LTD.
|
Equal to number of output channels of LTD.
|
||||||
|
Default: 64.
|
||||||
global_block_channels (tuple): Tuple of integers that describe
|
global_block_channels (tuple): Tuple of integers that describe
|
||||||
the output channels for each of the MobileNet-v2 bottleneck
|
the output channels for each of the MobileNet-v2 bottleneck
|
||||||
residual blocks in GFE.
|
residual blocks in GFE.
|
||||||
|
Default: (64, 96, 128).
|
||||||
global_out_channels (int): Number of output channels of GFE.
|
global_out_channels (int): Number of output channels of GFE.
|
||||||
|
Default: 128.
|
||||||
higher_in_channels (int): Number of input channels of the higher
|
higher_in_channels (int): Number of input channels of the higher
|
||||||
resolution branch in FFM.
|
resolution branch in FFM.
|
||||||
Equal to global_in_channels.
|
Equal to global_in_channels.
|
||||||
|
Default: 64.
|
||||||
lower_in_channels (int): Number of input channels of the lower
|
lower_in_channels (int): Number of input channels of the lower
|
||||||
resolution branch in FFM.
|
resolution branch in FFM.
|
||||||
Equal to global_out_channels.
|
Equal to global_out_channels.
|
||||||
|
Default: 128.
|
||||||
fusion_out_channels (int): Number of output channels of FFM.
|
fusion_out_channels (int): Number of output channels of FFM.
|
||||||
|
Default: 128.
|
||||||
scale_factor (int): The upsampling factor of the higher resolution
|
scale_factor (int): The upsampling factor of the higher resolution
|
||||||
branch in FFM.
|
branch in FFM.
|
||||||
Equal to the downsampling factor in GFE.
|
Equal to the downsampling factor in GFE.
|
||||||
|
Default: 4.
|
||||||
out_indices (tuple): Tuple of indices of list
|
out_indices (tuple): Tuple of indices of list
|
||||||
[higher_res_features, lower_res_features, fusion_output].
|
[higher_res_features, lower_res_features, fusion_output].
|
||||||
Often set to (0,1,2) to enable aux. heads.
|
Often set to (0,1,2) to enable aux. heads.
|
||||||
|
Default: (0, 1, 2).
|
||||||
conv_cfg (dict|None): Config of conv layers.
|
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||||
norm_cfg (dict|None): Config of norm layers.
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
act_cfg (dict): Config of activation layers.
|
dict(type='BN')
|
||||||
|
act_cfg (dict): Config of activation layers. Default:
|
||||||
|
dict(type='ReLU')
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
Default: False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -11,21 +11,16 @@ class DepthwiseSeparableFCNHead(FCNHead):
|
|||||||
This head is implemented according to Fast-SCNN.
|
This head is implemented according to Fast-SCNN.
|
||||||
Args:
|
Args:
|
||||||
in_channels(int): Number of output channels of FFM.
|
in_channels(int): Number of output channels of FFM.
|
||||||
|
|
||||||
channels(int): Number of middle-stage channels in the decode head.
|
channels(int): Number of middle-stage channels in the decode head.
|
||||||
|
|
||||||
concat_input(bool): Whether to concatenate original decode input into
|
concat_input(bool): Whether to concatenate original decode input into
|
||||||
the result of several consecutive convolution layers.
|
the result of several consecutive convolution layers.
|
||||||
|
Default: True.
|
||||||
num_classes(int): Used to determine the dimension of
|
num_classes(int): Used to determine the dimension of
|
||||||
final prediction tensor.
|
final prediction tensor.
|
||||||
|
|
||||||
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
|
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
|
||||||
|
|
||||||
norm_cfg (dict|None): Config of norm layers.
|
norm_cfg (dict|None): Config of norm layers.
|
||||||
|
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
Default: False.
|
||||||
loss_decode(dict): Config of loss type and some
|
loss_decode(dict): Config of loss type and some
|
||||||
relevant additional options.
|
relevant additional options.
|
||||||
"""
|
"""
|
||||||
|
@ -9,10 +9,12 @@ class InvertedResidual(nn.Module):
|
|||||||
oup (int): output channels.
|
oup (int): output channels.
|
||||||
stride (int): downsampling factor.
|
stride (int): downsampling factor.
|
||||||
expand_ratio (int): 1 or 2.
|
expand_ratio (int): 1 or 2.
|
||||||
dilation (int): Dilated conv. Default: 1
|
dilation (int): Dilated conv. Default: 1.
|
||||||
conv_cfg (dict|None): Config of conv layers.
|
conv_cfg (dict | None): Config of conv layers. Default: None.
|
||||||
norm_cfg (dict|None): Config of norm layers.
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
act_cfg (dict): Config of activation layers.
|
dict(type='BN').
|
||||||
|
act_cfg (dict): Config of activation layers. Default:
|
||||||
|
dict(type='ReLU6').
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -41,12 +41,9 @@ def collect_env():
|
|||||||
for name, devids in devices.items():
|
for name, devids in devices.items():
|
||||||
env_info['GPU ' + ','.join(devids)] = name
|
env_info['GPU ' + ','.join(devids)] = name
|
||||||
|
|
||||||
try:
|
gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
|
||||||
gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
|
gcc = gcc.decode('utf-8').strip()
|
||||||
gcc = gcc.decode('utf-8').strip()
|
env_info['GCC'] = gcc
|
||||||
env_info['GCC'] = gcc
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
env_info['GCC'] = 'n/a'
|
|
||||||
|
|
||||||
env_info['PyTorch'] = torch.__version__
|
env_info['PyTorch'] = torch.__version__
|
||||||
env_info['PyTorch compiling details'] = get_build_config()
|
env_info['PyTorch compiling details'] = get_build_config()
|
||||||
|
@ -675,14 +675,14 @@ def test_fastscnn_backbone():
|
|||||||
model = FastSCNN()
|
model = FastSCNN()
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
num_batch_picts = 4
|
batch_size = 4
|
||||||
imgs = torch.randn(num_batch_picts, 3, 1024, 2048)
|
imgs = torch.randn(batch_size, 3, 1024, 2048)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
|
|
||||||
assert len(feat) == 3
|
assert len(feat) == 3
|
||||||
# higher-res
|
# higher-res
|
||||||
assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256])
|
assert feat[0].shape == torch.Size([batch_size, 64, 128, 256])
|
||||||
# lower-res
|
# lower-res
|
||||||
assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64])
|
assert feat[1].shape == torch.Size([batch_size, 128, 32, 64])
|
||||||
# FFM output
|
# FFM output
|
||||||
assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256])
|
assert feat[2].shape == torch.Size([batch_size, 128, 128, 256])
|
||||||
|
35
tests/test_ops/test_inverted_residual_module.py
Normal file
35
tests/test_ops/test_inverted_residual_module.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from mmseg.ops import InvertedResidual
|
||||||
|
|
||||||
|
|
||||||
|
def test_inv_residual():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# test stride assertion.
|
||||||
|
InvertedResidual(32, 32, 3, 4)
|
||||||
|
|
||||||
|
# test default config with res connection.
|
||||||
|
# set expand_ratio = 4, stride = 1 and inp=oup.
|
||||||
|
inv_module = InvertedResidual(32, 32, 1, 4)
|
||||||
|
assert inv_module.use_res_connect
|
||||||
|
assert inv_module.conv[0].kernel_size == 3
|
||||||
|
assert inv_module.conv[0].padding == 1
|
||||||
|
x = torch.rand(1, 32, 64, 64)
|
||||||
|
output = inv_module(x)
|
||||||
|
assert output.shape == (1, 32, 64, 64)
|
||||||
|
|
||||||
|
# test inv_residual module without res connection.
|
||||||
|
# set expand_ratio = 4, stride = 2.
|
||||||
|
inv_module = InvertedResidual(32, 32, 2, 4)
|
||||||
|
assert not inv_module.use_res_connect
|
||||||
|
assert inv_module.conv[0].kernel_size == 1
|
||||||
|
x = torch.rand(1, 32, 64, 64)
|
||||||
|
output = inv_module(x)
|
||||||
|
assert output.shape == (1, 32, 32, 32)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user