add unit test for sep_fcn_head: debug 5

pull/58/head
johnzja 2020-08-12 10:49:35 +08:00
parent c89674d6cb
commit ed3a6d0a70
3 changed files with 18 additions and 2 deletions

View File

@ -27,6 +27,7 @@ class FCNHead(BaseDecodeHead):
assert num_convs > 0
self.num_convs = num_convs
self.concat_input = concat_input
self.kernel_size = kernel_size
super(FCNHead, self).__init__(**kwargs)
convs = []
convs.append(

View File

@ -45,6 +45,6 @@ class DepthwiseSeparableFCNHead(FCNHead):
self.conv_cat = DepthwiseSeparableConvModule(
self.in_channels + self.channels,
self.channels,
self.channels,
padding=1,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg)

View File

@ -559,3 +559,18 @@ def test_sep_fcn_head():
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
assert head.conv_seg.kernel_size == (1, 1)
head = DepthwiseSeparableFCNHead(
in_channels=64,
channels=64,
concat_input=True,
num_classes=19,
in_index=-1,
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
x = [torch.rand(3, 64, 32, 32)]
output = head(x)
assert output.shape == (3, head.num_classes, 32, 32)
assert head.concat_input
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)