2020-05-10 23:17:10 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: xingyu liao
|
|
|
|
@contact: liaoxingyu5@jd.com
|
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import logging
|
|
|
|
import sys
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2020-05-19 20:45:26 +08:00
|
|
|
import tqdm
|
2020-05-10 23:17:10 +08:00
|
|
|
from torch.backends import cudnn
|
|
|
|
|
|
|
|
sys.path.append('.')
|
|
|
|
|
|
|
|
from fastreid.evaluation import evaluate_rank
|
|
|
|
from fastreid.config import get_cfg
|
|
|
|
from fastreid.utils.logger import setup_logger
|
|
|
|
from fastreid.data import build_reid_test_loader
|
|
|
|
from predictor import FeatureExtractionDemo
|
|
|
|
from fastreid.utils.visualizer import Visualizer
|
|
|
|
|
|
|
|
cudnn.benchmark = True
|
2020-05-27 22:53:27 +08:00
|
|
|
logger = logging.getLogger('fastreid.visualize_result')
|
2020-05-10 23:17:10 +08:00
|
|
|
|
|
|
|
|
|
|
|
def setup_cfg(args):
|
|
|
|
# load config from file and command-line arguments
|
|
|
|
cfg = get_cfg()
|
|
|
|
cfg.merge_from_file(args.config_file)
|
|
|
|
cfg.merge_from_list(args.opts)
|
|
|
|
cfg.freeze()
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
def get_parser():
|
|
|
|
parser = argparse.ArgumentParser(description="Feature extraction with reid models")
|
|
|
|
parser.add_argument(
|
|
|
|
"--config-file",
|
|
|
|
metavar="FILE",
|
|
|
|
help="path to config file",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
'--device',
|
|
|
|
default='cuda: 1',
|
|
|
|
help='CUDA device to use'
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
'--parallel',
|
|
|
|
action='store_true',
|
|
|
|
help='if use multiprocess for feature extraction.'
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--dataset-name",
|
|
|
|
help="a test dataset name for visualizing ranking list."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--output",
|
|
|
|
default="./vis_rank_list",
|
|
|
|
help="a file or directory to save rankling list result.",
|
|
|
|
|
|
|
|
)
|
2020-05-12 21:35:33 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--vis-label",
|
|
|
|
action='store_true',
|
|
|
|
help="if visualize label of query instance"
|
|
|
|
)
|
2020-05-10 23:17:10 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--num-vis",
|
|
|
|
default=100,
|
|
|
|
help="number of query images to be visualized",
|
2020-05-11 14:12:29 +08:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--rank-sort",
|
|
|
|
default="ascending",
|
|
|
|
help="rank order of visualization images by AP metric",
|
2020-05-10 23:17:10 +08:00
|
|
|
)
|
2020-05-12 21:35:33 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--label-sort",
|
|
|
|
default="ascending",
|
|
|
|
help="label order of visualization images by cosine similarity metric",
|
|
|
|
)
|
2020-05-10 23:17:10 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--max-rank",
|
|
|
|
default=10,
|
|
|
|
help="maximum number of rank list to be visualized",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--opts",
|
|
|
|
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
|
|
|
default=[],
|
|
|
|
nargs=argparse.REMAINDER,
|
|
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = get_parser().parse_args()
|
|
|
|
logger = setup_logger()
|
|
|
|
cfg = setup_cfg(args)
|
|
|
|
test_loader, num_query = build_reid_test_loader(cfg, args.dataset_name)
|
|
|
|
demo = FeatureExtractionDemo(cfg, device=args.device, parallel=args.parallel)
|
|
|
|
|
|
|
|
logger.info("Start extracting image features")
|
|
|
|
feats = []
|
|
|
|
pids = []
|
|
|
|
camids = []
|
2020-05-25 23:39:11 +08:00
|
|
|
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)):
|
2020-05-10 23:17:10 +08:00
|
|
|
feats.append(feat)
|
|
|
|
pids.extend(pid)
|
|
|
|
camids.extend(camid)
|
|
|
|
|
|
|
|
feats = torch.cat(feats, dim=0)
|
|
|
|
q_feat = feats[:num_query]
|
|
|
|
g_feat = feats[num_query:]
|
|
|
|
q_pids = np.asarray(pids[:num_query])
|
|
|
|
g_pids = np.asarray(pids[num_query:])
|
|
|
|
q_camids = np.asarray(camids[:num_query])
|
|
|
|
g_camids = np.asarray(camids[num_query:])
|
|
|
|
|
|
|
|
# compute cosine distance
|
2020-05-20 14:29:33 +08:00
|
|
|
distmat = 1 - torch.mm(q_feat, g_feat.t())
|
2020-05-10 23:17:10 +08:00
|
|
|
distmat = distmat.numpy()
|
|
|
|
|
|
|
|
logger.info("Computing APs for all query images ...")
|
2020-05-20 14:29:33 +08:00
|
|
|
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
|
2020-05-10 23:17:10 +08:00
|
|
|
|
2020-05-25 23:39:11 +08:00
|
|
|
visualizer = Visualizer(test_loader.dataset)
|
2020-05-10 23:17:10 +08:00
|
|
|
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
|
2020-05-19 20:45:26 +08:00
|
|
|
|
|
|
|
logger.info("Saving ROC curve ...")
|
2020-05-20 14:29:33 +08:00
|
|
|
fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output)
|
|
|
|
visualizer.save_roc_info(args.output, fpr, tpr, pos, neg)
|
2020-05-19 20:45:26 +08:00
|
|
|
|
2020-05-12 21:35:33 +08:00
|
|
|
logger.info("Saving rank list result ...")
|
|
|
|
query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
|
|
|
|
args.rank_sort, args.label_sort, args.max_rank)
|