fix bug of convmodule (#889)

* fix bug of convmodule

* fix bug of convmodule

* fix unitest

* remove assert
pull/902/head
ZhangShilong 2021-03-20 23:08:20 +08:00 committed by GitHub
parent 97730c2316
commit 371a21728f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 3 deletions

View File

@ -146,6 +146,8 @@ class ConvModule(nn.Module):
norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
self.add_module(self.norm_name, norm)
else:
self.norm_name = None
# build activation layer
if self.with_activation:
@ -162,7 +164,10 @@ class ConvModule(nn.Module):
@property
def norm(self):
return getattr(self, self.norm_name)
if self.norm_name:
return getattr(self, self.norm_name)
else:
return None
def init_weights(self):
# 1. It is mainly for customized conv layers with their own

View File

@ -75,7 +75,7 @@ def test_conv_module():
assert conv.with_activation
assert hasattr(conv, 'activate')
assert not conv.with_norm
assert not hasattr(conv, 'norm')
assert conv.norm is None
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
@ -83,7 +83,7 @@ def test_conv_module():
# conv
conv = ConvModule(3, 8, 2, act_cfg=None)
assert not conv.with_norm
assert not hasattr(conv, 'norm')
assert conv.norm is None
assert not conv.with_activation
assert not hasattr(conv, 'activate')
x = torch.rand(1, 3, 256, 256)