mmsegmentation/tests/test_models/test_necks/test_jpu.py

47 lines
1.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.necks import JPU
def test_fastfcn_neck():
# Test FastFCN Standard Forward
model = JPU(
in_channels=(64, 128, 256),
mid_channels=64,
start_level=0,
end_level=-1,
dilations=(1, 2, 4, 8),
)
model.init_weights()
model.train()
batch_size = 1
input = [
torch.randn(batch_size, 64, 64, 128),
torch.randn(batch_size, 128, 32, 64),
torch.randn(batch_size, 256, 16, 32)
]
feat = model(input)
assert len(feat) == 3
assert feat[0].shape == torch.Size([batch_size, 64, 64, 128])
assert feat[1].shape == torch.Size([batch_size, 128, 32, 64])
assert feat[2].shape == torch.Size([batch_size, 256, 64, 128])
with pytest.raises(AssertionError):
# FastFCN input and in_channels constraints.
JPU(in_channels=(256, 64, 128), start_level=0, end_level=5)
# Test not default start_level
model = JPU(in_channels=(64, 128, 256), start_level=1, end_level=-1)
input = [
torch.randn(batch_size, 64, 64, 128),
torch.randn(batch_size, 128, 32, 64),
torch.randn(batch_size, 256, 16, 32)
]
feat = model(input)
assert len(feat) == 2
assert feat[0].shape == torch.Size([batch_size, 128, 32, 64])
assert feat[1].shape == torch.Size([batch_size, 2048, 32, 64])