mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* FastFCN first commit * FastFCN first commit * Fixing lint error * Fixing lint error * use for loop on JPU * Use For Loop * Refactor FastFCN * FastFCN * FastFCN * temp * Uploading models & logs (4x4) * Fixing typos * fix typos * rename config * change README.md * use _delete_=True * change configs * change start_level to 0 * change start_level to 0 * jpu * add unittest for start_level!=0
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.necks import JPU
|
|
|
|
|
|
def test_fastfcn_neck():
|
|
# Test FastFCN Standard Forward
|
|
model = JPU()
|
|
model.init_weights()
|
|
model.train()
|
|
batch_size = 1
|
|
input = [
|
|
torch.randn(batch_size, 512, 64, 128),
|
|
torch.randn(batch_size, 1024, 32, 64),
|
|
torch.randn(batch_size, 2048, 16, 32)
|
|
]
|
|
feat = model(input)
|
|
|
|
assert len(feat) == 3
|
|
assert feat[0].shape == torch.Size([batch_size, 512, 64, 128])
|
|
assert feat[1].shape == torch.Size([batch_size, 1024, 32, 64])
|
|
assert feat[2].shape == torch.Size([batch_size, 2048, 64, 128])
|
|
|
|
with pytest.raises(AssertionError):
|
|
# FastFCN input and in_channels constraints.
|
|
JPU(in_channels=(256, 512, 1024), start_level=0, end_level=5)
|
|
|
|
# Test not default start_level
|
|
model = JPU(in_channels=(512, 1024, 2048), start_level=1, end_level=-1)
|
|
input = [
|
|
torch.randn(batch_size, 512, 64, 128),
|
|
torch.randn(batch_size, 1024, 32, 64),
|
|
torch.randn(batch_size, 2048, 16, 32)
|
|
]
|
|
feat = model(input)
|
|
assert len(feat) == 2
|
|
assert feat[0].shape == torch.Size([batch_size, 1024, 32, 64])
|
|
assert feat[1].shape == torch.Size([batch_size, 2048, 32, 64])
|