add unit test for sep_fcn_head: debug 5
parent
c89674d6cb
commit
ed3a6d0a70
|
@ -27,6 +27,7 @@ class FCNHead(BaseDecodeHead):
|
|||
assert num_convs > 0
|
||||
self.num_convs = num_convs
|
||||
self.concat_input = concat_input
|
||||
self.kernel_size = kernel_size
|
||||
super(FCNHead, self).__init__(**kwargs)
|
||||
convs = []
|
||||
convs.append(
|
||||
|
|
|
@ -45,6 +45,6 @@ class DepthwiseSeparableFCNHead(FCNHead):
|
|||
self.conv_cat = DepthwiseSeparableConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
self.channels,
|
||||
padding=1,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.kernel_size // 2,
|
||||
norm_cfg=self.norm_cfg)
|
||||
|
|
|
@ -559,3 +559,18 @@ def test_sep_fcn_head():
|
|||
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
|
||||
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
|
||||
assert head.conv_seg.kernel_size == (1, 1)
|
||||
|
||||
head = DepthwiseSeparableFCNHead(
|
||||
in_channels=64,
|
||||
channels=64,
|
||||
concat_input=True,
|
||||
num_classes=19,
|
||||
in_index=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
|
||||
x = [torch.rand(3, 64, 32, 32)]
|
||||
output = head(x)
|
||||
assert output.shape == (3, head.num_classes, 32, 32)
|
||||
assert head.concat_input
|
||||
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
|
||||
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
|
||||
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
|
||||
|
|
Loading…
Reference in New Issue