deep-person-reid/torchreid/utils/torchtools.py

194 lines
6.3 KiB
Python

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
__all__ = [
'save_checkpoint',
'load_checkpoint',
'resume_from_checkpoint',
'open_all_layers',
'open_specified_layers',
'count_num_param',
'load_pretrained_weights'
]
from collections import OrderedDict
import shutil
import warnings
import os
import os.path as osp
from functools import partial
import pickle
import torch
import torch.nn as nn
from .tools import mkdir_if_missing
def save_checkpoint(state, save_dir, is_best=False, remove_module_from_keys=False):
mkdir_if_missing(save_dir)
if remove_module_from_keys:
# remove 'module.' in state_dict's keys
state_dict = state['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
new_state_dict[k] = v
state['state_dict'] = new_state_dict
# save
epoch = state['epoch']
fpath = osp.join(save_dir, 'model.pth.tar-' + str(epoch))
torch.save(state, fpath)
print('Checkpoint saved to "{}"'.format(fpath))
if is_best:
shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))
def load_checkpoint(fpath):
map_location = None if torch.cuda.is_available() else 'cpu'
try:
checkpoint = torch.load(fpath, map_location=map_location)
except UnicodeDecodeError:
pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
checkpoint = torch.load(fpath, pickle_module=pickle, map_location=map_location)
except Exception:
print('Unable to load checkpoint from "{}"'.format(fpath))
raise
return checkpoint
def resume_from_checkpoint(fpath, model, optimizer=None):
print('Loading checkpoint from "{}"'.format(fpath))
checkpoint = load_checkpoint(fpath)
model.load_state_dict(checkpoint['state_dict'])
print('Loaded model weights')
if optimizer is not None and 'optimizer' in checkpoint.keys():
optimizer.load_state_dict(checkpoint['optimizer'])
print('Loaded optimizer')
start_epoch = checkpoint['epoch']
print('Last epoch = {}'.format(start_epoch))
if 'rank1' in checkpoint.keys():
print('Last rank1 = {:.1%}'.format(checkpoint['rank1']))
return start_epoch
def adjust_learning_rate(optimizer, base_lr, epoch, stepsize=20, gamma=0.1,
linear_decay=False, final_lr=0, max_epoch=100):
if linear_decay:
# linearly decay learning rate from base_lr to final_lr
frac_done = epoch / max_epoch
lr = frac_done * final_lr + (1. - frac_done) * base_lr
else:
# decay learning rate by gamma for every stepsize
lr = base_lr * (gamma ** (epoch // stepsize))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def set_bn_to_eval(m):
# 1. no update for running mean and var
# 2. scale and shift parameters are still trainable
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
def open_all_layers(model):
"""Opens all layers in model for training.
Args:
model (nn.Module): neural net model.
"""
model.train()
for p in model.parameters():
p.requires_grad = True
def open_specified_layers(model, open_layers):
"""Opens specified layers in model for training while keeping
other layers frozen.
Args:
model (nn.Module): neural net model.
open_layers (str or list): layers open for training.
"""
if isinstance(model, nn.DataParallel):
model = model.module
if isinstance(open_layers, str):
open_layers = [open_layers]
for layer in open_layers:
assert hasattr(model, layer), '"{}" is not an attribute of the model, please provide the correct name'.format(layer)
for name, module in model.named_children():
if name in open_layers:
module.train()
for p in module.parameters():
p.requires_grad = True
else:
module.eval()
for p in module.parameters():
p.requires_grad = False
def count_num_param(model):
"""Counts number of parameters in a model
Args:
model (nn.Module): neural network
"""
num_param = sum(p.numel() for p in model.parameters()) / 1e+06
if isinstance(model, nn.DataParallel):
model = model.module
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
# we ignore the classifier because it is unused at test time
num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06
return num_param
def load_pretrained_weights(model, weight_path):
"""Loads pretrianed weights to model
Incompatible layers (unmatched in name or size) will be ignored
Args:
model (nn.Module): network model, which must not be nn.DataParallel
weight_path (str): path to pretrained weights
"""
checkpoint = load_checkpoint(weight_path)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
# If the pretrained state_dict was saved as nn.DataParallel,
# keys would contain "module.", which should be ignored.
if k.startswith('module.'):
k = k[7:]
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
'The pretrained weights "{}" cannot be loaded, '
'please check the key names manually '
'(** ignored and continue **)'.format(weight_path))
else:
print('Successfully loaded pretrained weights from "{}"'.format(weight_path))
if len(discarded_layers) > 0:
print('** The following layers are discarded '
'due to unmatched keys or layer size: {}'.format(discarded_layers))