From ed3a6d0a70db05774cd68c249ab8af073d2be24d Mon Sep 17 00:00:00 2001 From: johnzja Date: Wed, 12 Aug 2020 10:49:35 +0800 Subject: [PATCH] add unit test for sep_fcn_head: debug 5 --- mmseg/models/decode_heads/fcn_head.py | 1 + mmseg/models/decode_heads/sep_fcn_head.py | 4 ++-- tests/test_models/test_heads.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mmseg/models/decode_heads/fcn_head.py b/mmseg/models/decode_heads/fcn_head.py index e586a2e0d..ff48b5197 100644 --- a/mmseg/models/decode_heads/fcn_head.py +++ b/mmseg/models/decode_heads/fcn_head.py @@ -27,6 +27,7 @@ class FCNHead(BaseDecodeHead): assert num_convs > 0 self.num_convs = num_convs self.concat_input = concat_input + self.kernel_size = kernel_size super(FCNHead, self).__init__(**kwargs) convs = [] convs.append( diff --git a/mmseg/models/decode_heads/sep_fcn_head.py b/mmseg/models/decode_heads/sep_fcn_head.py index 34401920a..b7a9bae2f 100644 --- a/mmseg/models/decode_heads/sep_fcn_head.py +++ b/mmseg/models/decode_heads/sep_fcn_head.py @@ -45,6 +45,6 @@ class DepthwiseSeparableFCNHead(FCNHead): self.conv_cat = DepthwiseSeparableConvModule( self.in_channels + self.channels, self.channels, - self.channels, - padding=1, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, norm_cfg=self.norm_cfg) diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index dc41e410b..8feb0e64b 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -559,3 +559,18 @@ def test_sep_fcn_head(): assert isinstance(head.convs[0], DepthwiseSeparableConvModule) assert isinstance(head.convs[1], DepthwiseSeparableConvModule) assert head.conv_seg.kernel_size == (1, 1) + + head = DepthwiseSeparableFCNHead( + in_channels=64, + channels=64, + concat_input=True, + num_classes=19, + in_index=-1, + norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) + x = [torch.rand(3, 64, 32, 32)] + output = head(x) + assert output.shape == (3, head.num_classes, 32, 32) + assert head.concat_input + from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule + assert isinstance(head.convs[0], DepthwiseSeparableConvModule) + assert isinstance(head.convs[1], DepthwiseSeparableConvModule)