fix demo eval bug (#345)

Summary: Fix bug in demo.py when computing cmc and mAP in `evaluate_rank`. This is because after changing eval function, the passing params in demo are not changed.

close #345
pull/365/head
liaoxingyu 2020-12-04 11:14:18 +08:00
parent 2724515fd9
commit 1b9799f601
2 changed files with 14 additions and 4 deletions

View File

@ -17,13 +17,15 @@ from torch.backends import cudnn
sys.path.append('..')
from fastreid.config import get_cfg
from fastreid.utils.logger import setup_logger
from fastreid.utils.file_io import PathManager
from predictor import FeatureExtractionDemo
from predictor import FeatureExtractionDemo
# import some modules added in project like this below
# from projects.PartialReID.partialreid import *
cudnn.benchmark = True
setup_logger(name="fastreid")
def setup_cfg(args):

View File

@ -22,7 +22,13 @@ from fastreid.data import build_reid_test_loader
from predictor import FeatureExtractionDemo
from fastreid.utils.visualizer import Visualizer
# import some modules added in project
# for example, add partial reid like this below
# from projects.PartialReID.partialreid import *
cudnn.benchmark = True
setup_logger(name="fastreid")
logger = logging.getLogger('fastreid.visualize_result')
@ -93,7 +99,6 @@ def get_parser():
if __name__ == '__main__':
args = get_parser().parse_args()
logger = setup_logger()
cfg = setup_cfg(args)
test_loader, num_query = build_reid_test_loader(cfg, args.dataset_name)
demo = FeatureExtractionDemo(cfg, parallel=args.parallel)
@ -120,15 +125,18 @@ if __name__ == '__main__':
distmat = distmat.numpy()
logger.info("Computing APs for all query images ...")
cmc, all_ap, all_inp = evaluate_rank(distmat, q_feat, g_feat, q_pids, g_pids, q_camids, g_camids)
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
logger.info("Finish computing APs for all query images!")
visualizer = Visualizer(test_loader.dataset)
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
logger.info("Saving ROC curve ...")
logger.info("Start saving ROC curve ...")
fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output)
visualizer.save_roc_info(args.output, fpr, tpr, pos, neg)
logger.info("Finish saving ROC curve!")
logger.info("Saving rank list result ...")
query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
args.rank_sort, args.label_sort, args.max_rank)
logger.info("Finish saving rank list results!")