2020-05-10 23:17:10 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: xingyu liao
|
2020-07-29 17:43:39 +08:00
|
|
|
@contact: sherlockliao01@gmail.com
|
2020-05-10 23:17:10 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import logging
|
|
|
|
import sys
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2020-05-19 20:45:26 +08:00
|
|
|
import tqdm
|
2020-05-10 23:17:10 +08:00
|
|
|
from torch.backends import cudnn
|
|
|
|
|
|
|
|
sys.path.append('.')
|
|
|
|
|
|
|
|
from fastreid.evaluation import evaluate_rank
|
|
|
|
from fastreid.config import get_cfg
|
|
|
|
from fastreid.utils.logger import setup_logger
|
|
|
|
from fastreid.data import build_reid_test_loader
|
|
|
|
from predictor import FeatureExtractionDemo
|
|
|
|
from fastreid.utils.visualizer import Visualizer
|
|
|
|
|
2020-12-04 11:14:18 +08:00
|
|
|
# import some modules added in project
|
|
|
|
# for example, add partial reid like this below
|
2021-04-21 16:24:34 +08:00
|
|
|
# sys.path.append("projects/PartialReID")
|
|
|
|
# from partialreid import *
|
2020-12-04 11:14:18 +08:00
|
|
|
|
2020-05-10 23:17:10 +08:00
|
|
|
cudnn.benchmark = True
|
2020-12-04 11:14:18 +08:00
|
|
|
setup_logger(name="fastreid")
|
|
|
|
|
2020-05-27 22:53:27 +08:00
|
|
|
logger = logging.getLogger('fastreid.visualize_result')
|
2020-05-10 23:17:10 +08:00
|
|
|
|
|
|
|
|
|
|
|
def setup_cfg(args):
|
|
|
|
# load config from file and command-line arguments
|
|
|
|
cfg = get_cfg()
|
2021-04-21 16:24:34 +08:00
|
|
|
# add_partialreid_config(cfg)
|
2020-05-10 23:17:10 +08:00
|
|
|
cfg.merge_from_file(args.config_file)
|
|
|
|
cfg.merge_from_list(args.opts)
|
|
|
|
cfg.freeze()
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
def get_parser():
|
|
|
|
parser = argparse.ArgumentParser(description="Feature extraction with reid models")
|
|
|
|
parser.add_argument(
|
|
|
|
"--config-file",
|
|
|
|
metavar="FILE",
|
|
|
|
help="path to config file",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
'--parallel',
|
|
|
|
action='store_true',
|
|
|
|
help='if use multiprocess for feature extraction.'
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--dataset-name",
|
|
|
|
help="a test dataset name for visualizing ranking list."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--output",
|
|
|
|
default="./vis_rank_list",
|
|
|
|
help="a file or directory to save rankling list result.",
|
|
|
|
|
|
|
|
)
|
2020-05-12 21:35:33 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--vis-label",
|
|
|
|
action='store_true',
|
|
|
|
help="if visualize label of query instance"
|
|
|
|
)
|
2020-05-10 23:17:10 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--num-vis",
|
|
|
|
default=100,
|
|
|
|
help="number of query images to be visualized",
|
2020-05-11 14:12:29 +08:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--rank-sort",
|
|
|
|
default="ascending",
|
|
|
|
help="rank order of visualization images by AP metric",
|
2020-05-10 23:17:10 +08:00
|
|
|
)
|
2020-05-12 21:35:33 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--label-sort",
|
|
|
|
default="ascending",
|
|
|
|
help="label order of visualization images by cosine similarity metric",
|
|
|
|
)
|
2020-05-10 23:17:10 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--max-rank",
|
|
|
|
default=10,
|
|
|
|
help="maximum number of rank list to be visualized",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--opts",
|
|
|
|
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
|
|
|
default=[],
|
|
|
|
nargs=argparse.REMAINDER,
|
|
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = get_parser().parse_args()
|
|
|
|
cfg = setup_cfg(args)
|
2021-04-21 16:24:34 +08:00
|
|
|
test_loader, num_query = build_reid_test_loader(cfg, dataset_name=args.dataset_name)
|
2020-07-06 16:57:43 +08:00
|
|
|
demo = FeatureExtractionDemo(cfg, parallel=args.parallel)
|
2020-05-10 23:17:10 +08:00
|
|
|
|
|
|
|
logger.info("Start extracting image features")
|
|
|
|
feats = []
|
|
|
|
pids = []
|
|
|
|
camids = []
|
2020-05-25 23:39:11 +08:00
|
|
|
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)):
|
2020-05-10 23:17:10 +08:00
|
|
|
feats.append(feat)
|
|
|
|
pids.extend(pid)
|
|
|
|
camids.extend(camid)
|
|
|
|
|
|
|
|
feats = torch.cat(feats, dim=0)
|
|
|
|
q_feat = feats[:num_query]
|
|
|
|
g_feat = feats[num_query:]
|
|
|
|
q_pids = np.asarray(pids[:num_query])
|
|
|
|
g_pids = np.asarray(pids[num_query:])
|
|
|
|
q_camids = np.asarray(camids[:num_query])
|
|
|
|
g_camids = np.asarray(camids[num_query:])
|
|
|
|
|
|
|
|
# compute cosine distance
|
2020-05-20 14:29:33 +08:00
|
|
|
distmat = 1 - torch.mm(q_feat, g_feat.t())
|
2020-05-10 23:17:10 +08:00
|
|
|
distmat = distmat.numpy()
|
|
|
|
|
|
|
|
logger.info("Computing APs for all query images ...")
|
2020-12-04 11:14:18 +08:00
|
|
|
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
|
|
|
|
logger.info("Finish computing APs for all query images!")
|
2020-05-10 23:17:10 +08:00
|
|
|
|
2020-05-25 23:39:11 +08:00
|
|
|
visualizer = Visualizer(test_loader.dataset)
|
2020-05-10 23:17:10 +08:00
|
|
|
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
|
2020-05-19 20:45:26 +08:00
|
|
|
|
2020-12-04 11:14:18 +08:00
|
|
|
logger.info("Start saving ROC curve ...")
|
2020-05-20 14:29:33 +08:00
|
|
|
fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output)
|
|
|
|
visualizer.save_roc_info(args.output, fpr, tpr, pos, neg)
|
2020-12-04 11:14:18 +08:00
|
|
|
logger.info("Finish saving ROC curve!")
|
2020-05-19 20:45:26 +08:00
|
|
|
|
2020-05-12 21:35:33 +08:00
|
|
|
logger.info("Saving rank list result ...")
|
|
|
|
query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
|
|
|
|
args.rank_sort, args.label_sort, args.max_rank)
|
2020-12-04 11:14:18 +08:00
|
|
|
logger.info("Finish saving rank list results!")
|