From 36228f54dc5fc812c35b38838bac14fd90c74ef3 Mon Sep 17 00:00:00 2001 From: johnzja Date: Wed, 12 Aug 2020 10:07:49 +0800 Subject: [PATCH] add unit test for sep_fcn_head: debug 0 --- mmseg/models/decode_heads/sep_fcn_head.py | 2 +- tests/test_models/test_heads.py | 21 ++++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mmseg/models/decode_heads/sep_fcn_head.py b/mmseg/models/decode_heads/sep_fcn_head.py index 538020d1a..34401920a 100644 --- a/mmseg/models/decode_heads/sep_fcn_head.py +++ b/mmseg/models/decode_heads/sep_fcn_head.py @@ -8,7 +8,7 @@ class DepthwiseSeparableFCNHead(FCNHead): """Depthwise-Separable Fully Convolutional Network for Semantic Segmentation. - This head is implemented according to Fast-SCNN. + This head is implemented according to Fast-SCNN paper. Args: in_channels(int): Number of output channels of FFM. channels(int): Number of middle-stage channels in the decode head. diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 3ac6bb0aa..c3a6ed647 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -6,7 +6,8 @@ from mmcv.cnn import ConvModule from mmcv.utils.parrots_wrapper import SyncBatchNorm from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead, - DepthwiseSeparableASPPHead, EncHead, + DepthwiseSeparableASPPHead, + DepthwiseSeparableFCNHead, EncHead, FCNHead, GCHead, NLHead, OCRHead, PSAHead, PSPHead, UPerHead) from mmseg.models.decode_heads.decode_head import BaseDecodeHead @@ -539,3 +540,21 @@ def test_dw_aspp_head(): assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 45, 45) + + +def test_sep_fcn_head(): + # test sep_fcn_head with concat_input=False + head = DepthwiseSeparableFCNHead( + in_channels=128, + channels=128, + concat_input=False, + num_classes=19, + in_index=-1) + x = torch.rand(1, 128, 32, 32) + output = head(x) + assert output.shape == (1, head.num_classes, 32, 32) + assert not head.concat_input + from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule + assert isinstance(head.convs[0], DepthwiseSeparableConvModule) + assert isinstance(head.convs[1], DepthwiseSeparableConvModule) + assert head.conv_seg.kernel_size == (1, 1)