check if isinstance(model, nn.DataParallel)

pull/119/head
KaiyangZhou 2018-11-08 23:54:26 +00:00
parent 2bc29e37f5
commit 54b92ee617
1 changed files with 8 additions and 0 deletions

View File

@ -1,4 +1,5 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
@ -48,6 +49,9 @@ def open_specified_layers(model, open_layers):
- model (nn.Module): neural net model.
- open_layers (list): list of layer names.
"""
if isinstance(model, nn.DataParallel):
model = model.module
for layer in open_layers:
assert hasattr(model, layer), "'{}' is not an attribute of the model, please provide the correct name".format(layer)
@ -64,6 +68,10 @@ def open_specified_layers(model, open_layers):
def count_num_param(model):
num_param = sum(p.numel() for p in model.parameters()) / 1e+06
if isinstance(model, nn.DataParallel):
model = model.module
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
# we ignore the classifier because it is unused at test time
num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06