modify format of self.layers

pull/2/head
lixiaojie 2020-06-15 20:42:04 +08:00
parent bb99ca5c66
commit 703714b78e
2 changed files with 15 additions and 21 deletions

View File

@ -92,16 +92,14 @@ class InvertedResidual(nn.Module):
branch_features = planes // 2 branch_features = planes // 2
if self.stride == 1: if self.stride == 1:
assert inplanes == branch_features * 2, (f'inplanes ({inplanes}) ' assert inplanes == branch_features * 2, (
'should equal to ' f'inplanes ({inplanes}) should equal to branch_features * 2 '
'branch_features * 2 ' f'({branch_features * 2}) when stride is 1')
f'({branch_features * 2})'
' when stride is 1')
if inplanes != branch_features * 2: if inplanes != branch_features * 2:
assert self.stride != 1, (f'stride ({self.stride}) should not ' assert self.stride != 1, (
'equal 1 when inplanes != ' f'stride ({self.stride}) should not equal 1 when '
'branch_features * 2') f'inplanes != branch_features * 2')
if self.stride > 1: if self.stride > 1:
self.branch1 = nn.Sequential( self.branch1 = nn.Sequential(
@ -250,12 +248,10 @@ class ShuffleNetv2(BaseBackbone):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layers = [] self.layers = nn.ModuleList()
for i, num_blocks in enumerate(self.stage_blocks): for i, num_blocks in enumerate(self.stage_blocks):
layer = self._make_layer(channels[i], num_blocks) layer = self._make_layer(channels[i], num_blocks)
layer_name = f'layer{i + 1}' self.layers.append(layer)
self.add_module(layer_name, layer)
self.layers.append(layer_name)
output_channels = channels[-1] output_channels = channels[-1]
self.conv2 = ConvModule( self.conv2 = ConvModule(
@ -294,8 +290,8 @@ class ShuffleNetv2(BaseBackbone):
for param in self.conv1.parameters(): for param in self.conv1.parameters():
param.requires_grad = False param.requires_grad = False
for i in range(1, self.frozen_stages + 1): for i in range(self.frozen_stages):
m = getattr(self, f'layer{i}') m = self.layers[i]
m.eval() m.eval()
for param in m.parameters(): for param in m.parameters():
param.requires_grad = False param.requires_grad = False
@ -316,8 +312,7 @@ class ShuffleNetv2(BaseBackbone):
x = self.maxpool(x) x = self.maxpool(x)
outs = [] outs = []
for i, layer_name in enumerate(self.layers): for i, layer in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x) x = layer(x)
if i in self.out_indices: if i in self.out_indices:
outs.append(x) outs.append(x)

View File

@ -98,11 +98,10 @@ def test_shufflenetv2_backbone():
model = ShuffleNetv2(frozen_stages=frozen_stages) model = ShuffleNetv2(frozen_stages=frozen_stages)
model.init_weights() model.init_weights()
model.train() model.train()
for layer in [model.conv1]: for param in model.conv1.parameters():
for param in layer.parameters(): assert param.requires_grad is False
assert param.requires_grad is False for i in range(0, frozen_stages):
for i in range(1, frozen_stages + 1): layer = model.layers[i]
layer = getattr(model, f'layer{i}')
for mod in layer.modules(): for mod in layer.modules():
if isinstance(mod, _BatchNorm): if isinstance(mod, _BatchNorm):
assert mod.training is False assert mod.training is False