132 lines
4.4 KiB
Python
132 lines
4.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.backbones import STDCContextPathNet
|
|
from mmseg.models.backbones.stdc import (AttentionRefinementModule,
|
|
FeatureFusionModule, STDCModule,
|
|
STDCNet)
|
|
|
|
|
|
def test_stdc_context_path_net():
|
|
# Test STDCContextPathNet Standard Forward
|
|
model = STDCContextPathNet(
|
|
backbone_cfg=dict(
|
|
type='STDCNet',
|
|
stdc_type='STDCNet1',
|
|
in_channels=3,
|
|
channels=(32, 64, 256, 512, 1024),
|
|
bottleneck_type='cat',
|
|
num_convs=4,
|
|
norm_cfg=dict(type='BN', requires_grad=True),
|
|
act_cfg=dict(type='ReLU'),
|
|
with_final_conv=True),
|
|
last_in_channels=(1024, 512),
|
|
out_channels=128,
|
|
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4))
|
|
model.init_weights()
|
|
model.train()
|
|
batch_size = 2
|
|
imgs = torch.randn(batch_size, 3, 256, 512)
|
|
feat = model(imgs)
|
|
|
|
assert len(feat) == 4
|
|
# output for segment Head
|
|
assert feat[0].shape == torch.Size([batch_size, 256, 32, 64])
|
|
# for auxiliary head 1
|
|
assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
|
|
# for auxiliary head 2
|
|
assert feat[2].shape == torch.Size([batch_size, 128, 32, 64])
|
|
# for auxiliary head 3
|
|
assert feat[3].shape == torch.Size([batch_size, 256, 32, 64])
|
|
|
|
# Test input with rare shape
|
|
batch_size = 2
|
|
imgs = torch.randn(batch_size, 3, 527, 279)
|
|
model = STDCContextPathNet(
|
|
backbone_cfg=dict(
|
|
type='STDCNet',
|
|
stdc_type='STDCNet1',
|
|
in_channels=3,
|
|
channels=(32, 64, 256, 512, 1024),
|
|
bottleneck_type='add',
|
|
num_convs=4,
|
|
norm_cfg=dict(type='BN', requires_grad=True),
|
|
act_cfg=dict(type='ReLU'),
|
|
with_final_conv=False),
|
|
last_in_channels=(1024, 512),
|
|
out_channels=128,
|
|
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4))
|
|
model.init_weights()
|
|
model.train()
|
|
feat = model(imgs)
|
|
assert len(feat) == 4
|
|
|
|
|
|
def test_stdcnet():
|
|
with pytest.raises(AssertionError):
|
|
# STDC backbone constraints.
|
|
STDCNet(
|
|
stdc_type='STDCNet3',
|
|
in_channels=3,
|
|
channels=(32, 64, 256, 512, 1024),
|
|
bottleneck_type='cat',
|
|
num_convs=4,
|
|
norm_cfg=dict(type='BN', requires_grad=True),
|
|
act_cfg=dict(type='ReLU'),
|
|
with_final_conv=False)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# STDC bottleneck type constraints.
|
|
STDCNet(
|
|
stdc_type='STDCNet1',
|
|
in_channels=3,
|
|
channels=(32, 64, 256, 512, 1024),
|
|
bottleneck_type='dog',
|
|
num_convs=4,
|
|
norm_cfg=dict(type='BN', requires_grad=True),
|
|
act_cfg=dict(type='ReLU'),
|
|
with_final_conv=False)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# STDC channels length constraints.
|
|
STDCNet(
|
|
stdc_type='STDCNet1',
|
|
in_channels=3,
|
|
channels=(16, 32, 64, 256, 512, 1024),
|
|
bottleneck_type='cat',
|
|
num_convs=4,
|
|
norm_cfg=dict(type='BN', requires_grad=True),
|
|
act_cfg=dict(type='ReLU'),
|
|
with_final_conv=False)
|
|
|
|
|
|
def test_feature_fusion_module():
|
|
x_ffm = FeatureFusionModule(in_channels=64, out_channels=32)
|
|
assert x_ffm.conv0.in_channels == 64
|
|
assert x_ffm.attention[1].in_channels == 32
|
|
assert x_ffm.attention[2].in_channels == 8
|
|
assert x_ffm.attention[2].out_channels == 32
|
|
x1 = torch.randn(2, 32, 32, 64)
|
|
x2 = torch.randn(2, 32, 32, 64)
|
|
x_out = x_ffm(x1, x2)
|
|
assert x_out.shape == torch.Size([2, 32, 32, 64])
|
|
|
|
|
|
def test_attention_refinement_module():
|
|
x_arm = AttentionRefinementModule(128, 32)
|
|
assert x_arm.conv_layer.in_channels == 128
|
|
assert x_arm.atten_conv_layer[1].conv.out_channels == 32
|
|
x = torch.randn(2, 128, 32, 64)
|
|
x_out = x_arm(x)
|
|
assert x_out.shape == torch.Size([2, 32, 32, 64])
|
|
|
|
|
|
def test_stdc_module():
|
|
x_stdc = STDCModule(in_channels=32, out_channels=32, stride=4)
|
|
assert x_stdc.layers[0].conv.in_channels == 32
|
|
assert x_stdc.layers[3].conv.out_channels == 4
|
|
x = torch.randn(2, 32, 32, 64)
|
|
x_out = x_stdc(x)
|
|
assert x_out.shape == torch.Size([2, 32, 32, 64])
|