diff --git a/torchreid/utils/torchtools.py b/torchreid/utils/torchtools.py index 8bcd21e..2111075 100644 --- a/torchreid/utils/torchtools.py +++ b/torchreid/utils/torchtools.py @@ -1,4 +1,5 @@ from __future__ import absolute_import +from __future__ import print_function from __future__ import division import torch @@ -48,6 +49,9 @@ def open_specified_layers(model, open_layers): - model (nn.Module): neural net model. - open_layers (list): list of layer names. """ + if isinstance(model, nn.DataParallel): + model = model.module + for layer in open_layers: assert hasattr(model, layer), "'{}' is not an attribute of the model, please provide the correct name".format(layer) @@ -64,6 +68,10 @@ def open_specified_layers(model, open_layers): def count_num_param(model): num_param = sum(p.numel() for p in model.parameters()) / 1e+06 + + if isinstance(model, nn.DataParallel): + model = model.module + if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module): # we ignore the classifier because it is unused at test time num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06