v1.3.1: record results (rank1 & mAP) of each target dataset
parent
8bbe93ffb5
commit
de42729f41
|
@ -2,7 +2,7 @@ from __future__ import print_function, absolute_import
|
|||
|
||||
from torchreid import data, optim, utils, engine, losses, models, metrics
|
||||
|
||||
__version__ = '1.3.0'
|
||||
__version__ = '1.3.1'
|
||||
__author__ = 'Kaiyang Zhou'
|
||||
__homepage__ = 'https://kaiyangzhou.github.io/'
|
||||
__description__ = 'Deep learning person re-identification in PyTorch'
|
||||
|
|
|
@ -31,6 +31,7 @@ class Engine(object):
|
|||
self.test_loader = self.datamanager.test_loader
|
||||
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
||||
self.writer = None
|
||||
self.epoch = 0
|
||||
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
|
@ -166,7 +167,6 @@ class Engine(object):
|
|||
|
||||
if test_only:
|
||||
self.test(
|
||||
0,
|
||||
dist_metric=dist_metric,
|
||||
normalize_feature=normalize_feature,
|
||||
visrank=visrank,
|
||||
|
@ -198,7 +198,6 @@ class Engine(object):
|
|||
and (self.epoch+1) % eval_freq == 0 \
|
||||
and (self.epoch + 1) != self.max_epoch:
|
||||
rank1 = self.test(
|
||||
self.epoch,
|
||||
dist_metric=dist_metric,
|
||||
normalize_feature=normalize_feature,
|
||||
visrank=visrank,
|
||||
|
@ -208,12 +207,10 @@ class Engine(object):
|
|||
ranks=ranks
|
||||
)
|
||||
self.save_model(self.epoch, rank1, save_dir)
|
||||
self.writer.add_scalar('Test/rank1', rank1, self.epoch)
|
||||
|
||||
if self.max_epoch > 0:
|
||||
print('=> Final test')
|
||||
rank1 = self.test(
|
||||
self.epoch,
|
||||
dist_metric=dist_metric,
|
||||
normalize_feature=normalize_feature,
|
||||
visrank=visrank,
|
||||
|
@ -223,7 +220,6 @@ class Engine(object):
|
|||
ranks=ranks
|
||||
)
|
||||
self.save_model(self.epoch, rank1, save_dir)
|
||||
self.writer.add_scalar('Test/rank1', rank1, self.epoch)
|
||||
|
||||
elapsed = round(time.time() - time_start)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
@ -295,7 +291,6 @@ class Engine(object):
|
|||
|
||||
def test(
|
||||
self,
|
||||
epoch,
|
||||
dist_metric='euclidean',
|
||||
normalize_feature=False,
|
||||
visrank=False,
|
||||
|
@ -326,8 +321,7 @@ class Engine(object):
|
|||
print('##### Evaluating {} ({}) #####'.format(name, domain))
|
||||
query_loader = self.test_loader[name]['query']
|
||||
gallery_loader = self.test_loader[name]['gallery']
|
||||
rank1 = self._evaluate(
|
||||
epoch,
|
||||
rank1, mAP = self._evaluate(
|
||||
dataset_name=name,
|
||||
query_loader=query_loader,
|
||||
gallery_loader=gallery_loader,
|
||||
|
@ -341,12 +335,15 @@ class Engine(object):
|
|||
rerank=rerank
|
||||
)
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_scalar(f'Test/{name}/rank1', rank1, self.epoch)
|
||||
self.writer.add_scalar(f'Test/{name}/mAP', mAP, self.epoch)
|
||||
|
||||
return rank1
|
||||
|
||||
@torch.no_grad()
|
||||
def _evaluate(
|
||||
self,
|
||||
epoch,
|
||||
dataset_name='',
|
||||
query_loader=None,
|
||||
gallery_loader=None,
|
||||
|
@ -433,7 +430,7 @@ class Engine(object):
|
|||
topk=visrank_topk
|
||||
)
|
||||
|
||||
return cmc[0]
|
||||
return cmc[0], mAP
|
||||
|
||||
def compute_loss(self, criterion, outputs, targets):
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
|
|
Loading…
Reference in New Issue