152 lines
5.1 KiB
Python
152 lines
5.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.backbones import CGNet
|
|
from mmseg.models.backbones.cgnet import (ContextGuidedBlock,
|
|
GlobalContextExtractor)
|
|
|
|
|
|
def test_cgnet_GlobalContextExtractor():
|
|
block = GlobalContextExtractor(16, 16, with_cp=True)
|
|
x = torch.randn(2, 16, 64, 64, requires_grad=True)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([2, 16, 64, 64])
|
|
|
|
|
|
def test_cgnet_context_guided_block():
|
|
with pytest.raises(AssertionError):
|
|
# cgnet ContextGuidedBlock GlobalContextExtractor channel and reduction
|
|
# constraints.
|
|
ContextGuidedBlock(8, 8)
|
|
|
|
# test cgnet ContextGuidedBlock with checkpoint forward
|
|
block = ContextGuidedBlock(
|
|
16, 16, act_cfg=dict(type='PReLU'), with_cp=True)
|
|
assert block.with_cp
|
|
x = torch.randn(2, 16, 64, 64, requires_grad=True)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([2, 16, 64, 64])
|
|
|
|
# test cgnet ContextGuidedBlock without checkpoint forward
|
|
block = ContextGuidedBlock(32, 32)
|
|
assert not block.with_cp
|
|
x = torch.randn(3, 32, 32, 32)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([3, 32, 32, 32])
|
|
|
|
# test cgnet ContextGuidedBlock with down sampling
|
|
block = ContextGuidedBlock(32, 32, downsample=True)
|
|
assert block.conv1x1.conv.in_channels == 32
|
|
assert block.conv1x1.conv.out_channels == 32
|
|
assert block.conv1x1.conv.kernel_size == (3, 3)
|
|
assert block.conv1x1.conv.stride == (2, 2)
|
|
assert block.conv1x1.conv.padding == (1, 1)
|
|
|
|
assert block.f_loc.in_channels == 32
|
|
assert block.f_loc.out_channels == 32
|
|
assert block.f_loc.kernel_size == (3, 3)
|
|
assert block.f_loc.stride == (1, 1)
|
|
assert block.f_loc.padding == (1, 1)
|
|
assert block.f_loc.groups == 32
|
|
assert block.f_loc.dilation == (1, 1)
|
|
assert block.f_loc.bias is None
|
|
|
|
assert block.f_sur.in_channels == 32
|
|
assert block.f_sur.out_channels == 32
|
|
assert block.f_sur.kernel_size == (3, 3)
|
|
assert block.f_sur.stride == (1, 1)
|
|
assert block.f_sur.padding == (2, 2)
|
|
assert block.f_sur.groups == 32
|
|
assert block.f_sur.dilation == (2, 2)
|
|
assert block.f_sur.bias is None
|
|
|
|
assert block.bottleneck.in_channels == 64
|
|
assert block.bottleneck.out_channels == 32
|
|
assert block.bottleneck.kernel_size == (1, 1)
|
|
assert block.bottleneck.stride == (1, 1)
|
|
assert block.bottleneck.bias is None
|
|
|
|
x = torch.randn(1, 32, 32, 32)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 16, 16])
|
|
|
|
# test cgnet ContextGuidedBlock without down sampling
|
|
block = ContextGuidedBlock(32, 32, downsample=False)
|
|
assert block.conv1x1.conv.in_channels == 32
|
|
assert block.conv1x1.conv.out_channels == 16
|
|
assert block.conv1x1.conv.kernel_size == (1, 1)
|
|
assert block.conv1x1.conv.stride == (1, 1)
|
|
assert block.conv1x1.conv.padding == (0, 0)
|
|
|
|
assert block.f_loc.in_channels == 16
|
|
assert block.f_loc.out_channels == 16
|
|
assert block.f_loc.kernel_size == (3, 3)
|
|
assert block.f_loc.stride == (1, 1)
|
|
assert block.f_loc.padding == (1, 1)
|
|
assert block.f_loc.groups == 16
|
|
assert block.f_loc.dilation == (1, 1)
|
|
assert block.f_loc.bias is None
|
|
|
|
assert block.f_sur.in_channels == 16
|
|
assert block.f_sur.out_channels == 16
|
|
assert block.f_sur.kernel_size == (3, 3)
|
|
assert block.f_sur.stride == (1, 1)
|
|
assert block.f_sur.padding == (2, 2)
|
|
assert block.f_sur.groups == 16
|
|
assert block.f_sur.dilation == (2, 2)
|
|
assert block.f_sur.bias is None
|
|
|
|
x = torch.randn(1, 32, 32, 32)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 32, 32])
|
|
|
|
|
|
def test_cgnet_backbone():
|
|
with pytest.raises(AssertionError):
|
|
# check invalid num_channels
|
|
CGNet(num_channels=(32, 64, 128, 256))
|
|
|
|
with pytest.raises(AssertionError):
|
|
# check invalid num_blocks
|
|
CGNet(num_blocks=(3, 21, 3))
|
|
|
|
with pytest.raises(AssertionError):
|
|
# check invalid dilation
|
|
CGNet(num_blocks=2)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# check invalid reduction
|
|
CGNet(reductions=16)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# check invalid num_channels and reduction
|
|
CGNet(num_channels=(32, 64, 128), reductions=(64, 129))
|
|
|
|
# Test CGNet with default settings
|
|
model = CGNet()
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(2, 3, 224, 224)
|
|
feat = model(imgs)
|
|
assert len(feat) == 3
|
|
assert feat[0].shape == torch.Size([2, 35, 112, 112])
|
|
assert feat[1].shape == torch.Size([2, 131, 56, 56])
|
|
assert feat[2].shape == torch.Size([2, 256, 28, 28])
|
|
|
|
# Test CGNet with norm_eval True and with_cp True
|
|
model = CGNet(norm_eval=True, with_cp=True)
|
|
with pytest.raises(TypeError):
|
|
# check invalid pretrained
|
|
model.init_weights(pretrained=8)
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(2, 3, 224, 224)
|
|
feat = model(imgs)
|
|
assert len(feat) == 3
|
|
assert feat[0].shape == torch.Size([2, 35, 112, 112])
|
|
assert feat[1].shape == torch.Size([2, 131, 56, 56])
|
|
assert feat[2].shape == torch.Size([2, 256, 28, 28])
|