diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 9ba33d283..282550179 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -675,13 +675,14 @@ def test_fastscnn_backbone(): model = FastSCNN() model.init_weights() model.train() - imgs = torch.randn(1, 3, 1024, 2048) + num_batch_picts = 4 + imgs = torch.randn(num_batch_picts, 3, 1024, 2048) feat = model(imgs) assert len(feat) == 3 - assert feat[0].shape == torch.Size([1, 64, 128, 256]) # higher-res - assert feat[1].shape == torch.Size([1, 128, 32, 64]) # lower-res - assert feat[2].shape == torch.Size([1, 128, 128, 256]) # FFM output + assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256]) # higher-res + assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64]) # lower-res + assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256]) # FFM output