47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.necks import JPU
|
|
|
|
|
|
def test_fastfcn_neck():
|
|
# Test FastFCN Standard Forward
|
|
model = JPU(
|
|
in_channels=(64, 128, 256),
|
|
mid_channels=64,
|
|
start_level=0,
|
|
end_level=-1,
|
|
dilations=(1, 2, 4, 8),
|
|
)
|
|
model.init_weights()
|
|
model.train()
|
|
batch_size = 1
|
|
input = [
|
|
torch.randn(batch_size, 64, 64, 128),
|
|
torch.randn(batch_size, 128, 32, 64),
|
|
torch.randn(batch_size, 256, 16, 32)
|
|
]
|
|
feat = model(input)
|
|
|
|
assert len(feat) == 3
|
|
assert feat[0].shape == torch.Size([batch_size, 64, 64, 128])
|
|
assert feat[1].shape == torch.Size([batch_size, 128, 32, 64])
|
|
assert feat[2].shape == torch.Size([batch_size, 256, 64, 128])
|
|
|
|
with pytest.raises(AssertionError):
|
|
# FastFCN input and in_channels constraints.
|
|
JPU(in_channels=(256, 64, 128), start_level=0, end_level=5)
|
|
|
|
# Test not default start_level
|
|
model = JPU(in_channels=(64, 128, 256), start_level=1, end_level=-1)
|
|
input = [
|
|
torch.randn(batch_size, 64, 64, 128),
|
|
torch.randn(batch_size, 128, 32, 64),
|
|
torch.randn(batch_size, 256, 16, 32)
|
|
]
|
|
feat = model(input)
|
|
assert len(feat) == 2
|
|
assert feat[0].shape == torch.Size([batch_size, 128, 32, 64])
|
|
assert feat[1].shape == torch.Size([batch_size, 2048, 32, 64])
|