# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmseg.models.backbones import ERFNet
from mmseg.models.backbones.erfnet import (DownsamplerBlock, NonBottleneck1d,
                                           UpsamplerBlock)


def test_erfnet_backbone():
    # Test ERFNet Standard Forward.
    model = ERFNet(
        in_channels=3,
        enc_downsample_channels=(16, 64, 128),
        enc_stage_non_bottlenecks=(5, 8),
        enc_non_bottleneck_dilations=(2, 4, 8, 16),
        enc_non_bottleneck_channels=(64, 128),
        dec_upsample_channels=(64, 16),
        dec_stages_non_bottleneck=(2, 2),
        dec_non_bottleneck_channels=(64, 16),
        dropout_ratio=0.1,
    )
    model.init_weights()
    model.train()
    batch_size = 2
    imgs = torch.randn(batch_size, 3, 256, 512)
    output = model(imgs)

    # output for segment Head
    assert output[0].shape == torch.Size([batch_size, 16, 128, 256])

    # Test input with rare shape
    batch_size = 2
    imgs = torch.randn(batch_size, 3, 527, 279)
    output = model(imgs)
    assert len(output[0]) == batch_size

    with pytest.raises(AssertionError):
        # Number of encoder downsample block and decoder upsample block.
        ERFNet(
            in_channels=3,
            enc_downsample_channels=(16, 64, 128),
            enc_stage_non_bottlenecks=(5, 8),
            enc_non_bottleneck_dilations=(2, 4, 8, 16),
            enc_non_bottleneck_channels=(64, 128),
            dec_upsample_channels=(128, 64, 16),
            dec_stages_non_bottleneck=(2, 2),
            dec_non_bottleneck_channels=(64, 16),
            dropout_ratio=0.1,
        )
    with pytest.raises(AssertionError):
        # Number of encoder downsample block and encoder Non-bottleneck block.
        ERFNet(
            in_channels=3,
            enc_downsample_channels=(16, 64, 128),
            enc_stage_non_bottlenecks=(5, 8, 10),
            enc_non_bottleneck_dilations=(2, 4, 8, 16),
            enc_non_bottleneck_channels=(64, 128),
            dec_upsample_channels=(64, 16),
            dec_stages_non_bottleneck=(2, 2),
            dec_non_bottleneck_channels=(64, 16),
            dropout_ratio=0.1,
        )
    with pytest.raises(AssertionError):
        # Number of encoder downsample block and
        # channels of encoder Non-bottleneck block.
        ERFNet(
            in_channels=3,
            enc_downsample_channels=(16, 64, 128),
            enc_stage_non_bottlenecks=(5, 8),
            enc_non_bottleneck_dilations=(2, 4, 8, 16),
            enc_non_bottleneck_channels=(64, 128, 256),
            dec_upsample_channels=(64, 16),
            dec_stages_non_bottleneck=(2, 2),
            dec_non_bottleneck_channels=(64, 16),
            dropout_ratio=0.1,
        )

    with pytest.raises(AssertionError):
        # Number of encoder Non-bottleneck block and number of its channels.
        ERFNet(
            in_channels=3,
            enc_downsample_channels=(16, 64, 128),
            enc_stage_non_bottlenecks=(5, 8, 3),
            enc_non_bottleneck_dilations=(2, 4, 8, 16),
            enc_non_bottleneck_channels=(64, 128),
            dec_upsample_channels=(64, 16),
            dec_stages_non_bottleneck=(2, 2),
            dec_non_bottleneck_channels=(64, 16),
            dropout_ratio=0.1,
        )
    with pytest.raises(AssertionError):
        # Number of decoder upsample block and decoder Non-bottleneck block.
        ERFNet(
            in_channels=3,
            enc_downsample_channels=(16, 64, 128),
            enc_stage_non_bottlenecks=(5, 8),
            enc_non_bottleneck_dilations=(2, 4, 8, 16),
            enc_non_bottleneck_channels=(64, 128),
            dec_upsample_channels=(64, 16),
            dec_stages_non_bottleneck=(2, 2, 3),
            dec_non_bottleneck_channels=(64, 16),
            dropout_ratio=0.1,
        )
    with pytest.raises(AssertionError):
        # Number of decoder Non-bottleneck block and number of its channels.
        ERFNet(
            in_channels=3,
            enc_downsample_channels=(16, 64, 128),
            enc_stage_non_bottlenecks=(5, 8),
            enc_non_bottleneck_dilations=(2, 4, 8, 16),
            enc_non_bottleneck_channels=(64, 128),
            dec_upsample_channels=(64, 16),
            dec_stages_non_bottleneck=(2, 2),
            dec_non_bottleneck_channels=(64, 16, 8),
            dropout_ratio=0.1,
        )


def test_erfnet_downsampler_block():
    x_db = DownsamplerBlock(16, 64)
    assert x_db.conv.in_channels == 16
    assert x_db.conv.out_channels == 48
    assert len(x_db.bn.weight) == 64
    assert x_db.pool.kernel_size == 2
    assert x_db.pool.stride == 2


def test_erfnet_non_bottleneck_1d():
    x_nb1d = NonBottleneck1d(16, 0, 1)
    assert x_nb1d.convs_layers[0].in_channels == 16
    assert x_nb1d.convs_layers[0].out_channels == 16
    assert x_nb1d.convs_layers[2].in_channels == 16
    assert x_nb1d.convs_layers[2].out_channels == 16
    assert x_nb1d.convs_layers[5].in_channels == 16
    assert x_nb1d.convs_layers[5].out_channels == 16
    assert x_nb1d.convs_layers[7].in_channels == 16
    assert x_nb1d.convs_layers[7].out_channels == 16
    assert x_nb1d.convs_layers[9].p == 0


def test_erfnet_upsampler_block():
    x_ub = UpsamplerBlock(64, 16)
    assert x_ub.conv.in_channels == 64
    assert x_ub.conv.out_channels == 16
    assert len(x_ub.bn.weight) == 16