mmsegmentation/tests/test_models/test_necks/test_jpu.py

41 lines
1.3 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.necks import JPU
def test_fastfcn_neck():
# Test FastFCN Standard Forward
model = JPU()
model.init_weights()
model.train()
batch_size = 1
input = [
torch.randn(batch_size, 512, 64, 128),
torch.randn(batch_size, 1024, 32, 64),
torch.randn(batch_size, 2048, 16, 32)
]
feat = model(input)
assert len(feat) == 3
assert feat[0].shape == torch.Size([batch_size, 512, 64, 128])
assert feat[1].shape == torch.Size([batch_size, 1024, 32, 64])
assert feat[2].shape == torch.Size([batch_size, 2048, 64, 128])
with pytest.raises(AssertionError):
# FastFCN input and in_channels constraints.
JPU(in_channels=(256, 512, 1024), start_level=0, end_level=5)
# Test not default start_level
model = JPU(in_channels=(512, 1024, 2048), start_level=1, end_level=-1)
input = [
torch.randn(batch_size, 512, 64, 128),
torch.randn(batch_size, 1024, 32, 64),
torch.randn(batch_size, 2048, 16, 32)
]
feat = model(input)
assert len(feat) == 2
assert feat[0].shape == torch.Size([batch_size, 1024, 32, 64])
assert feat[1].shape == torch.Size([batch_size, 2048, 32, 64])