41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from mmseg.models.necks import JPU
|
||
|
|
||
|
|
||
|
def test_fastfcn_neck():
|
||
|
# Test FastFCN Standard Forward
|
||
|
model = JPU()
|
||
|
model.init_weights()
|
||
|
model.train()
|
||
|
batch_size = 1
|
||
|
input = [
|
||
|
torch.randn(batch_size, 512, 64, 128),
|
||
|
torch.randn(batch_size, 1024, 32, 64),
|
||
|
torch.randn(batch_size, 2048, 16, 32)
|
||
|
]
|
||
|
feat = model(input)
|
||
|
|
||
|
assert len(feat) == 3
|
||
|
assert feat[0].shape == torch.Size([batch_size, 512, 64, 128])
|
||
|
assert feat[1].shape == torch.Size([batch_size, 1024, 32, 64])
|
||
|
assert feat[2].shape == torch.Size([batch_size, 2048, 64, 128])
|
||
|
|
||
|
with pytest.raises(AssertionError):
|
||
|
# FastFCN input and in_channels constraints.
|
||
|
JPU(in_channels=(256, 512, 1024), start_level=0, end_level=5)
|
||
|
|
||
|
# Test not default start_level
|
||
|
model = JPU(in_channels=(512, 1024, 2048), start_level=1, end_level=-1)
|
||
|
input = [
|
||
|
torch.randn(batch_size, 512, 64, 128),
|
||
|
torch.randn(batch_size, 1024, 32, 64),
|
||
|
torch.randn(batch_size, 2048, 16, 32)
|
||
|
]
|
||
|
feat = model(input)
|
||
|
assert len(feat) == 2
|
||
|
assert feat[0].shape == torch.Size([batch_size, 1024, 32, 64])
|
||
|
assert feat[1].shape == torch.Size([batch_size, 2048, 32, 64])
|