add person reranking
parent
6342fb3055
commit
5839a026c3
scripts
torchreid
engine
utils
|
@ -125,13 +125,13 @@ def init_parser():
|
|||
help='learning rate decay')
|
||||
|
||||
# ************************************************************
|
||||
# Cross entropy loss-specific setting
|
||||
# Cross entropy loss
|
||||
# ************************************************************
|
||||
parser.add_argument('--label-smooth', action='store_true',
|
||||
help='use label smoothing regularizer in cross entropy loss')
|
||||
|
||||
# ************************************************************
|
||||
# Hard triplet loss-specific setting
|
||||
# Hard triplet loss
|
||||
# ************************************************************
|
||||
parser.add_argument('--margin', type=float, default=0.3,
|
||||
help='margin for triplet loss')
|
||||
|
@ -165,6 +165,13 @@ def init_parser():
|
|||
help='distance metric')
|
||||
parser.add_argument('--ranks', type=str, default=[1, 5, 10, 20], nargs='+',
|
||||
help='cmc ranks')
|
||||
parser.add_argument('--rerank', action='store_true',
|
||||
help='use person re-ranking (by Zhong et al. CVPR2017)')
|
||||
|
||||
parser.add_argument('--visrank', action='store_true',
|
||||
help='visualize ranked results, only available in evaluation mode')
|
||||
parser.add_argument('--visrank-topk', type=int, default=20,
|
||||
help='visualize topk ranks')
|
||||
|
||||
# ************************************************************
|
||||
# Miscs
|
||||
|
@ -182,11 +189,7 @@ def init_parser():
|
|||
parser.add_argument('--gpu-devices', type=str, default='0',
|
||||
help='gpu device ids for CUDA_VISIBLE_DEVICES')
|
||||
parser.add_argument('--use-avai-gpus', action='store_true',
|
||||
help='use available gpus instead of specified devices (useful when using managed clusters)')
|
||||
parser.add_argument('--visrank', action='store_true',
|
||||
help='visualize ranked results, only available in evaluation mode')
|
||||
parser.add_argument('--visrank-topk', type=int, default=20,
|
||||
help='visualize topk ranks')
|
||||
help='use available gpus instead of specified devices')
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -278,5 +281,6 @@ def engine_run_kwargs(parsed_args):
|
|||
'visrank': parsed_args.visrank,
|
||||
'visrank_topk': parsed_args.visrank_topk,
|
||||
'use_metric_cuhk03': parsed_args.use_metric_cuhk03,
|
||||
'ranks': parsed_args.ranks
|
||||
'ranks': parsed_args.ranks,
|
||||
'rerank': parsed_args.rerank
|
||||
}
|
|
@ -12,13 +12,13 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
import torchreid
|
||||
from torchreid.utils import AverageMeter, visualize_ranked_results, save_checkpoint
|
||||
from torchreid.utils import AverageMeter, visualize_ranked_results, save_checkpoint, re_ranking
|
||||
from torchreid.losses import DeepSupervision
|
||||
from torchreid import metrics
|
||||
|
||||
|
||||
class Engine(object):
|
||||
"""A generic base Engine class for both image- and video-reid.
|
||||
r"""A generic base Engine class for both image- and video-reid.
|
||||
|
||||
Args:
|
||||
datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager``
|
||||
|
@ -43,8 +43,8 @@ class Engine(object):
|
|||
def run(self, save_dir='log', max_epoch=0, start_epoch=0, fixbase_epoch=0, open_layers=None,
|
||||
start_eval=0, eval_freq=-1, test_only=False, print_freq=10,
|
||||
dist_metric='euclidean', visrank=False, visrank_topk=20,
|
||||
use_metric_cuhk03=False, ranks=[1, 5, 10, 20]):
|
||||
"""A unified pipeline for training and evaluating a model.
|
||||
use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False):
|
||||
r"""A unified pipeline for training and evaluating a model.
|
||||
|
||||
Args:
|
||||
save_dir (str): directory to save model.
|
||||
|
@ -70,6 +70,8 @@ class Engine(object):
|
|||
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
|
||||
Default is False. This should be enabled when using cuhk03 classic split.
|
||||
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
|
||||
rerank (bool, optional): use person re-ranking (by Zhong et al. CVPR'17).
|
||||
Default is False. This is only enabled when test_only=True.
|
||||
"""
|
||||
trainloader, testloader = self.datamanager.return_dataloaders()
|
||||
|
||||
|
@ -82,7 +84,8 @@ class Engine(object):
|
|||
visrank_topk=visrank_topk,
|
||||
save_dir=save_dir,
|
||||
use_metric_cuhk03=use_metric_cuhk03,
|
||||
ranks=ranks
|
||||
ranks=ranks,
|
||||
rerank=rerank
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -131,7 +134,7 @@ class Engine(object):
|
|||
print('Elapsed {}'.format(elapsed))
|
||||
|
||||
def train(self):
|
||||
"""Performs training on source datasets for one epoch.
|
||||
r"""Performs training on source datasets for one epoch.
|
||||
|
||||
This will be called every epoch in ``run()``, e.g.
|
||||
|
||||
|
@ -147,8 +150,8 @@ class Engine(object):
|
|||
raise NotImplementedError
|
||||
|
||||
def test(self, epoch, testloader, dist_metric='euclidean', visrank=False, visrank_topk=20,
|
||||
save_dir='', use_metric_cuhk03=False, ranks=[1, 5, 10, 20]):
|
||||
"""Tests model on target datasets.
|
||||
save_dir='', use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False):
|
||||
r"""Tests model on target datasets.
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -176,6 +179,8 @@ class Engine(object):
|
|||
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
|
||||
Default is False. This should be enabled when using cuhk03 classic split.
|
||||
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
|
||||
rerank (bool, optional): use person re-ranking (by Zhong et al. CVPR'17).
|
||||
Default is False.
|
||||
"""
|
||||
targets = list(testloader.keys())
|
||||
|
||||
|
@ -194,7 +199,8 @@ class Engine(object):
|
|||
visrank_topk=visrank_topk,
|
||||
save_dir=save_dir,
|
||||
use_metric_cuhk03=use_metric_cuhk03,
|
||||
ranks=ranks
|
||||
ranks=ranks,
|
||||
rerank=rerank
|
||||
)
|
||||
|
||||
return rank1
|
||||
|
@ -202,13 +208,13 @@ class Engine(object):
|
|||
@torch.no_grad()
|
||||
def _evaluate(self, epoch, dataset_name='', queryloader=None, galleryloader=None,
|
||||
dist_metric='euclidean', visrank=False, visrank_topk=20, save_dir='',
|
||||
use_metric_cuhk03=False, ranks=[1, 5, 10, 20]):
|
||||
use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False):
|
||||
batch_time = AverageMeter()
|
||||
|
||||
self.model.eval()
|
||||
|
||||
print('Extracting features from query set ...')
|
||||
qf, q_pids, q_camids = [], [], []
|
||||
qf, q_pids, q_camids = [], [], [] # query features, query person IDs and query camera IDs
|
||||
for batch_idx, data in enumerate(queryloader):
|
||||
imgs, pids, camids = self._parse_data_for_eval(data)
|
||||
if self.use_gpu:
|
||||
|
@ -226,7 +232,7 @@ class Engine(object):
|
|||
print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))
|
||||
|
||||
print('Extracting features from gallery set ...')
|
||||
gf, g_pids, g_camids = [], [], []
|
||||
gf, g_pids, g_camids = [], [], [] # gallery features, gallery person IDs and gallery camera IDs
|
||||
end = time.time()
|
||||
for batch_idx, data in enumerate(galleryloader):
|
||||
imgs, pids, camids = self._parse_data_for_eval(data)
|
||||
|
@ -249,6 +255,12 @@ class Engine(object):
|
|||
distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
|
||||
distmat = distmat.numpy()
|
||||
|
||||
if rerank:
|
||||
print('Applying person re-ranking ...')
|
||||
distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
|
||||
distmat_gg = metrics.compute_distance_matrix(gf, gf, dist_metric)
|
||||
distmat = re_ranking(distmat, distmat_qq, distmat_gg)
|
||||
|
||||
print('Computing CMC and mAP ...')
|
||||
cmc, mAP = metrics.evaluate_rank(
|
||||
distmat,
|
||||
|
|
|
@ -4,4 +4,5 @@ from .avgmeter import *
|
|||
from .loggers import *
|
||||
from .tools import *
|
||||
from .reidtools import *
|
||||
from .torchtools import *
|
||||
from .torchtools import *
|
||||
from .rerank import re_ranking
|
|
@ -0,0 +1,100 @@
|
|||
#!/usr/bin/env python2/python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Source: https://github.com/zhunzhong07/person-re-ranking
|
||||
|
||||
Created on Mon Jun 26 14:46:56 2017
|
||||
@author: luohao
|
||||
Modified by Houjing Huang, 2017-12-22.
|
||||
- This version accepts distance matrix instead of raw features.
|
||||
- The difference of `/` division between python 2 and 3 is handled.
|
||||
- numpy.float16 is replaced by numpy.float32 for numerical precision.
|
||||
|
||||
CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
|
||||
url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
|
||||
Matlab version: https://github.com/zhunzhong07/person-re-ranking
|
||||
|
||||
API
|
||||
q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery]
|
||||
q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query]
|
||||
g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery]
|
||||
k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3)
|
||||
Returns:
|
||||
final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery]
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
__all__ = ['re_ranking']
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):
|
||||
|
||||
# The following naming, e.g. gallery_num, is different from outer scope.
|
||||
# Don't care about it.
|
||||
|
||||
original_dist = np.concatenate(
|
||||
[np.concatenate([q_q_dist, q_g_dist], axis=1),
|
||||
np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
|
||||
axis=0)
|
||||
original_dist = np.power(original_dist, 2).astype(np.float32)
|
||||
original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0))
|
||||
V = np.zeros_like(original_dist).astype(np.float32)
|
||||
initial_rank = np.argsort(original_dist).astype(np.int32)
|
||||
|
||||
query_num = q_g_dist.shape[0]
|
||||
gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1]
|
||||
all_num = gallery_num
|
||||
|
||||
for i in range(all_num):
|
||||
# k-reciprocal neighbors
|
||||
forward_k_neigh_index = initial_rank[i,:k1+1]
|
||||
backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
|
||||
fi = np.where(backward_k_neigh_index==i)[0]
|
||||
k_reciprocal_index = forward_k_neigh_index[fi]
|
||||
k_reciprocal_expansion_index = k_reciprocal_index
|
||||
for j in range(len(k_reciprocal_index)):
|
||||
candidate = k_reciprocal_index[j]
|
||||
candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1]
|
||||
candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1]
|
||||
fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
|
||||
candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
|
||||
if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):
|
||||
k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
|
||||
|
||||
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
|
||||
weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
|
||||
V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight)
|
||||
original_dist = original_dist[:query_num,]
|
||||
if k2 != 1:
|
||||
V_qe = np.zeros_like(V,dtype=np.float32)
|
||||
for i in range(all_num):
|
||||
V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
|
||||
V = V_qe
|
||||
del V_qe
|
||||
del initial_rank
|
||||
invIndex = []
|
||||
for i in range(gallery_num):
|
||||
invIndex.append(np.where(V[:,i] != 0)[0])
|
||||
|
||||
jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)
|
||||
|
||||
|
||||
for i in range(query_num):
|
||||
temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32)
|
||||
indNonZero = np.where(V[i,:] != 0)[0]
|
||||
indImages = []
|
||||
indImages = [invIndex[ind] for ind in indNonZero]
|
||||
for j in range(len(indNonZero)):
|
||||
temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
|
||||
jaccard_dist[i] = 1-temp_min/(2.-temp_min)
|
||||
|
||||
final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
|
||||
del original_dist
|
||||
del V
|
||||
del jaccard_dist
|
||||
final_dist = final_dist[:query_num,query_num:]
|
||||
return final_dist
|
Loading…
Reference in New Issue