add unit test for sep_fcn_head: debug 0

This commit is contained in:
johnzja 2020-08-12 10:07:49 +08:00
parent 66224e96c5
commit 36228f54dc
2 changed files with 21 additions and 2 deletions

View File

@ -8,7 +8,7 @@ class DepthwiseSeparableFCNHead(FCNHead):
"""Depthwise-Separable Fully Convolutional Network for Semantic
Segmentation.
This head is implemented according to Fast-SCNN.
This head is implemented according to Fast-SCNN paper.
Args:
in_channels(int): Number of output channels of FFM.
channels(int): Number of middle-stage channels in the decode head.

View File

@ -6,7 +6,8 @@ from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import SyncBatchNorm
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
DepthwiseSeparableASPPHead, EncHead,
DepthwiseSeparableASPPHead,
DepthwiseSeparableFCNHead, EncHead,
FCNHead, GCHead, NLHead, OCRHead,
PSAHead, PSPHead, UPerHead)
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
@ -539,3 +540,21 @@ def test_dw_aspp_head():
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)
def test_sep_fcn_head():
# test sep_fcn_head with concat_input=False
head = DepthwiseSeparableFCNHead(
in_channels=128,
channels=128,
concat_input=False,
num_classes=19,
in_index=-1)
x = torch.rand(1, 128, 32, 32)
output = head(x)
assert output.shape == (1, head.num_classes, 32, 32)
assert not head.concat_input
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
assert head.conv_seg.kernel_size == (1, 1)