From 0d0641b5c198221612929aba9cfc74c62149a5a8 Mon Sep 17 00:00:00 2001 From: johnzja Date: Tue, 11 Aug 2020 20:14:14 +0800 Subject: [PATCH] add unit test for inverted_residual: debug 0 --- tests/test_ops/test_inverted_residual_module.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/test_ops/test_inverted_residual_module.py b/tests/test_ops/test_inverted_residual_module.py index a00f1510b..a49210349 100644 --- a/tests/test_ops/test_inverted_residual_module.py +++ b/tests/test_ops/test_inverted_residual_module.py @@ -14,8 +14,12 @@ def test_inv_residual(): # set expand_ratio = 4, stride = 1 and inp=oup. inv_module = InvertedResidual(32, 32, 1, 4) assert inv_module.use_res_connect - assert inv_module.conv[0].kernel_size == 3 - assert inv_module.conv[0].padding == 1 + assert inv_module.conv[0].kernel_size == (1, 1) + assert inv_module.conv[0].padding == 0 + assert inv_module.conv[1].kernel_size == (3, 3) + assert inv_module.conv[1].padding == 1 + assert not inv_module.conv[0].with_norm + assert not inv_module.conv[1].with_norm x = torch.rand(1, 32, 64, 64) output = inv_module(x) assert output.shape == (1, 32, 64, 64) @@ -24,11 +28,17 @@ def test_inv_residual(): # set expand_ratio = 4, stride = 2. inv_module = InvertedResidual(32, 32, 2, 4) assert not inv_module.use_res_connect - assert inv_module.conv[0].kernel_size == 1 + assert inv_module.conv[0].kernel_size == (1, 1) x = torch.rand(1, 32, 64, 64) output = inv_module(x) assert output.shape == (1, 32, 32, 32) + # test expand_ratio == 1 + inv_module = InvertedResidual(32, 32, 1, 1) + assert inv_module.conv[0].kernel_size == (3, 3) + x = torch.rand(1, 32, 64, 64) + output = inv_module(x) + assert output.shape == (1, 32, 64, 64)