deep-person-reid/torchreid/engine/engine.py

418 lines
17 KiB
Python

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import sys
import os
import os.path as osp
import time
import datetime
import numpy as np
import cv2
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torchreid
from torchreid.utils import AverageMeter, visualize_ranked_results, save_checkpoint, re_ranking, mkdir_if_missing
from torchreid.losses import DeepSupervision
from torchreid import metrics
GRID_SPACING = 10
class Engine(object):
r"""A generic base Engine class for both image- and video-reid.
Args:
datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager``
or ``torchreid.data.VideoDataManager``.
model (nn.Module): model instance.
optimizer (Optimizer): an Optimizer.
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
use_gpu (bool, optional): use gpu. Default is True.
"""
def __init__(self, datamanager, model, optimizer=None, scheduler=None, use_gpu=True):
self.datamanager = datamanager
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.use_gpu = (torch.cuda.is_available() and use_gpu)
self.writer = None
# check attributes
if not isinstance(self.model, nn.Module):
raise TypeError('model must be an instance of nn.Module')
def run(self, save_dir='log', max_epoch=0, start_epoch=0, fixbase_epoch=0, open_layers=None,
start_eval=0, eval_freq=-1, test_only=False, print_freq=10,
dist_metric='euclidean', normalize_feature=False, visrank=False, visrank_topk=10,
use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False, visactmap=False):
r"""A unified pipeline for training and evaluating a model.
Args:
save_dir (str): directory to save model.
max_epoch (int): maximum epoch.
start_epoch (int, optional): starting epoch. Default is 0.
fixbase_epoch (int, optional): number of epochs to train ``open_layers`` (new layers)
while keeping base layers frozen. Default is 0. ``fixbase_epoch`` is counted
in ``max_epoch``.
open_layers (str or list, optional): layers (attribute names) open for training.
start_eval (int, optional): from which epoch to start evaluation. Default is 0.
eval_freq (int, optional): evaluation frequency. Default is -1 (meaning evaluation
is only performed at the end of training).
test_only (bool, optional): if True, only runs evaluation on test datasets.
Default is False.
print_freq (int, optional): print_frequency. Default is 10.
dist_metric (str, optional): distance metric used to compute distance matrix
between query and gallery. Default is "euclidean".
normalize_feature (bool, optional): performs L2 normalization on feature vectors before
computing feature distance. Default is False.
visrank (bool, optional): visualizes ranked results. Default is False. It is recommended to
enable ``visrank`` when ``test_only`` is True. The ranked images will be saved to
"save_dir/visrank_dataset", e.g. "save_dir/visrank_market1501".
visrank_topk (int, optional): top-k ranked images to be visualized. Default is 10.
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
Default is False. This should be enabled when using cuhk03 classic split.
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
rerank (bool, optional): uses person re-ranking (by Zhong et al. CVPR'17).
Default is False. This is only enabled when test_only=True.
visactmap (bool, optional): visualizes activation maps. Default is False.
"""
trainloader, testloader = self.datamanager.return_dataloaders()
if visrank and not test_only:
raise ValueError('visrank=True is valid only if test_only=True')
if test_only:
self.test(
0,
testloader,
dist_metric=dist_metric,
normalize_feature=normalize_feature,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
use_metric_cuhk03=use_metric_cuhk03,
ranks=ranks,
rerank=rerank
)
return
if self.writer is None:
self.writer = SummaryWriter(log_dir=save_dir)
if visactmap:
self.visactmap(testloader, save_dir, self.datamanager.width, self.datamanager.height, print_freq)
return
time_start = time.time()
print('=> Start training')
for epoch in range(start_epoch, max_epoch):
self.train(epoch, max_epoch, trainloader, fixbase_epoch, open_layers, print_freq)
if (epoch+1)>=start_eval and eval_freq>0 and (epoch+1)%eval_freq==0 and (epoch+1)!=max_epoch:
rank1 = self.test(
epoch,
testloader,
dist_metric=dist_metric,
normalize_feature=normalize_feature,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
use_metric_cuhk03=use_metric_cuhk03,
ranks=ranks
)
self._save_checkpoint(epoch, rank1, save_dir)
if max_epoch > 0:
print('=> Final test')
rank1 = self.test(
epoch,
testloader,
dist_metric=dist_metric,
normalize_feature=normalize_feature,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
use_metric_cuhk03=use_metric_cuhk03,
ranks=ranks
)
self._save_checkpoint(epoch, rank1, save_dir)
elapsed = round(time.time() - time_start)
elapsed = str(datetime.timedelta(seconds=elapsed))
print('Elapsed {}'.format(elapsed))
if self.writer is None:
self.writer.close()
def train(self):
r"""Performs training on source datasets for one epoch.
This will be called every epoch in ``run()``, e.g.
.. code-block:: python
for epoch in range(start_epoch, max_epoch):
self.train(some_arguments)
.. note::
This must be implemented in subclasses.
"""
raise NotImplementedError
def test(self, epoch, testloader, dist_metric='euclidean', normalize_feature=False,
visrank=False, visrank_topk=10, save_dir='', use_metric_cuhk03=False,
ranks=[1, 5, 10, 20], rerank=False):
r"""Tests model on target datasets.
.. note::
This function has been called in ``run()``.
.. note::
The test pipeline implemented in this function suits both image- and
video-reid. In general, a subclass of Engine only needs to re-implement
``_extract_features()`` and ``_parse_data_for_eval()`` (most of the time),
but not a must. Please refer to the source code for more details.
"""
targets = list(testloader.keys())
for name in targets:
domain = 'source' if name in self.datamanager.sources else 'target'
print('##### Evaluating {} ({}) #####'.format(name, domain))
queryloader = testloader[name]['query']
galleryloader = testloader[name]['gallery']
rank1 = self._evaluate(
epoch,
dataset_name=name,
queryloader=queryloader,
galleryloader=galleryloader,
dist_metric=dist_metric,
normalize_feature=normalize_feature,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
use_metric_cuhk03=use_metric_cuhk03,
ranks=ranks,
rerank=rerank
)
return rank1
@torch.no_grad()
def _evaluate(self, epoch, dataset_name='', queryloader=None, galleryloader=None,
dist_metric='euclidean', normalize_feature=False, visrank=False,
visrank_topk=10, save_dir='', use_metric_cuhk03=False, ranks=[1, 5, 10, 20],
rerank=False):
batch_time = AverageMeter()
print('Extracting features from query set ...')
qf, q_pids, q_camids = [], [], [] # query features, query person IDs and query camera IDs
for batch_idx, data in enumerate(queryloader):
imgs, pids, camids = self._parse_data_for_eval(data)
if self.use_gpu:
imgs = imgs.cuda()
end = time.time()
features = self._extract_features(imgs)
batch_time.update(time.time() - end)
features = features.data.cpu()
qf.append(features)
q_pids.extend(pids)
q_camids.extend(camids)
qf = torch.cat(qf, 0)
q_pids = np.asarray(q_pids)
q_camids = np.asarray(q_camids)
print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))
print('Extracting features from gallery set ...')
gf, g_pids, g_camids = [], [], [] # gallery features, gallery person IDs and gallery camera IDs
end = time.time()
for batch_idx, data in enumerate(galleryloader):
imgs, pids, camids = self._parse_data_for_eval(data)
if self.use_gpu:
imgs = imgs.cuda()
end = time.time()
features = self._extract_features(imgs)
batch_time.update(time.time() - end)
features = features.data.cpu()
gf.append(features)
g_pids.extend(pids)
g_camids.extend(camids)
gf = torch.cat(gf, 0)
g_pids = np.asarray(g_pids)
g_camids = np.asarray(g_camids)
print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))
print('Speed: {:.4f} sec/batch'.format(batch_time.avg))
if normalize_feature:
print('Normalzing features with L2 norm ...')
qf = F.normalize(qf, p=2, dim=1)
gf = F.normalize(gf, p=2, dim=1)
print('Computing distance matrix with metric={} ...'.format(dist_metric))
distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
distmat = distmat.numpy()
if rerank:
print('Applying person re-ranking ...')
distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
distmat_gg = metrics.compute_distance_matrix(gf, gf, dist_metric)
distmat = re_ranking(distmat, distmat_qq, distmat_gg)
print('Computing CMC and mAP ...')
cmc, mAP = metrics.evaluate_rank(
distmat,
q_pids,
g_pids,
q_camids,
g_camids,
use_metric_cuhk03=use_metric_cuhk03
)
print('** Results **')
print('mAP: {:.1%}'.format(mAP))
print('CMC curve')
for r in ranks:
print('Rank-{:<3}: {:.1%}'.format(r, cmc[r-1]))
if visrank:
visualize_ranked_results(
distmat,
self.datamanager.return_testdataset_by_name(dataset_name),
self.datamanager.data_type,
width=self.datamanager.width,
height=self.datamanager.height,
save_dir=osp.join(save_dir, 'visrank_'+dataset_name),
topk=visrank_topk
)
return cmc[0]
@torch.no_grad()
def visactmap(self, testloader, save_dir, width, height, print_freq):
"""Visualizes CNN activation maps to see where the CNN focuses on to extract features.
This function takes as input the query images of target datasets
Reference:
- Zagoruyko and Komodakis. Paying more attention to attention: Improving the
performance of convolutional neural networks via attention transfer. ICLR, 2017
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
"""
self.model.eval()
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
for target in list(testloader.keys()):
queryloader = testloader[target]['query']
# original images and activation maps are saved individually
actmap_dir = osp.join(save_dir, 'actmap_'+target)
mkdir_if_missing(actmap_dir)
print('Visualizing activation maps for {} ...'.format(target))
for batch_idx, data in enumerate(queryloader):
imgs, paths = data[0], data[3]
if self.use_gpu:
imgs = imgs.cuda()
# forward to get convolutional feature maps
try:
outputs = self.model(imgs, return_featuremaps=True)
except TypeError:
raise TypeError('forward() got unexpected keyword argument "return_featuremaps". ' \
'Please add return_featuremaps as an input argument to forward(). When ' \
'return_featuremaps=True, return feature maps only.')
if outputs.dim() != 4:
raise ValueError('The model output is supposed to have ' \
'shape of (b, c, h, w), i.e. 4 dimensions, but got {} dimensions. '
'Please make sure you set the model output at eval mode '
'to be the last convolutional feature maps'.format(outputs.dim()))
# compute activation maps
outputs = (outputs**2).sum(1)
b, h, w = outputs.size()
outputs = outputs.view(b, h*w)
outputs = F.normalize(outputs, p=2, dim=1)
outputs = outputs.view(b, h, w)
if self.use_gpu:
imgs, outputs = imgs.cpu(), outputs.cpu()
for j in range(outputs.size(0)):
# get image name
path = paths[j]
imname = osp.basename(osp.splitext(path)[0])
# RGB image
img = imgs[j, ...]
for t, m, s in zip(img, imagenet_mean, imagenet_std):
t.mul_(s).add_(m).clamp_(0, 1)
img_np = np.uint8(np.floor(img.numpy() * 255))
img_np = img_np.transpose((1, 2, 0)) # (c, h, w) -> (h, w, c)
# activation map
am = outputs[j, ...].numpy()
am = cv2.resize(am, (width, height))
am = 255 * (am - np.max(am)) / (np.max(am) - np.min(am) + 1e-12)
am = np.uint8(np.floor(am))
am = cv2.applyColorMap(am, cv2.COLORMAP_JET)
# overlapped
overlapped = img_np * 0.3 + am * 0.7
overlapped[overlapped>255] = 255
overlapped = overlapped.astype(np.uint8)
# save images in a single figure (add white spacing between images)
# from left to right: original image, activation map, overlapped image
grid_img = 255 * np.ones((height, 3*width+2*GRID_SPACING, 3), dtype=np.uint8)
grid_img[:, :width, :] = img_np[:, :, ::-1]
grid_img[:, width+GRID_SPACING: 2*width+GRID_SPACING, :] = am
grid_img[:, 2*width+2*GRID_SPACING:, :] = overlapped
cv2.imwrite(osp.join(actmap_dir, imname+'.jpg'), grid_img)
if (batch_idx+1) % print_freq == 0:
print('- done batch {}/{}'.format(batch_idx+1, len(queryloader)))
def _compute_loss(self, criterion, outputs, targets):
if isinstance(outputs, (tuple, list)):
loss = DeepSupervision(criterion, outputs, targets)
else:
loss = criterion(outputs, targets)
return loss
def _extract_features(self, input):
self.model.eval()
return self.model(input)
def _parse_data_for_train(self, data):
imgs = data[0]
pids = data[1]
return imgs, pids
def _parse_data_for_eval(self, data):
imgs = data[0]
pids = data[1]
camids = data[2]
return imgs, pids, camids
def _save_checkpoint(self, epoch, rank1, save_dir, is_best=False):
save_checkpoint({
'state_dict': self.model.state_dict(),
'epoch': epoch + 1,
'rank1': rank1,
'optimizer': self.optimizer.state_dict(),
}, save_dir, is_best=is_best)