deep-person-reid/torchreid/utils/torchtools.py

124 lines
4.0 KiB
Python
Raw Normal View History

2018-07-04 17:32:32 +08:00
from __future__ import absolute_import
from __future__ import print_function
2018-07-04 17:32:32 +08:00
from __future__ import division
import torch
2018-07-06 18:03:54 +08:00
import torch.nn as nn
2018-07-04 17:32:32 +08:00
2018-11-08 06:02:23 +08:00
def adjust_learning_rate(optimizer, base_lr, epoch, stepsize=20, gamma=0.1,
linear_decay=False, final_lr=0, max_epoch=100):
if linear_decay:
# linearly decay learning rate from base_lr to final_lr
frac_done = epoch / max_epoch
lr = frac_done * final_lr + (1. - frac_done) * base_lr
else:
# decay learning rate by gamma for every stepsize
lr = base_lr * (gamma ** (epoch // stepsize))
2018-07-04 17:32:32 +08:00
for param_group in optimizer.param_groups:
2018-07-04 22:53:50 +08:00
param_group['lr'] = lr
def set_bn_to_eval(m):
2018-07-06 18:03:54 +08:00
# 1. no update for running mean and var
# 2. scale and shift parameters are still trainable
2018-07-04 22:53:50 +08:00
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
2018-07-06 18:03:54 +08:00
2018-11-09 07:04:50 +08:00
def open_all_layers(model):
"""
Open all layers in model for training.
Args:
- model (nn.Module): neural net model.
"""
model.train()
for p in model.parameters():
p.requires_grad = True
def open_specified_layers(model, open_layers):
"""
Open specified layers in model for training while keeping
other layers frozen.
Args:
- model (nn.Module): neural net model.
- open_layers (list): list of layer names.
"""
if isinstance(model, nn.DataParallel):
model = model.module
2018-11-09 07:04:50 +08:00
for layer in open_layers:
2019-01-31 06:41:47 +08:00
assert hasattr(model, layer), '"{}" is not an attribute of the model, please provide the correct name'.format(layer)
2018-11-09 07:04:50 +08:00
for name, module in model.named_children():
if name in open_layers:
module.train()
for p in module.parameters():
p.requires_grad = True
else:
module.eval()
for p in module.parameters():
p.requires_grad = False
2018-07-06 18:03:54 +08:00
def count_num_param(model):
num_param = sum(p.numel() for p in model.parameters()) / 1e+06
if isinstance(model, nn.DataParallel):
model = model.module
2018-07-06 18:03:54 +08:00
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
# we ignore the classifier because it is unused at test time
num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06
2019-01-28 07:15:38 +08:00
return num_param
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
2019-02-01 08:50:28 +08:00
if isinstance(output, (tuple, list)):
output = output[0]
2019-01-28 07:15:38 +08:00
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
acc = correct_k.mul_(100.0 / batch_size)
res.append(acc.item())
return res
def load_pretrained_weights(model, weight_path):
checkpoint = torch.load(weight_path)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
print('ERROR: the pretrained weights "{}" cannot be loaded, please check the key names manually (ignored and continue)'.format(weight_path))
else:
print('Successfully loaded pretrained weights from "{}"'.format(weight_path))
if len(discarded_layers) > 0:
print("* The following layers are discarded due to unmatched keys or layer size: {}".format(discarded_layers))