add --normalize-feature to args parser

pull/173/head
KaiyangZhou 2019-05-09 23:11:10 +01:00
parent 5839a026c3
commit 3e10dc60dc
2 changed files with 25 additions and 7 deletions

View File

@ -163,6 +163,8 @@ def init_parser():
help='start to evaluate after a specific epoch')
parser.add_argument('--dist-metric', type=str, default='euclidean',
help='distance metric')
parser.add_argument('--normalize-feature', action='store_true',
help='normalize feature vectors before calculating distance')
parser.add_argument('--ranks', type=str, default=[1, 5, 10, 20], nargs='+',
help='cmc ranks')
parser.add_argument('--rerank', action='store_true',
@ -278,6 +280,7 @@ def engine_run_kwargs(parsed_args):
'test_only': parsed_args.evaluate,
'print_freq': parsed_args.print_freq,
'dist_metric': parsed_args.dist_metric,
'normalize_feature': parsed_args.normalize_feature,
'visrank': parsed_args.visrank,
'visrank_topk': parsed_args.visrank_topk,
'use_metric_cuhk03': parsed_args.use_metric_cuhk03,

View File

@ -10,6 +10,7 @@ import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchreid
from torchreid.utils import AverageMeter, visualize_ranked_results, save_checkpoint, re_ranking
@ -42,7 +43,7 @@ class Engine(object):
def run(self, save_dir='log', max_epoch=0, start_epoch=0, fixbase_epoch=0, open_layers=None,
start_eval=0, eval_freq=-1, test_only=False, print_freq=10,
dist_metric='euclidean', visrank=False, visrank_topk=20,
dist_metric='euclidean', normalize_feature=False, visrank=False, visrank_topk=20,
use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False):
r"""A unified pipeline for training and evaluating a model.
@ -62,6 +63,8 @@ class Engine(object):
print_freq (int, optional): print_frequency. Default is 10.
dist_metric (str, optional): distance metric used to compute distance matrix
between query and gallery. Default is "euclidean".
normalize_feature (bool, optional): performs L2 normalization on feature vectors before
computing feature distance. Default is False.
visrank (bool, optional): visualizes ranked results. Default is False. Visualization
will be performed every test time, so it is recommended to enable ``visrank`` when
``test_only`` is True. The ranked images will be saved to
@ -70,7 +73,7 @@ class Engine(object):
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
Default is False. This should be enabled when using cuhk03 classic split.
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
rerank (bool, optional): use person re-ranking (by Zhong et al. CVPR'17).
rerank (bool, optional): uses person re-ranking (by Zhong et al. CVPR'17).
Default is False. This is only enabled when test_only=True.
"""
trainloader, testloader = self.datamanager.return_dataloaders()
@ -80,6 +83,7 @@ class Engine(object):
0,
testloader,
dist_metric=dist_metric,
normalize_feature=normalize_feature,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
@ -107,6 +111,7 @@ class Engine(object):
epoch,
testloader,
dist_metric=dist_metric,
normalize_feature=normalize_feature,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
@ -121,6 +126,7 @@ class Engine(object):
epoch,
testloader,
dist_metric=dist_metric,
normalize_feature=normalize_feature,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
@ -149,8 +155,9 @@ class Engine(object):
"""
raise NotImplementedError
def test(self, epoch, testloader, dist_metric='euclidean', visrank=False, visrank_topk=20,
save_dir='', use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False):
def test(self, epoch, testloader, dist_metric='euclidean', normalize_feature=False,
visrank=False, visrank_topk=20, save_dir='', use_metric_cuhk03=False,
ranks=[1, 5, 10, 20], rerank=False):
r"""Tests model on target datasets.
.. note::
@ -170,6 +177,8 @@ class Engine(object):
{dataset_name: 'query': queryloader, 'gallery': galleryloader}.
dist_metric (str, optional): distance metric used to compute distance matrix
between query and gallery. Default is "euclidean".
normalize_feature (bool, optional): performs L2 normalization on feature vectors before
computing feature distance. Default is False.
visrank (bool, optional): visualizes ranked results. Default is False. Visualization
will be performed every test time, so it is recommended to enable ``visrank`` when
``test_only`` is True. The ranked images will be saved to
@ -179,7 +188,7 @@ class Engine(object):
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
Default is False. This should be enabled when using cuhk03 classic split.
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
rerank (bool, optional): use person re-ranking (by Zhong et al. CVPR'17).
rerank (bool, optional): uses person re-ranking (by Zhong et al. CVPR'17).
Default is False.
"""
targets = list(testloader.keys())
@ -195,6 +204,7 @@ class Engine(object):
queryloader=queryloader,
galleryloader=galleryloader,
dist_metric=dist_metric,
normalize_feature=normalize_feature,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
@ -207,8 +217,9 @@ class Engine(object):
@torch.no_grad()
def _evaluate(self, epoch, dataset_name='', queryloader=None, galleryloader=None,
dist_metric='euclidean', visrank=False, visrank_topk=20, save_dir='',
use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False):
dist_metric='euclidean', normalize_feature=False, visrank=False,
visrank_topk=20, save_dir='', use_metric_cuhk03=False, ranks=[1, 5, 10, 20],
rerank=False):
batch_time = AverageMeter()
self.model.eval()
@ -252,6 +263,10 @@ class Engine(object):
print('Speed: {:.4f} sec/batch'.format(batch_time.avg))
if normalize_feature:
qf = F.normalize(qf, p=2, dim=1)
gf = F.normalize(gf, p=2, dim=1)
distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
distmat = distmat.numpy()