mirror of https://github.com/open-mmlab/mmcv.git
fix bug of convmodule (#889)
* fix bug of convmodule * fix bug of convmodule * fix unitest * remove assertpull/902/head
parent
97730c2316
commit
371a21728f
|
@ -146,6 +146,8 @@ class ConvModule(nn.Module):
|
|||
norm_channels = in_channels
|
||||
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
|
||||
self.add_module(self.norm_name, norm)
|
||||
else:
|
||||
self.norm_name = None
|
||||
|
||||
# build activation layer
|
||||
if self.with_activation:
|
||||
|
@ -162,7 +164,10 @@ class ConvModule(nn.Module):
|
|||
|
||||
@property
|
||||
def norm(self):
|
||||
return getattr(self, self.norm_name)
|
||||
if self.norm_name:
|
||||
return getattr(self, self.norm_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
def init_weights(self):
|
||||
# 1. It is mainly for customized conv layers with their own
|
||||
|
|
|
@ -75,7 +75,7 @@ def test_conv_module():
|
|||
assert conv.with_activation
|
||||
assert hasattr(conv, 'activate')
|
||||
assert not conv.with_norm
|
||||
assert not hasattr(conv, 'norm')
|
||||
assert conv.norm is None
|
||||
x = torch.rand(1, 3, 256, 256)
|
||||
output = conv(x)
|
||||
assert output.shape == (1, 8, 255, 255)
|
||||
|
@ -83,7 +83,7 @@ def test_conv_module():
|
|||
# conv
|
||||
conv = ConvModule(3, 8, 2, act_cfg=None)
|
||||
assert not conv.with_norm
|
||||
assert not hasattr(conv, 'norm')
|
||||
assert conv.norm is None
|
||||
assert not conv.with_activation
|
||||
assert not hasattr(conv, 'activate')
|
||||
x = torch.rand(1, 3, 256, 256)
|
||||
|
|
Loading…
Reference in New Issue