add unit test for inverted_residual: debug 0

pull/58/head
johnzja 2020-08-11 20:14:14 +08:00
parent ae85850d30
commit 0d0641b5c1
1 changed files with 13 additions and 3 deletions

View File

@ -14,8 +14,12 @@ def test_inv_residual():
# set expand_ratio = 4, stride = 1 and inp=oup.
inv_module = InvertedResidual(32, 32, 1, 4)
assert inv_module.use_res_connect
assert inv_module.conv[0].kernel_size == 3
assert inv_module.conv[0].padding == 1
assert inv_module.conv[0].kernel_size == (1, 1)
assert inv_module.conv[0].padding == 0
assert inv_module.conv[1].kernel_size == (3, 3)
assert inv_module.conv[1].padding == 1
assert not inv_module.conv[0].with_norm
assert not inv_module.conv[1].with_norm
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 32, 64, 64)
@ -24,11 +28,17 @@ def test_inv_residual():
# set expand_ratio = 4, stride = 2.
inv_module = InvertedResidual(32, 32, 2, 4)
assert not inv_module.use_res_connect
assert inv_module.conv[0].kernel_size == 1
assert inv_module.conv[0].kernel_size == (1, 1)
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 32, 32, 32)
# test expand_ratio == 1
inv_module = InvertedResidual(32, 32, 1, 1)
assert inv_module.conv[0].kernel_size == (3, 3)
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 32, 64, 64)