# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch from mmseg.models.backbones import STDCContextPathNet from mmseg.models.backbones.stdc import (AttentionRefinementModule, FeatureFusionModule, STDCModule, STDCNet) def test_stdc_context_path_net(): # Test STDCContextPathNet Standard Forward model = STDCContextPathNet( backbone_cfg=dict( type='STDCNet', stdc_type='STDCNet1', in_channels=3, channels=(32, 64, 256, 512, 1024), bottleneck_type='cat', num_convs=4, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='ReLU'), with_final_conv=True), last_in_channels=(1024, 512), out_channels=128, ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)) model.init_weights() model.train() batch_size = 2 imgs = torch.randn(batch_size, 3, 256, 512) feat = model(imgs) assert len(feat) == 4 # output for segment Head assert feat[0].shape == torch.Size([batch_size, 256, 32, 64]) # for auxiliary head 1 assert feat[1].shape == torch.Size([batch_size, 128, 16, 32]) # for auxiliary head 2 assert feat[2].shape == torch.Size([batch_size, 128, 32, 64]) # for auxiliary head 3 assert feat[3].shape == torch.Size([batch_size, 256, 32, 64]) # Test input with rare shape batch_size = 2 imgs = torch.randn(batch_size, 3, 527, 279) model = STDCContextPathNet( backbone_cfg=dict( type='STDCNet', stdc_type='STDCNet1', in_channels=3, channels=(32, 64, 256, 512, 1024), bottleneck_type='add', num_convs=4, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='ReLU'), with_final_conv=False), last_in_channels=(1024, 512), out_channels=128, ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)) model.init_weights() model.train() feat = model(imgs) assert len(feat) == 4 def test_stdcnet(): with pytest.raises(AssertionError): # STDC backbone constraints. STDCNet( stdc_type='STDCNet3', in_channels=3, channels=(32, 64, 256, 512, 1024), bottleneck_type='cat', num_convs=4, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='ReLU'), with_final_conv=False) with pytest.raises(AssertionError): # STDC bottleneck type constraints. STDCNet( stdc_type='STDCNet1', in_channels=3, channels=(32, 64, 256, 512, 1024), bottleneck_type='dog', num_convs=4, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='ReLU'), with_final_conv=False) with pytest.raises(AssertionError): # STDC channels length constraints. STDCNet( stdc_type='STDCNet1', in_channels=3, channels=(16, 32, 64, 256, 512, 1024), bottleneck_type='cat', num_convs=4, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='ReLU'), with_final_conv=False) def test_feature_fusion_module(): x_ffm = FeatureFusionModule(in_channels=64, out_channels=32) assert x_ffm.conv0.in_channels == 64 assert x_ffm.attention[1].in_channels == 32 assert x_ffm.attention[2].in_channels == 8 assert x_ffm.attention[2].out_channels == 32 x1 = torch.randn(2, 32, 32, 64) x2 = torch.randn(2, 32, 32, 64) x_out = x_ffm(x1, x2) assert x_out.shape == torch.Size([2, 32, 32, 64]) def test_attention_refinement_module(): x_arm = AttentionRefinementModule(128, 32) assert x_arm.conv_layer.in_channels == 128 assert x_arm.atten_conv_layer[1].conv.out_channels == 32 x = torch.randn(2, 128, 32, 64) x_out = x_arm(x) assert x_out.shape == torch.Size([2, 32, 32, 64]) def test_stdc_module(): x_stdc = STDCModule(in_channels=32, out_channels=32, stride=4) assert x_stdc.layers[0].conv.in_channels == 32 assert x_stdc.layers[3].conv.out_channels == 4 x = torch.randn(2, 32, 32, 64) x_out = x_stdc(x) assert x_out.shape == torch.Size([2, 32, 32, 64])