add vis-ranked-res

pull/62/head
KaiyangZhou 2018-08-01 10:52:06 +01:00
parent 410ed98c97
commit 83a1bb7bec
1 changed files with 13 additions and 2 deletions

View File

@ -24,6 +24,7 @@ from utils.iotools import save_checkpoint
from utils.avgmeter import AverageMeter
from utils.logger import Logger
from utils.torchtools import set_bn_to_eval, count_num_param
from utils.reidtools import visualize_ranked_results
from eval_metrics import evaluate
from optimizers import init_optim
@ -97,6 +98,8 @@ parser.add_argument('--use-cpu', action='store_true',
help="use cpu")
parser.add_argument('--gpu-devices', default='0', type=str,
help='gpu device ids for CUDA_VISIBLE_DEVICES')
parser.add_argument('--vis-ranked-res', action='store_true',
help="visualize ranked results, only available in evaluation mode (default: False)")
args = parser.parse_args()
@ -202,7 +205,13 @@ def main():
if args.evaluate:
print("Evaluate only")
test(model, queryloader, galleryloader, use_gpu)
_, distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True)
if args.vis_ranked_res:
visualize_ranked_results(
distmat, dataset,
save_dir=osp.join(args.save_dir, 'ranked_results'),
topk=20,
)
return
start_time = time.time()
@ -298,7 +307,7 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=Fa
end = time.time()
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], return_distmat=False):
batch_time = AverageMeter()
model.eval()
@ -359,6 +368,8 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):
print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
print("------------------")
if return_distmat:
return cmc[0], distmat
return cmc[0]