fast-reid/demo/visualize_result.py

140 lines
3.9 KiB
Python
Raw Normal View History

# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
import argparse
import logging
import sys
import numpy as np
import torch
import tqdm
from torch.backends import cudnn
sys.path.append('.')
from fastreid.evaluation import evaluate_rank
from fastreid.config import get_cfg
from fastreid.utils.logger import setup_logger
from fastreid.data import build_reid_test_loader
from predictor import FeatureExtractionDemo
from fastreid.utils.visualizer import Visualizer
cudnn.benchmark = True
2020-05-11 14:12:29 +08:00
logger = logging.getLogger('fastreid')
def setup_cfg(args):
# load config from file and command-line arguments
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def get_parser():
parser = argparse.ArgumentParser(description="Feature extraction with reid models")
parser.add_argument(
"--config-file",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
'--device',
default='cuda: 1',
help='CUDA device to use'
)
parser.add_argument(
'--parallel',
action='store_true',
help='if use multiprocess for feature extraction.'
)
parser.add_argument(
"--dataset-name",
help="a test dataset name for visualizing ranking list."
)
parser.add_argument(
"--output",
default="./vis_rank_list",
help="a file or directory to save rankling list result.",
)
parser.add_argument(
"--vis-label",
action='store_true',
help="if visualize label of query instance"
)
parser.add_argument(
"--num-vis",
default=100,
help="number of query images to be visualized",
2020-05-11 14:12:29 +08:00
)
parser.add_argument(
"--rank-sort",
default="ascending",
help="rank order of visualization images by AP metric",
)
parser.add_argument(
"--label-sort",
default="ascending",
help="label order of visualization images by cosine similarity metric",
)
parser.add_argument(
"--max-rank",
default=10,
help="maximum number of rank list to be visualized",
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
return parser
if __name__ == '__main__':
args = get_parser().parse_args()
logger = setup_logger()
cfg = setup_cfg(args)
test_loader, num_query = build_reid_test_loader(cfg, args.dataset_name)
demo = FeatureExtractionDemo(cfg, device=args.device, parallel=args.parallel)
logger.info("Start extracting image features")
feats = []
pids = []
camids = []
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)):
feats.append(feat)
pids.extend(pid)
camids.extend(camid)
feats = torch.cat(feats, dim=0)
q_feat = feats[:num_query]
g_feat = feats[num_query:]
q_pids = np.asarray(pids[:num_query])
g_pids = np.asarray(pids[num_query:])
q_camids = np.asarray(camids[:num_query])
g_camids = np.asarray(camids[num_query:])
# compute cosine distance
distmat = 1 - torch.mm(q_feat, g_feat.t())
distmat = distmat.numpy()
logger.info("Computing APs for all query images ...")
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
visualizer = Visualizer(test_loader.dataset)
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
logger.info("Saving ROC curve ...")
fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output)
visualizer.save_roc_info(args.output, fpr, tpr, pos, neg)
logger.info("Saving rank list result ...")
query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
args.rank_sort, args.label_sort, args.max_rank)