mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* add icnet backbone * add icnet head * add icnet configs * nclass -> num_classes * Support ICNet * ICNet * ICNet * Add ICNeck * Add ICNeck * Add ICNeck * Add ICNeck * Adding unittest * Uploading models & logs * Uploading models & logs * add comment * smaller test_swin.py * try to delete test_swin.py * delete test_unet.py * delete test_unet.py * temp * smaller test_unet.py Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.necks import ICNeck
|
|
from mmseg.models.necks.ic_neck import CascadeFeatureFusion
|
|
from ..test_heads.utils import _conv_has_norm, to_cuda
|
|
|
|
|
|
def test_ic_neck():
|
|
# test with norm_cfg
|
|
neck = ICNeck(
|
|
in_channels=(64, 256, 256),
|
|
out_channels=128,
|
|
norm_cfg=dict(type='SyncBN'),
|
|
align_corners=False)
|
|
assert _conv_has_norm(neck, sync_bn=True)
|
|
|
|
inputs = [
|
|
torch.randn(1, 64, 128, 256),
|
|
torch.randn(1, 256, 65, 129),
|
|
torch.randn(1, 256, 32, 64)
|
|
]
|
|
neck = ICNeck(
|
|
in_channels=(64, 256, 256),
|
|
out_channels=128,
|
|
norm_cfg=dict(type='BN', requires_grad=True),
|
|
align_corners=False)
|
|
if torch.cuda.is_available():
|
|
neck, inputs = to_cuda(neck, inputs)
|
|
|
|
outputs = neck(inputs)
|
|
assert outputs[0].shape == (1, 128, 65, 129)
|
|
assert outputs[1].shape == (1, 128, 128, 256)
|
|
assert outputs[1].shape == (1, 128, 128, 256)
|
|
|
|
|
|
def test_ic_neck_cascade_feature_fusion():
|
|
cff = CascadeFeatureFusion(256, 256, 128)
|
|
assert cff.conv_low.in_channels == 256
|
|
assert cff.conv_low.out_channels == 128
|
|
assert cff.conv_high.in_channels == 256
|
|
assert cff.conv_high.out_channels == 128
|
|
|
|
|
|
def test_ic_neck_input_channels():
|
|
with pytest.raises(AssertionError):
|
|
# ICNet Neck input channel constraints.
|
|
ICNeck(
|
|
in_channels=(64, 256, 256, 256),
|
|
out_channels=128,
|
|
norm_cfg=dict(type='BN', requires_grad=True),
|
|
align_corners=False)
|