mmsegmentation/tests/test_models/test_backbones/test_stdc.py

132 lines
4.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones import STDCContextPathNet
from mmseg.models.backbones.stdc import (AttentionRefinementModule,
FeatureFusionModule, STDCModule,
STDCNet)
def test_stdc_context_path_net():
# Test STDCContextPathNet Standard Forward
model = STDCContextPathNet(
backbone_cfg=dict(
type='STDCNet',
stdc_type='STDCNet1',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='cat',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=True),
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4))
model.init_weights()
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 256, 512)
feat = model(imgs)
assert len(feat) == 4
# output for segment Head
assert feat[0].shape == torch.Size([batch_size, 256, 32, 64])
# for auxiliary head 1
assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
# for auxiliary head 2
assert feat[2].shape == torch.Size([batch_size, 128, 32, 64])
# for auxiliary head 3
assert feat[3].shape == torch.Size([batch_size, 256, 32, 64])
# Test input with rare shape
batch_size = 2
imgs = torch.randn(batch_size, 3, 527, 279)
model = STDCContextPathNet(
backbone_cfg=dict(
type='STDCNet',
stdc_type='STDCNet1',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='add',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=False),
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4))
model.init_weights()
model.train()
feat = model(imgs)
assert len(feat) == 4
def test_stdcnet():
with pytest.raises(AssertionError):
# STDC backbone constraints.
STDCNet(
stdc_type='STDCNet3',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='cat',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=False)
with pytest.raises(AssertionError):
# STDC bottleneck type constraints.
STDCNet(
stdc_type='STDCNet1',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='dog',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=False)
with pytest.raises(AssertionError):
# STDC channels length constraints.
STDCNet(
stdc_type='STDCNet1',
in_channels=3,
channels=(16, 32, 64, 256, 512, 1024),
bottleneck_type='cat',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=False)
def test_feature_fusion_module():
x_ffm = FeatureFusionModule(in_channels=64, out_channels=32)
assert x_ffm.conv0.in_channels == 64
assert x_ffm.attention[1].in_channels == 32
assert x_ffm.attention[2].in_channels == 8
assert x_ffm.attention[2].out_channels == 32
x1 = torch.randn(2, 32, 32, 64)
x2 = torch.randn(2, 32, 32, 64)
x_out = x_ffm(x1, x2)
assert x_out.shape == torch.Size([2, 32, 32, 64])
def test_attention_refinement_module():
x_arm = AttentionRefinementModule(128, 32)
assert x_arm.conv_layer.in_channels == 128
assert x_arm.atten_conv_layer[1].conv.out_channels == 32
x = torch.randn(2, 128, 32, 64)
x_out = x_arm(x)
assert x_out.shape == torch.Size([2, 32, 32, 64])
def test_stdc_module():
x_stdc = STDCModule(in_channels=32, out_channels=32, stride=4)
assert x_stdc.layers[0].conv.in_channels == 32
assert x_stdc.layers[3].conv.out_channels == 4
x = torch.randn(2, 32, 32, 64)
x_out = x_stdc(x)
assert x_out.shape == torch.Size([2, 32, 32, 64])