add fast_scnn backbone test
parent
2c77085dc2
commit
c16940c102
|
@ -227,6 +227,13 @@ class FastSCNN(nn.Module):
|
|||
"""
|
||||
|
||||
super(FastSCNN, self).__init__()
|
||||
if global_in_channels != higher_in_channels:
|
||||
raise AssertionError('Global Input Channels must be the same with Higher Input Channels!')
|
||||
elif global_out_channels != lower_in_channels:
|
||||
raise AssertionError('Global Output Channels must be the same with Lower Input Channels!')
|
||||
if scale_factor != 4:
|
||||
raise AssertionError('Scale-factor must compensate the downsampling factor in the GFE module!')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_dw_channels1 = downsample_dw_channels1
|
||||
self.downsample_dw_channels2 = downsample_dw_channels2
|
||||
|
|
|
@ -4,7 +4,7 @@ from mmcv.ops import DeformConv2dPack
|
|||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from torch.nn.modules import AvgPool2d, GroupNorm
|
||||
|
||||
from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt
|
||||
from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt, FastSCNN
|
||||
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
||||
from mmseg.models.utils import ResLayer
|
||||
|
@ -664,3 +664,26 @@ def test_resnext_backbone():
|
|||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
|
||||
def test_fastscnn_backbone():
|
||||
with pytest.raises(AssertionError):
|
||||
# Fast-SCNN channel constraints.
|
||||
FastSCNN(3, 32, 48, 64, (64, 96, 128), 127, 64, 128)
|
||||
|
||||
# Test FastSCNN Standard Forward
|
||||
model = FastSCNN()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 1024, 2048)
|
||||
feat = model(imgs)
|
||||
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([1, 64, 128, 256]) # higher-res
|
||||
assert feat[1].shape == torch.Size([1, 128, 32, 64]) # lower-res
|
||||
assert feat[2].shape == torch.Size([1, 128, 128, 256]) # FFM output
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue