import pytest
import torch

from mmseg.models.backbones import FastSCNN


def test_fastscnn_backbone():
    with pytest.raises(AssertionError):
        # Fast-SCNN channel constraints.
        FastSCNN(
            3, (32, 48),
            64, (64, 96, 128), (2, 2, 1),
            global_out_channels=127,
            higher_in_channels=64,
            lower_in_channels=128)

    # Test FastSCNN Standard Forward
    model = FastSCNN()
    model.init_weights()
    model.train()
    batch_size = 4
    imgs = torch.randn(batch_size, 3, 512, 1024)
    feat = model(imgs)

    assert len(feat) == 3
    # higher-res
    assert feat[0].shape == torch.Size([batch_size, 64, 64, 128])
    # lower-res
    assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
    # FFM output
    assert feat[2].shape == torch.Size([batch_size, 128, 64, 128])