121 lines
4.7 KiB
Python
121 lines
4.7 KiB
Python
import mmcv
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.utils import InvertedResidual, InvertedResidualV3
|
|
|
|
|
|
def test_inv_residual():
|
|
with pytest.raises(AssertionError):
|
|
# test stride assertion.
|
|
InvertedResidual(32, 32, 3, 4)
|
|
|
|
# test default config with res connection.
|
|
# set expand_ratio = 4, stride = 1 and inp=oup.
|
|
inv_module = InvertedResidual(32, 32, 1, 4)
|
|
assert inv_module.use_res_connect
|
|
assert inv_module.conv[0].kernel_size == (1, 1)
|
|
assert inv_module.conv[0].padding == 0
|
|
assert inv_module.conv[1].kernel_size == (3, 3)
|
|
assert inv_module.conv[1].padding == 1
|
|
assert inv_module.conv[0].with_norm
|
|
assert inv_module.conv[1].with_norm
|
|
x = torch.rand(1, 32, 64, 64)
|
|
output = inv_module(x)
|
|
assert output.shape == (1, 32, 64, 64)
|
|
|
|
# test inv_residual module without res connection.
|
|
# set expand_ratio = 4, stride = 2.
|
|
inv_module = InvertedResidual(32, 32, 2, 4)
|
|
assert not inv_module.use_res_connect
|
|
assert inv_module.conv[0].kernel_size == (1, 1)
|
|
x = torch.rand(1, 32, 64, 64)
|
|
output = inv_module(x)
|
|
assert output.shape == (1, 32, 32, 32)
|
|
|
|
# test expand_ratio == 1
|
|
inv_module = InvertedResidual(32, 32, 1, 1)
|
|
assert inv_module.conv[0].kernel_size == (3, 3)
|
|
x = torch.rand(1, 32, 64, 64)
|
|
output = inv_module(x)
|
|
assert output.shape == (1, 32, 64, 64)
|
|
|
|
# test with checkpoint forward
|
|
inv_module = InvertedResidual(32, 32, 1, 1, with_cp=True)
|
|
assert inv_module.with_cp
|
|
x = torch.rand(1, 32, 64, 64, requires_grad=True)
|
|
output = inv_module(x)
|
|
assert output.shape == (1, 32, 64, 64)
|
|
|
|
|
|
def test_inv_residualv3():
|
|
with pytest.raises(AssertionError):
|
|
# test stride assertion.
|
|
InvertedResidualV3(32, 32, 16, stride=3)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# test assertion.
|
|
InvertedResidualV3(32, 32, 16, with_expand_conv=False)
|
|
|
|
# test with se_cfg=None, with_expand_conv=False
|
|
inv_module = InvertedResidualV3(32, 32, 32, with_expand_conv=False)
|
|
|
|
assert inv_module.with_res_shortcut is True
|
|
assert inv_module.with_se is False
|
|
assert inv_module.with_expand_conv is False
|
|
assert not hasattr(inv_module, 'expand_conv')
|
|
assert isinstance(inv_module.depthwise_conv.conv, torch.nn.Conv2d)
|
|
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
|
|
assert inv_module.depthwise_conv.conv.stride == (1, 1)
|
|
assert inv_module.depthwise_conv.conv.padding == (1, 1)
|
|
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
|
|
assert isinstance(inv_module.depthwise_conv.activate, torch.nn.ReLU)
|
|
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
|
|
assert inv_module.linear_conv.conv.stride == (1, 1)
|
|
assert inv_module.linear_conv.conv.padding == (0, 0)
|
|
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
|
|
|
|
x = torch.rand(1, 32, 64, 64)
|
|
output = inv_module(x)
|
|
assert output.shape == (1, 32, 64, 64)
|
|
|
|
# test with se_cfg and with_expand_conv
|
|
se_cfg = dict(
|
|
channels=16,
|
|
ratio=4,
|
|
act_cfg=(dict(type='ReLU'),
|
|
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
|
|
act_cfg = dict(type='HSwish')
|
|
inv_module = InvertedResidualV3(
|
|
32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg)
|
|
assert inv_module.with_res_shortcut is False
|
|
assert inv_module.with_se is True
|
|
assert inv_module.with_expand_conv is True
|
|
assert inv_module.expand_conv.conv.kernel_size == (1, 1)
|
|
assert inv_module.expand_conv.conv.stride == (1, 1)
|
|
assert inv_module.expand_conv.conv.padding == (0, 0)
|
|
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
|
|
|
|
assert isinstance(inv_module.depthwise_conv.conv,
|
|
mmcv.cnn.bricks.Conv2dAdaptivePadding)
|
|
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
|
|
assert inv_module.depthwise_conv.conv.stride == (2, 2)
|
|
assert inv_module.depthwise_conv.conv.padding == (0, 0)
|
|
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
|
|
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
|
|
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
|
|
assert inv_module.linear_conv.conv.stride == (1, 1)
|
|
assert inv_module.linear_conv.conv.padding == (0, 0)
|
|
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
|
|
x = torch.rand(1, 32, 64, 64)
|
|
output = inv_module(x)
|
|
assert output.shape == (1, 40, 32, 32)
|
|
|
|
# test with checkpoint forward
|
|
inv_module = InvertedResidualV3(
|
|
32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg, with_cp=True)
|
|
assert inv_module.with_cp
|
|
x = torch.randn(2, 32, 64, 64, requires_grad=True)
|
|
output = inv_module(x)
|
|
assert output.shape == (2, 40, 32, 32)
|