1
0
mirror of https://github.com/open-mmlab/mmsegmentation.git synced 2025-06-03 22:03:48 +08:00

add fast_scnn backbone test

This commit is contained in:
johnzja 2020-08-07 15:46:49 +08:00
parent 2c77085dc2
commit c16940c102
2 changed files with 31 additions and 1 deletions
mmseg/models/backbones
tests/test_models

@ -227,6 +227,13 @@ class FastSCNN(nn.Module):
""" """
super(FastSCNN, self).__init__() super(FastSCNN, self).__init__()
if global_in_channels != higher_in_channels:
raise AssertionError('Global Input Channels must be the same with Higher Input Channels!')
elif global_out_channels != lower_in_channels:
raise AssertionError('Global Output Channels must be the same with Lower Input Channels!')
if scale_factor != 4:
raise AssertionError('Scale-factor must compensate the downsampling factor in the GFE module!')
self.in_channels = in_channels self.in_channels = in_channels
self.downsample_dw_channels1 = downsample_dw_channels1 self.downsample_dw_channels1 = downsample_dw_channels1
self.downsample_dw_channels2 = downsample_dw_channels2 self.downsample_dw_channels2 = downsample_dw_channels2

@ -4,7 +4,7 @@ from mmcv.ops import DeformConv2dPack
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from torch.nn.modules import AvgPool2d, GroupNorm from torch.nn.modules import AvgPool2d, GroupNorm
from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt, FastSCNN
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
from mmseg.models.utils import ResLayer from mmseg.models.utils import ResLayer
@ -664,3 +664,26 @@ def test_resnext_backbone():
assert feat[1].shape == torch.Size([1, 512, 28, 28]) assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) assert feat[3].shape == torch.Size([1, 2048, 7, 7])
def test_fastscnn_backbone():
with pytest.raises(AssertionError):
# Fast-SCNN channel constraints.
FastSCNN(3, 32, 48, 64, (64, 96, 128), 127, 64, 128)
# Test FastSCNN Standard Forward
model = FastSCNN()
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 1024, 2048)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([1, 64, 128, 256]) # higher-res
assert feat[1].shape == torch.Size([1, 128, 32, 64]) # lower-res
assert feat[2].shape == torch.Size([1, 128, 128, 256]) # FFM output