mmsegmentation/tests/test_models/test_backbones/test_cgnet.py

152 lines
5.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones import CGNet
from mmseg.models.backbones.cgnet import (ContextGuidedBlock,
GlobalContextExtractor)
def test_cgnet_GlobalContextExtractor():
block = GlobalContextExtractor(16, 16, with_cp=True)
x = torch.randn(2, 16, 64, 64, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([2, 16, 64, 64])
def test_cgnet_context_guided_block():
with pytest.raises(AssertionError):
# cgnet ContextGuidedBlock GlobalContextExtractor channel and reduction
# constraints.
ContextGuidedBlock(8, 8)
# test cgnet ContextGuidedBlock with checkpoint forward
block = ContextGuidedBlock(
16, 16, act_cfg=dict(type='PReLU'), with_cp=True)
assert block.with_cp
x = torch.randn(2, 16, 64, 64, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([2, 16, 64, 64])
# test cgnet ContextGuidedBlock without checkpoint forward
block = ContextGuidedBlock(32, 32)
assert not block.with_cp
x = torch.randn(3, 32, 32, 32)
x_out = block(x)
assert x_out.shape == torch.Size([3, 32, 32, 32])
# test cgnet ContextGuidedBlock with down sampling
block = ContextGuidedBlock(32, 32, downsample=True)
assert block.conv1x1.conv.in_channels == 32
assert block.conv1x1.conv.out_channels == 32
assert block.conv1x1.conv.kernel_size == (3, 3)
assert block.conv1x1.conv.stride == (2, 2)
assert block.conv1x1.conv.padding == (1, 1)
assert block.f_loc.in_channels == 32
assert block.f_loc.out_channels == 32
assert block.f_loc.kernel_size == (3, 3)
assert block.f_loc.stride == (1, 1)
assert block.f_loc.padding == (1, 1)
assert block.f_loc.groups == 32
assert block.f_loc.dilation == (1, 1)
assert block.f_loc.bias is None
assert block.f_sur.in_channels == 32
assert block.f_sur.out_channels == 32
assert block.f_sur.kernel_size == (3, 3)
assert block.f_sur.stride == (1, 1)
assert block.f_sur.padding == (2, 2)
assert block.f_sur.groups == 32
assert block.f_sur.dilation == (2, 2)
assert block.f_sur.bias is None
assert block.bottleneck.in_channels == 64
assert block.bottleneck.out_channels == 32
assert block.bottleneck.kernel_size == (1, 1)
assert block.bottleneck.stride == (1, 1)
assert block.bottleneck.bias is None
x = torch.randn(1, 32, 32, 32)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 16, 16])
# test cgnet ContextGuidedBlock without down sampling
block = ContextGuidedBlock(32, 32, downsample=False)
assert block.conv1x1.conv.in_channels == 32
assert block.conv1x1.conv.out_channels == 16
assert block.conv1x1.conv.kernel_size == (1, 1)
assert block.conv1x1.conv.stride == (1, 1)
assert block.conv1x1.conv.padding == (0, 0)
assert block.f_loc.in_channels == 16
assert block.f_loc.out_channels == 16
assert block.f_loc.kernel_size == (3, 3)
assert block.f_loc.stride == (1, 1)
assert block.f_loc.padding == (1, 1)
assert block.f_loc.groups == 16
assert block.f_loc.dilation == (1, 1)
assert block.f_loc.bias is None
assert block.f_sur.in_channels == 16
assert block.f_sur.out_channels == 16
assert block.f_sur.kernel_size == (3, 3)
assert block.f_sur.stride == (1, 1)
assert block.f_sur.padding == (2, 2)
assert block.f_sur.groups == 16
assert block.f_sur.dilation == (2, 2)
assert block.f_sur.bias is None
x = torch.randn(1, 32, 32, 32)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 32, 32])
def test_cgnet_backbone():
with pytest.raises(AssertionError):
# check invalid num_channels
CGNet(num_channels=(32, 64, 128, 256))
with pytest.raises(AssertionError):
# check invalid num_blocks
CGNet(num_blocks=(3, 21, 3))
with pytest.raises(AssertionError):
# check invalid dilation
CGNet(num_blocks=2)
with pytest.raises(AssertionError):
# check invalid reduction
CGNet(reductions=16)
with pytest.raises(AssertionError):
# check invalid num_channels and reduction
CGNet(num_channels=(32, 64, 128), reductions=(64, 129))
# Test CGNet with default settings
model = CGNet()
model.init_weights()
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([2, 35, 112, 112])
assert feat[1].shape == torch.Size([2, 131, 56, 56])
assert feat[2].shape == torch.Size([2, 256, 28, 28])
# Test CGNet with norm_eval True and with_cp True
model = CGNet(norm_eval=True, with_cp=True)
with pytest.raises(TypeError):
# check invalid pretrained
model.init_weights(pretrained=8)
model.init_weights()
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([2, 35, 112, 112])
assert feat[1].shape == torch.Size([2, 131, 56, 56])
assert feat[2].shape == torch.Size([2, 256, 28, 28])