mmsegmentation/tests/test_models/test_necks/test_multilevel_neck.py
谢昕辰 07cc26ae5a add upsample neck (#512)
* init

* upsample v1.0

* fix errors

* change to in_channels list

* add unittest, docstring, norm/act config and rename

Co-authored-by: xiexinch <test767803@foxmail.com>
2021-04-24 21:22:09 -07:00

29 lines
947 B
Python

import torch
from mmseg.models import MultiLevelNeck
def test_multilevel_neck():
# Test multi feature maps
in_channels = [256, 512, 1024, 2048]
inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)]
neck = MultiLevelNeck(in_channels, 256)
outputs = neck(inputs)
assert outputs[0].shape == torch.Size([1, 256, 7, 7])
assert outputs[1].shape == torch.Size([1, 256, 14, 14])
assert outputs[2].shape == torch.Size([1, 256, 28, 28])
assert outputs[3].shape == torch.Size([1, 256, 56, 56])
# Test one feature map
in_channels = [768]
inputs = [torch.randn(1, 768, 14, 14)]
neck = MultiLevelNeck(in_channels, 256)
outputs = neck(inputs)
assert outputs[0].shape == torch.Size([1, 256, 7, 7])
assert outputs[1].shape == torch.Size([1, 256, 14, 14])
assert outputs[2].shape == torch.Size([1, 256, 28, 28])
assert outputs[3].shape == torch.Size([1, 256, 56, 56])