mmsegmentation/tests/test_models/test_backbones/test_bisenetv1.py

110 lines
3.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones import BiSeNetV1
from mmseg.models.backbones.bisenetv1 import (AttentionRefinementModule,
ContextPath, FeatureFusionModule,
SpatialPath)
def test_bisenetv1_backbone():
# Test BiSeNetV1 Standard Forward
backbone_cfg = dict(
type='ResNet',
in_channels=3,
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_eval=False,
style='pytorch',
contract_dilation=True)
model = BiSeNetV1(in_channels=3, backbone_cfg=backbone_cfg)
model.init_weights()
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 64, 128)
feat = model(imgs)
assert len(feat) == 3
# output for segment Head
assert feat[0].shape == torch.Size([batch_size, 256, 8, 16])
# for auxiliary head 1
assert feat[1].shape == torch.Size([batch_size, 128, 8, 16])
# for auxiliary head 2
assert feat[2].shape == torch.Size([batch_size, 128, 4, 8])
# Test input with rare shape
batch_size = 2
imgs = torch.randn(batch_size, 3, 95, 27)
feat = model(imgs)
assert len(feat) == 3
with pytest.raises(AssertionError):
# BiSeNetV1 spatial path channel constraints.
BiSeNetV1(
backbone_cfg=backbone_cfg,
in_channels=3,
spatial_channels=(16, 16, 16))
with pytest.raises(AssertionError):
# BiSeNetV1 context path constraints.
BiSeNetV1(
backbone_cfg=backbone_cfg,
in_channels=3,
context_channels=(16, 32, 64, 128))
def test_bisenetv1_spatial_path():
with pytest.raises(AssertionError):
# BiSeNetV1 spatial path channel constraints.
SpatialPath(num_channels=(16, 16, 16), in_channels=3)
def test_bisenetv1_context_path():
backbone_cfg = dict(
type='ResNet',
in_channels=3,
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_eval=False,
style='pytorch',
contract_dilation=True)
with pytest.raises(AssertionError):
# BiSeNetV1 context path constraints.
ContextPath(
backbone_cfg=backbone_cfg, context_channels=(16, 32, 64, 128))
def test_bisenetv1_attention_refinement_module():
x_arm = AttentionRefinementModule(32, 8)
assert x_arm.conv_layer.in_channels == 32
assert x_arm.conv_layer.out_channels == 8
assert x_arm.conv_layer.kernel_size == (3, 3)
x = torch.randn(2, 32, 8, 16)
x_out = x_arm(x)
assert x_out.shape == torch.Size([2, 8, 8, 16])
def test_bisenetv1_feature_fusion_module():
ffm = FeatureFusionModule(16, 32)
assert ffm.conv1.in_channels == 16
assert ffm.conv1.out_channels == 32
assert ffm.conv1.kernel_size == (1, 1)
assert ffm.gap.output_size == (1, 1)
assert ffm.conv_atten[0].in_channels == 32
assert ffm.conv_atten[0].out_channels == 32
assert ffm.conv_atten[0].kernel_size == (1, 1)
ffm = FeatureFusionModule(16, 16)
x1 = torch.randn(2, 8, 8, 16)
x2 = torch.randn(2, 8, 8, 16)
x_out = ffm(x1, x2)
assert x_out.shape == torch.Size([2, 16, 8, 16])