MengzhangLI f910caf229
[Feature] Support FastFCN (#885)
* FastFCN first commit

* FastFCN first commit

* Fixing lint error

* Fixing lint error

* use for loop on JPU

* Use For Loop

* Refactor FastFCN

* FastFCN

* FastFCN

* temp

* Uploading models & logs (4x4)

* Fixing typos

* fix typos

* rename config

* change README.md

* use _delete_=True

* change configs

* change start_level to 0

* change start_level to 0

* jpu

* add unittest for start_level!=0
2021-10-01 02:41:24 +08:00

41 lines
1.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.necks import JPU
def test_fastfcn_neck():
# Test FastFCN Standard Forward
model = JPU()
model.init_weights()
model.train()
batch_size = 1
input = [
torch.randn(batch_size, 512, 64, 128),
torch.randn(batch_size, 1024, 32, 64),
torch.randn(batch_size, 2048, 16, 32)
]
feat = model(input)
assert len(feat) == 3
assert feat[0].shape == torch.Size([batch_size, 512, 64, 128])
assert feat[1].shape == torch.Size([batch_size, 1024, 32, 64])
assert feat[2].shape == torch.Size([batch_size, 2048, 64, 128])
with pytest.raises(AssertionError):
# FastFCN input and in_channels constraints.
JPU(in_channels=(256, 512, 1024), start_level=0, end_level=5)
# Test not default start_level
model = JPU(in_channels=(512, 1024, 2048), start_level=1, end_level=-1)
input = [
torch.randn(batch_size, 512, 64, 128),
torch.randn(batch_size, 1024, 32, 64),
torch.randn(batch_size, 2048, 16, 32)
]
feat = model(input)
assert len(feat) == 2
assert feat[0].shape == torch.Size([batch_size, 1024, 32, 64])
assert feat[1].shape == torch.Size([batch_size, 2048, 32, 64])