import pytest
import torch

from mmseg.utils import InvertedResidual


def test_inv_residual():
    with pytest.raises(AssertionError):
        # test stride assertion.
        InvertedResidual(32, 32, 3, 4)

    # test default config with res connection.
    # set expand_ratio = 4, stride = 1 and inp=oup.
    inv_module = InvertedResidual(32, 32, 1, 4)
    assert inv_module.use_res_connect
    assert inv_module.conv[0].kernel_size == (1, 1)
    assert inv_module.conv[0].padding == 0
    assert inv_module.conv[1].kernel_size == (3, 3)
    assert inv_module.conv[1].padding == 1
    assert inv_module.conv[0].with_norm
    assert inv_module.conv[1].with_norm
    x = torch.rand(1, 32, 64, 64)
    output = inv_module(x)
    assert output.shape == (1, 32, 64, 64)

    # test inv_residual module without res connection.
    # set expand_ratio = 4, stride = 2.
    inv_module = InvertedResidual(32, 32, 2, 4)
    assert not inv_module.use_res_connect
    assert inv_module.conv[0].kernel_size == (1, 1)
    x = torch.rand(1, 32, 64, 64)
    output = inv_module(x)
    assert output.shape == (1, 32, 32, 32)

    # test expand_ratio == 1
    inv_module = InvertedResidual(32, 32, 1, 1)
    assert inv_module.conv[0].kernel_size == (3, 3)
    x = torch.rand(1, 32, 64, 64)
    output = inv_module(x)
    assert output.shape == (1, 32, 64, 64)