add fast_scnn backbone test

pull/58/head
johnzja 2020-08-07 15:46:49 +08:00
parent 2c77085dc2
commit c16940c102
2 changed files with 31 additions and 1 deletions

View File

@ -227,6 +227,13 @@ class FastSCNN(nn.Module):
"""
super(FastSCNN, self).__init__()
if global_in_channels != higher_in_channels:
raise AssertionError('Global Input Channels must be the same with Higher Input Channels!')
elif global_out_channels != lower_in_channels:
raise AssertionError('Global Output Channels must be the same with Lower Input Channels!')
if scale_factor != 4:
raise AssertionError('Scale-factor must compensate the downsampling factor in the GFE module!')
self.in_channels = in_channels
self.downsample_dw_channels1 = downsample_dw_channels1
self.downsample_dw_channels2 = downsample_dw_channels2

View File

@ -4,7 +4,7 @@ from mmcv.ops import DeformConv2dPack
from mmcv.utils.parrots_wrapper import _BatchNorm
from torch.nn.modules import AvgPool2d, GroupNorm
from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt
from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt, FastSCNN
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
from mmseg.models.utils import ResLayer
@ -664,3 +664,26 @@ def test_resnext_backbone():
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
def test_fastscnn_backbone():
with pytest.raises(AssertionError):
# Fast-SCNN channel constraints.
FastSCNN(3, 32, 48, 64, (64, 96, 128), 127, 64, 128)
# Test FastSCNN Standard Forward
model = FastSCNN()
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 1024, 2048)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([1, 64, 128, 256]) # higher-res
assert feat[1].shape == torch.Size([1, 128, 32, 64]) # lower-res
assert feat[2].shape == torch.Size([1, 128, 128, 256]) # FFM output