267 lines
8.6 KiB
Python
267 lines
8.6 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
from __future__ import division
|
|
|
|
__all__ = ['save_checkpoint', 'load_checkpoint', 'resume_from_checkpoint',
|
|
'open_all_layers', 'open_specified_layers', 'count_num_param',
|
|
'load_pretrained_weights']
|
|
|
|
from collections import OrderedDict
|
|
import shutil
|
|
import warnings
|
|
import os
|
|
import os.path as osp
|
|
from functools import partial
|
|
import pickle
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .tools import mkdir_if_missing
|
|
|
|
|
|
def save_checkpoint(state, save_dir, is_best=False, remove_module_from_keys=False):
|
|
"""Saves checkpoint.
|
|
|
|
Args:
|
|
state (dict): dictionary.
|
|
save_dir (str): directory to save checkpoint.
|
|
is_best (bool, optional): if True, this checkpoint will be copied and named
|
|
"model-best.pth.tar". Default is False.
|
|
remove_module_from_keys (bool, optional): whether to remove "module."
|
|
from layer names. Default is False.
|
|
|
|
Examples::
|
|
>>> state = {
|
|
>>> 'state_dict': model.state_dict(),
|
|
>>> 'epoch': 10,
|
|
>>> 'rank1': 0.5,
|
|
>>> 'optimizer': optimizer.state_dict()
|
|
>>> }
|
|
>>> save_checkpoint(state, 'log/my_model')
|
|
"""
|
|
mkdir_if_missing(save_dir)
|
|
if remove_module_from_keys:
|
|
# remove 'module.' in state_dict's keys
|
|
state_dict = state['state_dict']
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
if k.startswith('module.'):
|
|
k = k[7:]
|
|
new_state_dict[k] = v
|
|
state['state_dict'] = new_state_dict
|
|
# save
|
|
epoch = state['epoch']
|
|
fpath = osp.join(save_dir, 'model.pth.tar-' + str(epoch))
|
|
torch.save(state, fpath)
|
|
print('Checkpoint saved to "{}"'.format(fpath))
|
|
if is_best:
|
|
shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model-best.pth.tar'))
|
|
|
|
|
|
def load_checkpoint(fpath):
|
|
"""Loads checkpoint.
|
|
|
|
``UnicodeDecodeError`` can be well handled, which means
|
|
python2-saved files can be read from python3.
|
|
|
|
Args:
|
|
fpath (str): path to checkpoint.
|
|
|
|
Returns:
|
|
dict
|
|
|
|
Examples::
|
|
>>> from torchreid.utils import load_checkpoint
|
|
>>> fpath = 'log/my_model/model.pth.tar-10'
|
|
>>> checkpoint = load_checkpoint(fpath)
|
|
"""
|
|
map_location = None if torch.cuda.is_available() else 'cpu'
|
|
try:
|
|
checkpoint = torch.load(fpath, map_location=map_location)
|
|
except UnicodeDecodeError:
|
|
pickle.load = partial(pickle.load, encoding="latin1")
|
|
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
|
|
checkpoint = torch.load(fpath, pickle_module=pickle, map_location=map_location)
|
|
except Exception:
|
|
print('Unable to load checkpoint from "{}"'.format(fpath))
|
|
raise
|
|
return checkpoint
|
|
|
|
|
|
def resume_from_checkpoint(fpath, model, optimizer=None):
|
|
"""Resumes training from a checkpoint.
|
|
|
|
This will load (1) model weights and (2) ``state_dict``
|
|
of optimizer if ``optimizer`` is not None.
|
|
|
|
Args:
|
|
fpath (str): path to checkpoint.
|
|
model (nn.Module): model.
|
|
optimizer (Optimizer, optional): an Optimizer.
|
|
|
|
Returns:
|
|
int: start_epoch.
|
|
|
|
Examples::
|
|
>>> from torchreid.utils import resume_from_checkpoint
|
|
>>> fpath = 'log/my_model/model.pth.tar-10'
|
|
>>> start_epoch = resume_from_checkpoint(fpath, model, optimizer)
|
|
"""
|
|
print('Loading checkpoint from "{}"'.format(fpath))
|
|
checkpoint = load_checkpoint(fpath)
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
print('Loaded model weights')
|
|
if optimizer is not None and 'optimizer' in checkpoint.keys():
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
print('Loaded optimizer')
|
|
start_epoch = checkpoint['epoch']
|
|
print('Last epoch = {}'.format(start_epoch))
|
|
if 'rank1' in checkpoint.keys():
|
|
print('Last rank1 = {:.1%}'.format(checkpoint['rank1']))
|
|
return start_epoch
|
|
|
|
|
|
def adjust_learning_rate(optimizer, base_lr, epoch, stepsize=20, gamma=0.1,
|
|
linear_decay=False, final_lr=0, max_epoch=100):
|
|
"""Adjusts learning rate.
|
|
|
|
Deprecated.
|
|
"""
|
|
if linear_decay:
|
|
# linearly decay learning rate from base_lr to final_lr
|
|
frac_done = epoch / max_epoch
|
|
lr = frac_done * final_lr + (1. - frac_done) * base_lr
|
|
else:
|
|
# decay learning rate by gamma for every stepsize
|
|
lr = base_lr * (gamma ** (epoch // stepsize))
|
|
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = lr
|
|
|
|
|
|
def set_bn_to_eval(m):
|
|
"""Sets BatchNorm layers to eval mode."""
|
|
# 1. no update for running mean and var
|
|
# 2. scale and shift parameters are still trainable
|
|
classname = m.__class__.__name__
|
|
if classname.find('BatchNorm') != -1:
|
|
m.eval()
|
|
|
|
|
|
def open_all_layers(model):
|
|
"""Opens all layers in model for training.
|
|
|
|
Examples::
|
|
>>> from torchreid.utils import open_all_layers
|
|
>>> open_all_layers(model)
|
|
"""
|
|
model.train()
|
|
for p in model.parameters():
|
|
p.requires_grad = True
|
|
|
|
|
|
def open_specified_layers(model, open_layers):
|
|
"""Opens specified layers in model for training while keeping
|
|
other layers frozen.
|
|
|
|
Args:
|
|
model (nn.Module): neural net model.
|
|
open_layers (str or list): layers open for training.
|
|
|
|
Examples::
|
|
>>> from torchreid.utils import open_specified_layers
|
|
>>> # Only model.classifier will be updated.
|
|
>>> open_layers = 'classifier'
|
|
>>> open_specified_layers(model, open_layers)
|
|
>>> # Only model.fc and model.classifier will be updated.
|
|
>>> open_layers = ['fc', 'classifier']
|
|
>>> open_specified_layers(model, open_layers)
|
|
"""
|
|
if isinstance(model, nn.DataParallel):
|
|
model = model.module
|
|
|
|
if isinstance(open_layers, str):
|
|
open_layers = [open_layers]
|
|
|
|
for layer in open_layers:
|
|
assert hasattr(model, layer), '"{}" is not an attribute of the model, please provide the correct name'.format(layer)
|
|
|
|
for name, module in model.named_children():
|
|
if name in open_layers:
|
|
module.train()
|
|
for p in module.parameters():
|
|
p.requires_grad = True
|
|
else:
|
|
module.eval()
|
|
for p in module.parameters():
|
|
p.requires_grad = False
|
|
|
|
|
|
def count_num_param(model):
|
|
"""Counts number of parameters in a model.
|
|
|
|
Examples::
|
|
>>> from torchreid.utils import count_num_param
|
|
>>> model_size = count_num_param(model)
|
|
"""
|
|
num_param = sum(p.numel() for p in model.parameters()) / 1e+06
|
|
|
|
if isinstance(model, nn.DataParallel):
|
|
model = model.module
|
|
|
|
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
|
|
# we ignore the classifier because it is unused at test time
|
|
num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06
|
|
return num_param
|
|
|
|
|
|
def load_pretrained_weights(model, weight_path):
|
|
"""Loads pretrianed weights to model.
|
|
|
|
Features::
|
|
- Incompatible layers (unmatched in name or size) will be ignored.
|
|
- Can automatically deal with keys containing "module.".
|
|
|
|
Args:
|
|
model (nn.Module): model.
|
|
weight_path (str): path to pretrained weights.
|
|
|
|
Examples::
|
|
>>> from torchreid.utils import load_pretrained_weights
|
|
>>> weight_path = 'log/my_model/model-best.pth.tar'
|
|
>>> load_pretrained_weights(model, weight_path)
|
|
"""
|
|
checkpoint = load_checkpoint(weight_path)
|
|
if 'state_dict' in checkpoint:
|
|
state_dict = checkpoint['state_dict']
|
|
else:
|
|
state_dict = checkpoint
|
|
|
|
model_dict = model.state_dict()
|
|
new_state_dict = OrderedDict()
|
|
matched_layers, discarded_layers = [], []
|
|
|
|
for k, v in state_dict.items():
|
|
if k.startswith('module.'):
|
|
k = k[7:] # discard module.
|
|
|
|
if k in model_dict and model_dict[k].size() == v.size():
|
|
new_state_dict[k] = v
|
|
matched_layers.append(k)
|
|
else:
|
|
discarded_layers.append(k)
|
|
|
|
model_dict.update(new_state_dict)
|
|
model.load_state_dict(model_dict)
|
|
|
|
if len(matched_layers) == 0:
|
|
warnings.warn(
|
|
'The pretrained weights "{}" cannot be loaded, '
|
|
'please check the key names manually '
|
|
'(** ignored and continue **)'.format(weight_path))
|
|
else:
|
|
print('Successfully loaded pretrained weights from "{}"'.format(weight_path))
|
|
if len(discarded_layers) > 0:
|
|
print('** The following layers are discarded '
|
|
'due to unmatched keys or layer size: {}'.format(discarded_layers)) |