v1.3.1: record results (rank1 & mAP) of each target dataset
parent
8bbe93ffb5
commit
de42729f41
|
@ -2,7 +2,7 @@ from __future__ import print_function, absolute_import
|
||||||
|
|
||||||
from torchreid import data, optim, utils, engine, losses, models, metrics
|
from torchreid import data, optim, utils, engine, losses, models, metrics
|
||||||
|
|
||||||
__version__ = '1.3.0'
|
__version__ = '1.3.1'
|
||||||
__author__ = 'Kaiyang Zhou'
|
__author__ = 'Kaiyang Zhou'
|
||||||
__homepage__ = 'https://kaiyangzhou.github.io/'
|
__homepage__ = 'https://kaiyangzhou.github.io/'
|
||||||
__description__ = 'Deep learning person re-identification in PyTorch'
|
__description__ = 'Deep learning person re-identification in PyTorch'
|
||||||
|
|
|
@ -31,6 +31,7 @@ class Engine(object):
|
||||||
self.test_loader = self.datamanager.test_loader
|
self.test_loader = self.datamanager.test_loader
|
||||||
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
||||||
self.writer = None
|
self.writer = None
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
self.model = None
|
self.model = None
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
|
@ -166,7 +167,6 @@ class Engine(object):
|
||||||
|
|
||||||
if test_only:
|
if test_only:
|
||||||
self.test(
|
self.test(
|
||||||
0,
|
|
||||||
dist_metric=dist_metric,
|
dist_metric=dist_metric,
|
||||||
normalize_feature=normalize_feature,
|
normalize_feature=normalize_feature,
|
||||||
visrank=visrank,
|
visrank=visrank,
|
||||||
|
@ -198,7 +198,6 @@ class Engine(object):
|
||||||
and (self.epoch+1) % eval_freq == 0 \
|
and (self.epoch+1) % eval_freq == 0 \
|
||||||
and (self.epoch + 1) != self.max_epoch:
|
and (self.epoch + 1) != self.max_epoch:
|
||||||
rank1 = self.test(
|
rank1 = self.test(
|
||||||
self.epoch,
|
|
||||||
dist_metric=dist_metric,
|
dist_metric=dist_metric,
|
||||||
normalize_feature=normalize_feature,
|
normalize_feature=normalize_feature,
|
||||||
visrank=visrank,
|
visrank=visrank,
|
||||||
|
@ -208,12 +207,10 @@ class Engine(object):
|
||||||
ranks=ranks
|
ranks=ranks
|
||||||
)
|
)
|
||||||
self.save_model(self.epoch, rank1, save_dir)
|
self.save_model(self.epoch, rank1, save_dir)
|
||||||
self.writer.add_scalar('Test/rank1', rank1, self.epoch)
|
|
||||||
|
|
||||||
if self.max_epoch > 0:
|
if self.max_epoch > 0:
|
||||||
print('=> Final test')
|
print('=> Final test')
|
||||||
rank1 = self.test(
|
rank1 = self.test(
|
||||||
self.epoch,
|
|
||||||
dist_metric=dist_metric,
|
dist_metric=dist_metric,
|
||||||
normalize_feature=normalize_feature,
|
normalize_feature=normalize_feature,
|
||||||
visrank=visrank,
|
visrank=visrank,
|
||||||
|
@ -223,7 +220,6 @@ class Engine(object):
|
||||||
ranks=ranks
|
ranks=ranks
|
||||||
)
|
)
|
||||||
self.save_model(self.epoch, rank1, save_dir)
|
self.save_model(self.epoch, rank1, save_dir)
|
||||||
self.writer.add_scalar('Test/rank1', rank1, self.epoch)
|
|
||||||
|
|
||||||
elapsed = round(time.time() - time_start)
|
elapsed = round(time.time() - time_start)
|
||||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||||
|
@ -295,7 +291,6 @@ class Engine(object):
|
||||||
|
|
||||||
def test(
|
def test(
|
||||||
self,
|
self,
|
||||||
epoch,
|
|
||||||
dist_metric='euclidean',
|
dist_metric='euclidean',
|
||||||
normalize_feature=False,
|
normalize_feature=False,
|
||||||
visrank=False,
|
visrank=False,
|
||||||
|
@ -326,8 +321,7 @@ class Engine(object):
|
||||||
print('##### Evaluating {} ({}) #####'.format(name, domain))
|
print('##### Evaluating {} ({}) #####'.format(name, domain))
|
||||||
query_loader = self.test_loader[name]['query']
|
query_loader = self.test_loader[name]['query']
|
||||||
gallery_loader = self.test_loader[name]['gallery']
|
gallery_loader = self.test_loader[name]['gallery']
|
||||||
rank1 = self._evaluate(
|
rank1, mAP = self._evaluate(
|
||||||
epoch,
|
|
||||||
dataset_name=name,
|
dataset_name=name,
|
||||||
query_loader=query_loader,
|
query_loader=query_loader,
|
||||||
gallery_loader=gallery_loader,
|
gallery_loader=gallery_loader,
|
||||||
|
@ -341,12 +335,15 @@ class Engine(object):
|
||||||
rerank=rerank
|
rerank=rerank
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.writer is not None:
|
||||||
|
self.writer.add_scalar(f'Test/{name}/rank1', rank1, self.epoch)
|
||||||
|
self.writer.add_scalar(f'Test/{name}/mAP', mAP, self.epoch)
|
||||||
|
|
||||||
return rank1
|
return rank1
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _evaluate(
|
def _evaluate(
|
||||||
self,
|
self,
|
||||||
epoch,
|
|
||||||
dataset_name='',
|
dataset_name='',
|
||||||
query_loader=None,
|
query_loader=None,
|
||||||
gallery_loader=None,
|
gallery_loader=None,
|
||||||
|
@ -433,7 +430,7 @@ class Engine(object):
|
||||||
topk=visrank_topk
|
topk=visrank_topk
|
||||||
)
|
)
|
||||||
|
|
||||||
return cmc[0]
|
return cmc[0], mAP
|
||||||
|
|
||||||
def compute_loss(self, criterion, outputs, targets):
|
def compute_loss(self, criterion, outputs, targets):
|
||||||
if isinstance(outputs, (tuple, list)):
|
if isinstance(outputs, (tuple, list)):
|
||||||
|
|
Loading…
Reference in New Issue