check if isinstance(model, nn.DataParallel)
parent
2bc29e37f5
commit
54b92ee617
|
@ -1,4 +1,5 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import torch
|
||||
|
@ -48,6 +49,9 @@ def open_specified_layers(model, open_layers):
|
|||
- model (nn.Module): neural net model.
|
||||
- open_layers (list): list of layer names.
|
||||
"""
|
||||
if isinstance(model, nn.DataParallel):
|
||||
model = model.module
|
||||
|
||||
for layer in open_layers:
|
||||
assert hasattr(model, layer), "'{}' is not an attribute of the model, please provide the correct name".format(layer)
|
||||
|
||||
|
@ -64,6 +68,10 @@ def open_specified_layers(model, open_layers):
|
|||
|
||||
def count_num_param(model):
|
||||
num_param = sum(p.numel() for p in model.parameters()) / 1e+06
|
||||
|
||||
if isinstance(model, nn.DataParallel):
|
||||
model = model.module
|
||||
|
||||
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
|
||||
# we ignore the classifier because it is unused at test time
|
||||
num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06
|
||||
|
|
Loading…
Reference in New Issue