modify self.layers

This commit is contained in:
lixiaojie 2020-06-15 20:52:17 +08:00
parent 3bf971238a
commit edceab13e3
3 changed files with 13 additions and 17 deletions

View File

@ -1,3 +1,3 @@
from .version import __version__, short_version #from .version import __version__, short_version
__all__ = ['__version__', 'short_version'] #__all__ = ['__version__', 'short_version']

View File

@ -111,8 +111,8 @@ class ShuffleUnit(nn.Module):
if self.combine == 'add': if self.combine == 'add':
self.depthwise_stride = 1 self.depthwise_stride = 1
self._combine_func = self._add self._combine_func = self._add
assert inplanes == planes, ('inplanes must be equal to ' assert inplanes == planes, (
'planes when combine is add') 'inplanes must be equal to planes when combine is add')
elif self.combine == 'concat': elif self.combine == 'concat':
self.depthwise_stride = 2 self.depthwise_stride = 2
self._combine_func = self._concat self._combine_func = self._concat
@ -273,20 +273,18 @@ class ShuffleNetv1(BaseBackbone):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layers = [] self.layers = nn.ModuleList()
for i, num_blocks in enumerate(self.stage_blocks): for i, num_blocks in enumerate(self.stage_blocks):
first_block = False if i == 0 else True first_block = False if i == 0 else True
layer = self.make_layer(channels[i], num_blocks, first_block) layer = self.make_layer(channels[i], num_blocks, first_block)
layer_name = f'layer{i + 1}' self.layers.append(layer)
self.add_module(layer_name, layer)
self.layers.append(layer_name)
def _freeze_stages(self): def _freeze_stages(self):
if self.frozen_stages >= 0: if self.frozen_stages >= 0:
for param in self.conv1.parameters(): for param in self.conv1.parameters():
param.requires_grad = False param.requires_grad = False
for i in range(1, self.frozen_stages + 1): for i in range(self.frozen_stages):
layer = getattr(self, f'layer{i}') layer = self.layers[i]
layer.eval() layer.eval()
for param in layer.parameters(): for param in layer.parameters():
param.requires_grad = False param.requires_grad = False
@ -336,8 +334,7 @@ class ShuffleNetv1(BaseBackbone):
x = self.maxpool(x) x = self.maxpool(x)
outs = [] outs = []
for i, layer_name in enumerate(self.layers): for i, layer in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x) x = layer(x)
if i in self.out_indices: if i in self.out_indices:
outs.append(x) outs.append(x)

View File

@ -104,11 +104,10 @@ def test_shufflenetv1_backbone():
model = ShuffleNetv1(frozen_stages=frozen_stages) model = ShuffleNetv1(frozen_stages=frozen_stages)
model.init_weights() model.init_weights()
model.train() model.train()
for layer in [model.conv1]: for param in model.conv1.parameters():
for param in layer.parameters():
assert param.requires_grad is False assert param.requires_grad is False
for i in range(1, frozen_stages + 1): for i in range(frozen_stages):
layer = getattr(model, f'layer{i}') layer = model.layers[i]
for mod in layer.modules(): for mod in layer.modules():
if isinstance(mod, _BatchNorm): if isinstance(mod, _BatchNorm):
assert mod.training is False assert mod.training is False