147 lines
5.3 KiB
Python
147 lines
5.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.backbones import ERFNet
|
|
from mmseg.models.backbones.erfnet import (DownsamplerBlock, NonBottleneck1d,
|
|
UpsamplerBlock)
|
|
|
|
|
|
def test_erfnet_backbone():
|
|
# Test ERFNet Standard Forward.
|
|
model = ERFNet(
|
|
in_channels=3,
|
|
enc_downsample_channels=(16, 64, 128),
|
|
enc_stage_non_bottlenecks=(5, 8),
|
|
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
|
enc_non_bottleneck_channels=(64, 128),
|
|
dec_upsample_channels=(64, 16),
|
|
dec_stages_non_bottleneck=(2, 2),
|
|
dec_non_bottleneck_channels=(64, 16),
|
|
dropout_ratio=0.1,
|
|
)
|
|
model.init_weights()
|
|
model.train()
|
|
batch_size = 2
|
|
imgs = torch.randn(batch_size, 3, 256, 512)
|
|
output = model(imgs)
|
|
|
|
# output for segment Head
|
|
assert output[0].shape == torch.Size([batch_size, 16, 128, 256])
|
|
|
|
# Test input with rare shape
|
|
batch_size = 2
|
|
imgs = torch.randn(batch_size, 3, 527, 279)
|
|
output = model(imgs)
|
|
assert len(output[0]) == batch_size
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Number of encoder downsample block and decoder upsample block.
|
|
ERFNet(
|
|
in_channels=3,
|
|
enc_downsample_channels=(16, 64, 128),
|
|
enc_stage_non_bottlenecks=(5, 8),
|
|
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
|
enc_non_bottleneck_channels=(64, 128),
|
|
dec_upsample_channels=(128, 64, 16),
|
|
dec_stages_non_bottleneck=(2, 2),
|
|
dec_non_bottleneck_channels=(64, 16),
|
|
dropout_ratio=0.1,
|
|
)
|
|
with pytest.raises(AssertionError):
|
|
# Number of encoder downsample block and encoder Non-bottleneck block.
|
|
ERFNet(
|
|
in_channels=3,
|
|
enc_downsample_channels=(16, 64, 128),
|
|
enc_stage_non_bottlenecks=(5, 8, 10),
|
|
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
|
enc_non_bottleneck_channels=(64, 128),
|
|
dec_upsample_channels=(64, 16),
|
|
dec_stages_non_bottleneck=(2, 2),
|
|
dec_non_bottleneck_channels=(64, 16),
|
|
dropout_ratio=0.1,
|
|
)
|
|
with pytest.raises(AssertionError):
|
|
# Number of encoder downsample block and
|
|
# channels of encoder Non-bottleneck block.
|
|
ERFNet(
|
|
in_channels=3,
|
|
enc_downsample_channels=(16, 64, 128),
|
|
enc_stage_non_bottlenecks=(5, 8),
|
|
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
|
enc_non_bottleneck_channels=(64, 128, 256),
|
|
dec_upsample_channels=(64, 16),
|
|
dec_stages_non_bottleneck=(2, 2),
|
|
dec_non_bottleneck_channels=(64, 16),
|
|
dropout_ratio=0.1,
|
|
)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Number of encoder Non-bottleneck block and number of its channels.
|
|
ERFNet(
|
|
in_channels=3,
|
|
enc_downsample_channels=(16, 64, 128),
|
|
enc_stage_non_bottlenecks=(5, 8, 3),
|
|
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
|
enc_non_bottleneck_channels=(64, 128),
|
|
dec_upsample_channels=(64, 16),
|
|
dec_stages_non_bottleneck=(2, 2),
|
|
dec_non_bottleneck_channels=(64, 16),
|
|
dropout_ratio=0.1,
|
|
)
|
|
with pytest.raises(AssertionError):
|
|
# Number of decoder upsample block and decoder Non-bottleneck block.
|
|
ERFNet(
|
|
in_channels=3,
|
|
enc_downsample_channels=(16, 64, 128),
|
|
enc_stage_non_bottlenecks=(5, 8),
|
|
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
|
enc_non_bottleneck_channels=(64, 128),
|
|
dec_upsample_channels=(64, 16),
|
|
dec_stages_non_bottleneck=(2, 2, 3),
|
|
dec_non_bottleneck_channels=(64, 16),
|
|
dropout_ratio=0.1,
|
|
)
|
|
with pytest.raises(AssertionError):
|
|
# Number of decoder Non-bottleneck block and number of its channels.
|
|
ERFNet(
|
|
in_channels=3,
|
|
enc_downsample_channels=(16, 64, 128),
|
|
enc_stage_non_bottlenecks=(5, 8),
|
|
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
|
enc_non_bottleneck_channels=(64, 128),
|
|
dec_upsample_channels=(64, 16),
|
|
dec_stages_non_bottleneck=(2, 2),
|
|
dec_non_bottleneck_channels=(64, 16, 8),
|
|
dropout_ratio=0.1,
|
|
)
|
|
|
|
|
|
def test_erfnet_downsampler_block():
|
|
x_db = DownsamplerBlock(16, 64)
|
|
assert x_db.conv.in_channels == 16
|
|
assert x_db.conv.out_channels == 48
|
|
assert len(x_db.bn.weight) == 64
|
|
assert x_db.pool.kernel_size == 2
|
|
assert x_db.pool.stride == 2
|
|
|
|
|
|
def test_erfnet_non_bottleneck_1d():
|
|
x_nb1d = NonBottleneck1d(16, 0, 1)
|
|
assert x_nb1d.convs_layers[0].in_channels == 16
|
|
assert x_nb1d.convs_layers[0].out_channels == 16
|
|
assert x_nb1d.convs_layers[2].in_channels == 16
|
|
assert x_nb1d.convs_layers[2].out_channels == 16
|
|
assert x_nb1d.convs_layers[5].in_channels == 16
|
|
assert x_nb1d.convs_layers[5].out_channels == 16
|
|
assert x_nb1d.convs_layers[7].in_channels == 16
|
|
assert x_nb1d.convs_layers[7].out_channels == 16
|
|
assert x_nb1d.convs_layers[9].p == 0
|
|
|
|
|
|
def test_erfnet_upsampler_block():
|
|
x_ub = UpsamplerBlock(64, 16)
|
|
assert x_ub.conv.in_channels == 64
|
|
assert x_ub.conv.out_channels == 16
|
|
assert len(x_ub.bn.weight) == 16
|