2019-03-20 01:26:08 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import division
|
|
|
|
|
|
|
|
import os.path as osp
|
|
|
|
import time
|
|
|
|
import datetime
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2019-05-10 06:11:10 +08:00
|
|
|
from torch.nn import functional as F
|
2019-08-23 05:41:21 +08:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
2019-03-20 01:26:08 +08:00
|
|
|
|
2019-11-28 00:35:54 +08:00
|
|
|
from torchreid.utils import AverageMeter, visualize_ranked_results, save_checkpoint, re_ranking
|
2019-03-20 01:26:08 +08:00
|
|
|
from torchreid.losses import DeepSupervision
|
|
|
|
from torchreid import metrics
|
|
|
|
|
|
|
|
|
|
|
|
class Engine(object):
|
2019-05-10 05:47:55 +08:00
|
|
|
r"""A generic base Engine class for both image- and video-reid.
|
2019-03-20 01:26:08 +08:00
|
|
|
|
2019-03-24 07:09:39 +08:00
|
|
|
Args:
|
|
|
|
datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager``
|
|
|
|
or ``torchreid.data.VideoDataManager``.
|
|
|
|
model (nn.Module): model instance.
|
|
|
|
optimizer (Optimizer): an Optimizer.
|
|
|
|
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
|
2019-08-26 17:34:31 +08:00
|
|
|
use_gpu (bool, optional): use gpu. Default is True.
|
2019-03-24 07:09:39 +08:00
|
|
|
"""
|
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
def __init__(self, datamanager, model, optimizer=None, scheduler=None, use_gpu=True):
|
2019-03-21 20:53:21 +08:00
|
|
|
self.datamanager = datamanager
|
2019-03-20 01:26:08 +08:00
|
|
|
self.model = model
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.scheduler = scheduler
|
2019-08-26 17:34:31 +08:00
|
|
|
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
2019-08-23 05:41:21 +08:00
|
|
|
self.writer = None
|
2019-11-28 00:35:54 +08:00
|
|
|
self.train_loader = self.datamanager.train_loader
|
|
|
|
self.test_loader = self.datamanager.test_loader
|
|
|
|
|
|
|
|
def run(
|
|
|
|
self,
|
|
|
|
save_dir='log',
|
|
|
|
max_epoch=0,
|
|
|
|
start_epoch=0,
|
|
|
|
print_freq=10,
|
|
|
|
fixbase_epoch=0,
|
|
|
|
open_layers=None,
|
|
|
|
start_eval=0,
|
|
|
|
eval_freq=-1,
|
|
|
|
test_only=False,
|
|
|
|
dist_metric='euclidean',
|
|
|
|
normalize_feature=False,
|
|
|
|
visrank=False,
|
|
|
|
visrank_topk=10,
|
|
|
|
use_metric_cuhk03=False,
|
|
|
|
ranks=[1, 5, 10, 20],
|
|
|
|
rerank=False
|
|
|
|
):
|
2019-05-10 05:47:55 +08:00
|
|
|
r"""A unified pipeline for training and evaluating a model.
|
2019-03-24 07:09:39 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
save_dir (str): directory to save model.
|
|
|
|
max_epoch (int): maximum epoch.
|
|
|
|
start_epoch (int, optional): starting epoch. Default is 0.
|
2019-11-28 00:35:54 +08:00
|
|
|
print_freq (int, optional): print_frequency. Default is 10.
|
2019-03-24 07:09:39 +08:00
|
|
|
fixbase_epoch (int, optional): number of epochs to train ``open_layers`` (new layers)
|
2019-05-24 22:36:10 +08:00
|
|
|
while keeping base layers frozen. Default is 0. ``fixbase_epoch`` is counted
|
2019-03-24 19:06:12 +08:00
|
|
|
in ``max_epoch``.
|
2019-03-24 07:09:39 +08:00
|
|
|
open_layers (str or list, optional): layers (attribute names) open for training.
|
|
|
|
start_eval (int, optional): from which epoch to start evaluation. Default is 0.
|
|
|
|
eval_freq (int, optional): evaluation frequency. Default is -1 (meaning evaluation
|
|
|
|
is only performed at the end of training).
|
|
|
|
test_only (bool, optional): if True, only runs evaluation on test datasets.
|
|
|
|
Default is False.
|
|
|
|
dist_metric (str, optional): distance metric used to compute distance matrix
|
|
|
|
between query and gallery. Default is "euclidean".
|
2019-05-10 06:11:10 +08:00
|
|
|
normalize_feature (bool, optional): performs L2 normalization on feature vectors before
|
|
|
|
computing feature distance. Default is False.
|
2019-08-04 00:35:06 +08:00
|
|
|
visrank (bool, optional): visualizes ranked results. Default is False. It is recommended to
|
|
|
|
enable ``visrank`` when ``test_only`` is True. The ranked images will be saved to
|
|
|
|
"save_dir/visrank_dataset", e.g. "save_dir/visrank_market1501".
|
2019-08-04 06:01:21 +08:00
|
|
|
visrank_topk (int, optional): top-k ranked images to be visualized. Default is 10.
|
2019-03-24 07:09:39 +08:00
|
|
|
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
|
|
|
|
Default is False. This should be enabled when using cuhk03 classic split.
|
|
|
|
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
|
2019-05-10 06:11:10 +08:00
|
|
|
rerank (bool, optional): uses person re-ranking (by Zhong et al. CVPR'17).
|
2019-05-10 05:47:55 +08:00
|
|
|
Default is False. This is only enabled when test_only=True.
|
2019-03-24 07:09:39 +08:00
|
|
|
"""
|
2019-03-20 01:26:08 +08:00
|
|
|
|
2019-08-04 00:35:06 +08:00
|
|
|
if visrank and not test_only:
|
2019-11-28 00:35:54 +08:00
|
|
|
raise ValueError('visrank can be set to True only if test_only=True')
|
2019-08-04 00:35:06 +08:00
|
|
|
|
2019-03-20 01:26:08 +08:00
|
|
|
if test_only:
|
2019-03-24 07:09:39 +08:00
|
|
|
self.test(
|
|
|
|
0,
|
|
|
|
dist_metric=dist_metric,
|
2019-05-10 06:11:10 +08:00
|
|
|
normalize_feature=normalize_feature,
|
2019-03-24 07:09:39 +08:00
|
|
|
visrank=visrank,
|
|
|
|
visrank_topk=visrank_topk,
|
|
|
|
save_dir=save_dir,
|
|
|
|
use_metric_cuhk03=use_metric_cuhk03,
|
2019-05-10 05:47:55 +08:00
|
|
|
ranks=ranks,
|
|
|
|
rerank=rerank
|
2019-03-24 07:09:39 +08:00
|
|
|
)
|
2019-03-20 01:26:08 +08:00
|
|
|
return
|
|
|
|
|
2019-08-23 05:41:21 +08:00
|
|
|
if self.writer is None:
|
|
|
|
self.writer = SummaryWriter(log_dir=save_dir)
|
|
|
|
|
2019-03-20 01:26:08 +08:00
|
|
|
time_start = time.time()
|
|
|
|
print('=> Start training')
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, max_epoch):
|
2019-11-27 19:09:13 +08:00
|
|
|
self.train(
|
2019-11-28 00:35:54 +08:00
|
|
|
epoch,
|
|
|
|
max_epoch,
|
|
|
|
writer,
|
2019-11-27 19:09:13 +08:00
|
|
|
print_freq=print_freq,
|
|
|
|
fixbase_epoch=fixbase_epoch,
|
|
|
|
open_layers=open_layers
|
|
|
|
)
|
2019-03-20 01:26:08 +08:00
|
|
|
|
2019-06-14 06:11:07 +08:00
|
|
|
if (epoch+1)>=start_eval and eval_freq>0 and (epoch+1)%eval_freq==0 and (epoch+1)!=max_epoch:
|
2019-03-24 07:09:39 +08:00
|
|
|
rank1 = self.test(
|
|
|
|
epoch,
|
|
|
|
dist_metric=dist_metric,
|
2019-05-10 06:11:10 +08:00
|
|
|
normalize_feature=normalize_feature,
|
2019-03-24 07:09:39 +08:00
|
|
|
visrank=visrank,
|
|
|
|
visrank_topk=visrank_topk,
|
|
|
|
save_dir=save_dir,
|
|
|
|
use_metric_cuhk03=use_metric_cuhk03,
|
|
|
|
ranks=ranks
|
|
|
|
)
|
2019-03-20 01:26:08 +08:00
|
|
|
self._save_checkpoint(epoch, rank1, save_dir)
|
|
|
|
|
2019-03-24 19:06:12 +08:00
|
|
|
if max_epoch > 0:
|
|
|
|
print('=> Final test')
|
|
|
|
rank1 = self.test(
|
|
|
|
epoch,
|
|
|
|
dist_metric=dist_metric,
|
2019-05-10 06:11:10 +08:00
|
|
|
normalize_feature=normalize_feature,
|
2019-03-24 19:06:12 +08:00
|
|
|
visrank=visrank,
|
|
|
|
visrank_topk=visrank_topk,
|
|
|
|
save_dir=save_dir,
|
|
|
|
use_metric_cuhk03=use_metric_cuhk03,
|
|
|
|
ranks=ranks
|
|
|
|
)
|
|
|
|
self._save_checkpoint(epoch, rank1, save_dir)
|
2019-03-20 01:26:08 +08:00
|
|
|
|
|
|
|
elapsed = round(time.time() - time_start)
|
|
|
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
|
|
|
print('Elapsed {}'.format(elapsed))
|
2019-11-27 19:09:13 +08:00
|
|
|
if self.writer is not None:
|
2019-08-23 05:41:21 +08:00
|
|
|
self.writer.close()
|
2019-03-20 01:26:08 +08:00
|
|
|
|
|
|
|
def train(self):
|
2019-05-10 05:47:55 +08:00
|
|
|
r"""Performs training on source datasets for one epoch.
|
2019-03-24 07:09:39 +08:00
|
|
|
|
|
|
|
This will be called every epoch in ``run()``, e.g.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, max_epoch):
|
|
|
|
self.train(some_arguments)
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
2019-08-04 00:35:06 +08:00
|
|
|
This must be implemented in subclasses.
|
2019-03-24 07:09:39 +08:00
|
|
|
"""
|
2019-03-20 01:26:08 +08:00
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-11-28 00:35:54 +08:00
|
|
|
def test(
|
|
|
|
self,
|
|
|
|
epoch,
|
|
|
|
dist_metric='euclidean',
|
|
|
|
normalize_feature=False,
|
|
|
|
visrank=False,
|
|
|
|
visrank_topk=10,
|
|
|
|
save_dir='',
|
|
|
|
use_metric_cuhk03=False,
|
|
|
|
ranks=[1, 5, 10, 20],
|
|
|
|
rerank=False
|
|
|
|
):
|
2019-05-10 05:47:55 +08:00
|
|
|
r"""Tests model on target datasets.
|
2019-03-24 07:09:39 +08:00
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
2019-08-04 00:35:06 +08:00
|
|
|
This function has been called in ``run()``.
|
2019-03-24 07:09:39 +08:00
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
|
|
|
The test pipeline implemented in this function suits both image- and
|
|
|
|
video-reid. In general, a subclass of Engine only needs to re-implement
|
2019-08-04 00:35:06 +08:00
|
|
|
``_extract_features()`` and ``_parse_data_for_eval()`` (most of the time),
|
2019-03-24 07:09:39 +08:00
|
|
|
but not a must. Please refer to the source code for more details.
|
|
|
|
"""
|
2019-11-28 00:35:54 +08:00
|
|
|
targets = list(self.test_loader.keys())
|
2019-03-20 01:26:08 +08:00
|
|
|
|
2019-03-24 07:09:39 +08:00
|
|
|
for name in targets:
|
2019-03-28 06:30:08 +08:00
|
|
|
domain = 'source' if name in self.datamanager.sources else 'target'
|
2019-05-24 22:36:10 +08:00
|
|
|
print('##### Evaluating {} ({}) #####'.format(name, domain))
|
2019-11-28 00:35:54 +08:00
|
|
|
query_loader = self.test_loader[name]['query']
|
|
|
|
gallery_loader = self.test_loader[name]['gallery']
|
2019-03-20 01:26:08 +08:00
|
|
|
rank1 = self._evaluate(
|
2019-03-24 07:09:39 +08:00
|
|
|
epoch,
|
|
|
|
dataset_name=name,
|
2019-11-28 00:35:54 +08:00
|
|
|
query_loader=query_loader,
|
|
|
|
gallery_loader=gallery_loader,
|
2019-03-24 07:09:39 +08:00
|
|
|
dist_metric=dist_metric,
|
2019-05-10 06:11:10 +08:00
|
|
|
normalize_feature=normalize_feature,
|
2019-03-24 07:09:39 +08:00
|
|
|
visrank=visrank,
|
|
|
|
visrank_topk=visrank_topk,
|
|
|
|
save_dir=save_dir,
|
|
|
|
use_metric_cuhk03=use_metric_cuhk03,
|
2019-05-10 05:47:55 +08:00
|
|
|
ranks=ranks,
|
|
|
|
rerank=rerank
|
2019-03-20 01:26:08 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
return rank1
|
|
|
|
|
|
|
|
@torch.no_grad()
|
2019-11-28 00:35:54 +08:00
|
|
|
def _evaluate(
|
|
|
|
self,
|
|
|
|
epoch,
|
|
|
|
dataset_name='',
|
|
|
|
query_loader=None,
|
|
|
|
gallery_loader=None,
|
|
|
|
dist_metric='euclidean',
|
|
|
|
normalize_feature=False,
|
|
|
|
visrank=False,
|
|
|
|
visrank_topk=10,
|
|
|
|
save_dir='',
|
|
|
|
use_metric_cuhk03=False,
|
|
|
|
ranks=[1, 5, 10, 20],
|
|
|
|
rerank=False
|
|
|
|
):
|
2019-03-20 01:26:08 +08:00
|
|
|
batch_time = AverageMeter()
|
|
|
|
|
2019-11-28 00:35:54 +08:00
|
|
|
def _feature_extraction(data_loader):
|
|
|
|
f_, pids_, camids_ = [], [], []
|
|
|
|
for batch_idx, data in enumerate(data_loader):
|
|
|
|
imgs, pids, camids = self._parse_data_for_eval(data)
|
|
|
|
if self.use_gpu:
|
|
|
|
imgs = imgs.cuda()
|
|
|
|
end = time.time()
|
|
|
|
features = self._extract_features(imgs)
|
|
|
|
batch_time.update(time.time() - end)
|
|
|
|
features = features.data.cpu()
|
|
|
|
f_.append(features)
|
|
|
|
pids_.extend(pids)
|
|
|
|
camids_.extend(camids)
|
|
|
|
f_ = torch.cat(f_, 0)
|
|
|
|
pids_ = np.asarray(pids_)
|
|
|
|
camids_ = np.asarray(camids_)
|
|
|
|
return f_, pids_, camids_
|
|
|
|
|
2019-03-20 01:26:08 +08:00
|
|
|
print('Extracting features from query set ...')
|
2019-11-28 00:35:54 +08:00
|
|
|
qf, q_pids, q_camids = _feature_extraction(query_loader)
|
2019-03-20 01:26:08 +08:00
|
|
|
print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))
|
|
|
|
|
|
|
|
print('Extracting features from gallery set ...')
|
2019-11-28 00:35:54 +08:00
|
|
|
gf, g_pids, g_camids = _feature_extraction(gallery_loader)
|
2019-03-20 01:26:08 +08:00
|
|
|
print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))
|
|
|
|
|
|
|
|
print('Speed: {:.4f} sec/batch'.format(batch_time.avg))
|
|
|
|
|
2019-05-10 06:11:10 +08:00
|
|
|
if normalize_feature:
|
2019-05-10 06:21:04 +08:00
|
|
|
print('Normalzing features with L2 norm ...')
|
2019-05-10 06:11:10 +08:00
|
|
|
qf = F.normalize(qf, p=2, dim=1)
|
|
|
|
gf = F.normalize(gf, p=2, dim=1)
|
|
|
|
|
2019-05-10 06:21:04 +08:00
|
|
|
print('Computing distance matrix with metric={} ...'.format(dist_metric))
|
2019-03-20 01:26:08 +08:00
|
|
|
distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
|
2019-03-22 08:14:41 +08:00
|
|
|
distmat = distmat.numpy()
|
2019-03-20 01:26:08 +08:00
|
|
|
|
2019-05-10 05:47:55 +08:00
|
|
|
if rerank:
|
|
|
|
print('Applying person re-ranking ...')
|
|
|
|
distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
|
|
|
|
distmat_gg = metrics.compute_distance_matrix(gf, gf, dist_metric)
|
|
|
|
distmat = re_ranking(distmat, distmat_qq, distmat_gg)
|
|
|
|
|
2019-03-20 01:26:08 +08:00
|
|
|
print('Computing CMC and mAP ...')
|
|
|
|
cmc, mAP = metrics.evaluate_rank(
|
|
|
|
distmat,
|
|
|
|
q_pids,
|
|
|
|
g_pids,
|
|
|
|
q_camids,
|
|
|
|
g_camids,
|
|
|
|
use_metric_cuhk03=use_metric_cuhk03
|
|
|
|
)
|
|
|
|
|
2019-03-21 20:53:21 +08:00
|
|
|
print('** Results **')
|
2019-03-20 01:26:08 +08:00
|
|
|
print('mAP: {:.1%}'.format(mAP))
|
|
|
|
print('CMC curve')
|
|
|
|
for r in ranks:
|
|
|
|
print('Rank-{:<3}: {:.1%}'.format(r, cmc[r-1]))
|
|
|
|
|
|
|
|
if visrank:
|
|
|
|
visualize_ranked_results(
|
|
|
|
distmat,
|
2019-11-28 00:35:54 +08:00
|
|
|
self.datamanager.return_query_and_gallery_by_name(dataset_name),
|
2019-08-04 00:35:06 +08:00
|
|
|
self.datamanager.data_type,
|
|
|
|
width=self.datamanager.width,
|
|
|
|
height=self.datamanager.height,
|
|
|
|
save_dir=osp.join(save_dir, 'visrank_'+dataset_name),
|
2019-03-20 01:26:08 +08:00
|
|
|
topk=visrank_topk
|
|
|
|
)
|
|
|
|
|
|
|
|
return cmc[0]
|
|
|
|
|
|
|
|
def _compute_loss(self, criterion, outputs, targets):
|
|
|
|
if isinstance(outputs, (tuple, list)):
|
|
|
|
loss = DeepSupervision(criterion, outputs, targets)
|
|
|
|
else:
|
|
|
|
loss = criterion(outputs, targets)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def _extract_features(self, input):
|
|
|
|
self.model.eval()
|
|
|
|
return self.model(input)
|
|
|
|
|
|
|
|
def _parse_data_for_train(self, data):
|
|
|
|
imgs = data[0]
|
|
|
|
pids = data[1]
|
|
|
|
return imgs, pids
|
|
|
|
|
|
|
|
def _parse_data_for_eval(self, data):
|
|
|
|
imgs = data[0]
|
|
|
|
pids = data[1]
|
|
|
|
camids = data[2]
|
|
|
|
return imgs, pids, camids
|
|
|
|
|
|
|
|
def _save_checkpoint(self, epoch, rank1, save_dir, is_best=False):
|
|
|
|
save_checkpoint({
|
|
|
|
'state_dict': self.model.state_dict(),
|
|
|
|
'epoch': epoch + 1,
|
|
|
|
'rank1': rank1,
|
|
|
|
'optimizer': self.optimizer.state_dict(),
|
2019-06-14 06:11:07 +08:00
|
|
|
}, save_dir, is_best=is_best)
|