mirror of https://github.com/JDAI-CV/fast-reid.git
fix demo eval bug (#345)
Summary: Fix bug in demo.py when computing cmc and mAP in `evaluate_rank`. This is because after changing eval function, the passing params in demo are not changed. close #345pull/365/head
parent
2724515fd9
commit
1b9799f601
|
@ -17,13 +17,15 @@ from torch.backends import cudnn
|
|||
sys.path.append('..')
|
||||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from predictor import FeatureExtractionDemo
|
||||
|
||||
from predictor import FeatureExtractionDemo
|
||||
# import some modules added in project like this below
|
||||
# from projects.PartialReID.partialreid import *
|
||||
|
||||
cudnn.benchmark = True
|
||||
setup_logger(name="fastreid")
|
||||
|
||||
|
||||
def setup_cfg(args):
|
||||
|
|
|
@ -22,7 +22,13 @@ from fastreid.data import build_reid_test_loader
|
|||
from predictor import FeatureExtractionDemo
|
||||
from fastreid.utils.visualizer import Visualizer
|
||||
|
||||
# import some modules added in project
|
||||
# for example, add partial reid like this below
|
||||
# from projects.PartialReID.partialreid import *
|
||||
|
||||
cudnn.benchmark = True
|
||||
setup_logger(name="fastreid")
|
||||
|
||||
logger = logging.getLogger('fastreid.visualize_result')
|
||||
|
||||
|
||||
|
@ -93,7 +99,6 @@ def get_parser():
|
|||
|
||||
if __name__ == '__main__':
|
||||
args = get_parser().parse_args()
|
||||
logger = setup_logger()
|
||||
cfg = setup_cfg(args)
|
||||
test_loader, num_query = build_reid_test_loader(cfg, args.dataset_name)
|
||||
demo = FeatureExtractionDemo(cfg, parallel=args.parallel)
|
||||
|
@ -120,15 +125,18 @@ if __name__ == '__main__':
|
|||
distmat = distmat.numpy()
|
||||
|
||||
logger.info("Computing APs for all query images ...")
|
||||
cmc, all_ap, all_inp = evaluate_rank(distmat, q_feat, g_feat, q_pids, g_pids, q_camids, g_camids)
|
||||
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
logger.info("Finish computing APs for all query images!")
|
||||
|
||||
visualizer = Visualizer(test_loader.dataset)
|
||||
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
|
||||
logger.info("Saving ROC curve ...")
|
||||
logger.info("Start saving ROC curve ...")
|
||||
fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output)
|
||||
visualizer.save_roc_info(args.output, fpr, tpr, pos, neg)
|
||||
logger.info("Finish saving ROC curve!")
|
||||
|
||||
logger.info("Saving rank list result ...")
|
||||
query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
|
||||
args.rank_sort, args.label_sort, args.max_rank)
|
||||
logger.info("Finish saving rank list results!")
|
||||
|
|
Loading…
Reference in New Issue