mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add unit test for sep_fcn_head: debug 0
This commit is contained in:
parent
66224e96c5
commit
36228f54dc
@ -8,7 +8,7 @@ class DepthwiseSeparableFCNHead(FCNHead):
|
||||
"""Depthwise-Separable Fully Convolutional Network for Semantic
|
||||
Segmentation.
|
||||
|
||||
This head is implemented according to Fast-SCNN.
|
||||
This head is implemented according to Fast-SCNN paper.
|
||||
Args:
|
||||
in_channels(int): Number of output channels of FFM.
|
||||
channels(int): Number of middle-stage channels in the decode head.
|
||||
|
@ -6,7 +6,8 @@ from mmcv.cnn import ConvModule
|
||||
from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
||||
|
||||
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
|
||||
DepthwiseSeparableASPPHead, EncHead,
|
||||
DepthwiseSeparableASPPHead,
|
||||
DepthwiseSeparableFCNHead, EncHead,
|
||||
FCNHead, GCHead, NLHead, OCRHead,
|
||||
PSAHead, PSPHead, UPerHead)
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
@ -539,3 +540,21 @@ def test_dw_aspp_head():
|
||||
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_sep_fcn_head():
|
||||
# test sep_fcn_head with concat_input=False
|
||||
head = DepthwiseSeparableFCNHead(
|
||||
in_channels=128,
|
||||
channels=128,
|
||||
concat_input=False,
|
||||
num_classes=19,
|
||||
in_index=-1)
|
||||
x = torch.rand(1, 128, 32, 32)
|
||||
output = head(x)
|
||||
assert output.shape == (1, head.num_classes, 32, 32)
|
||||
assert not head.concat_input
|
||||
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
|
||||
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
|
||||
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
|
||||
assert head.conv_seg.kernel_size == (1, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user