diff --git a/demo/run_demo.sh b/demo/run_demo.sh index a4bc776..3e514bb 100644 --- a/demo/run_demo.sh +++ b/demo/run_demo.sh @@ -1,6 +1,4 @@ python demo/visualize_ranking.py --config-file 'logs/market1501/sbs_R50/config.yaml' \ ---parallel \ ---dataset-name 'Market1501' \ ---output 'logs/test_ranking' \ +--parallel --vis-label --dataset-name 'Market1501' --output 'logs/test_ranking' \ --opts MODEL.WEIGHTS 'logs/market1501/sbs_R50/model_final.pth' diff --git a/demo/visualize_ranking.py b/demo/visualize_ranking.py index 80b64dd..0c1e8d3 100644 --- a/demo/visualize_ranking.py +++ b/demo/visualize_ranking.py @@ -8,6 +8,7 @@ import argparse import logging import tqdm import sys +import os import numpy as np import torch @@ -62,6 +63,11 @@ def get_parser(): help="a file or directory to save rankling list result.", ) + parser.add_argument( + "--vis-label", + action='store_true', + help="if visualize label of query instance" + ) parser.add_argument( "--num-vis", default=100, @@ -72,6 +78,11 @@ def get_parser(): default="ascending", help="rank order of visualization images by AP metric", ) + parser.add_argument( + "--label-sort", + default="ascending", + help="label order of visualization images by cosine similarity metric", + ) parser.add_argument( "--max-rank", default=10, @@ -115,10 +126,10 @@ if __name__ == '__main__': distmat = distmat.numpy() logger.info("Computing APs for all query images ...") - cmc, all_ap, all_inp = evaluate_rank(1-distmat, q_pids, g_pids, q_camids, g_camids) + cmc, all_ap, all_inp = evaluate_rank(1 - distmat, q_pids, g_pids, q_camids, g_camids) visualizer = Visualizer(test_loader.loader.dataset) visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids) - logger.info("Saving ranking list result ...") - visualizer.vis_ranking_list(args.output, args.num_vis, rank_sort=args.rank_sort, max_rank=args.max_rank) - + logger.info("Saving rank list result ...") + query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis, + args.rank_sort, args.label_sort, args.max_rank) diff --git a/fastreid/utils/visualizer.py b/fastreid/utils/visualizer.py index 1e63db4..f0bd66b 100644 --- a/fastreid/utils/visualizer.py +++ b/fastreid/utils/visualizer.py @@ -7,7 +7,6 @@ import os import cv2 -import matplotlib.figure as mplfigure import matplotlib.pyplot as plt import numpy as np import tqdm @@ -46,9 +45,12 @@ class Visualizer: sort_idx = order[keep] return cmc, sort_idx - def save_rank_result(self, query_indices, output, max_rank=5, actmap=False): - fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6)) - # fig.suptitle('query/AP/camid sim/true(false)/camid') + def save_rank_result(self, query_indices, output, max_rank=5, vis_label=False, label_sort='ascending', + actmap=False): + if vis_label: + fig, axes = plt.subplots(2, max_rank + 1, figsize=(3 * max_rank, 12)) + else: + fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6)) for cnt, q_idx in enumerate(tqdm.tqdm(query_indices)): all_imgs = [] cmc, sort_idx = self.get_matched_result(q_idx) @@ -58,11 +60,16 @@ class Visualizer: query_name = query_info['img_path'].split('/')[-1] all_imgs.append(query_img) query_img = np.rollaxis(np.asarray(query_img.numpy(), dtype=np.uint8), 0, 3) - axes.flat[0].imshow(query_img) - axes.flat[0].set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id)) - axes.flat[0].axis("off") - # print('query' + query_info['img_path'].split('/')[-1]) + plt.clf() + ax = fig.add_subplot(1, max_rank + 1, 1) + ax.imshow(query_img) + ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id)) + ax.axis("off") for i in range(max_rank): + if vis_label: + ax = fig.add_subplot(2, max_rank + 1, i + 2) + else: + ax = fig.add_subplot(1, max_rank + 1, i + 2) g_idx = self.num_query + sort_idx[i] gallery_info = self.dataset[g_idx] gallery_img = gallery_info['images'] @@ -71,18 +78,17 @@ class Visualizer: gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3) if cmc[i] == 1: label = 'true' - axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1, - height=gallery_img.shape[0] - 1, edgecolor=(1, 0, 0), - fill=False, linewidth=5)) + ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1, + height=gallery_img.shape[0] - 1, edgecolor=(1, 0, 0), + fill=False, linewidth=5)) else: label = 'false' - axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1, - height=gallery_img.shape[0] - 1, - edgecolor=(0, 0, 1), fill=False, linewidth=5)) - axes.flat[i + 1].imshow(gallery_img) - # print('/'.join(gallery_info['img_path'].split('/')[-2:])) - axes.flat[i + 1].set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}') - axes.flat[i + 1].axis("off") + ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1, + height=gallery_img.shape[0] - 1, + edgecolor=(0, 0, 1), fill=False, linewidth=5)) + ax.imshow(gallery_img) + ax.set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}') + ax.axis("off") # if actmap: # act_outputs = [] # @@ -101,19 +107,42 @@ class Visualizer: # acts = self.get_actmap(act_outputs[0], sz) # for i in range(top + 1): # axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet') + if vis_label: + label_indice = np.where(cmc == 1)[0] + if label_sort == "ascending": label_indice = label_indice[::-1] + label_indice = label_indice[:max_rank] + for i in range(max_rank): + if i >= len(label_indice): break + j = label_indice[i] + g_idx = self.num_query + sort_idx[j] + gallery_info = self.dataset[g_idx] + gallery_img = gallery_info['images'] + cam_id = gallery_info['camid'] + gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3) + ax = fig.add_subplot(2, max_rank + 1, max_rank + 3 + i) + ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1, + height=gallery_img.shape[0] - 1, + edgecolor=(1, 0, 0), + fill=False, linewidth=5)) + ax.imshow(gallery_img) + ax.set_title(f'{self.sim[q_idx, sort_idx[j]]:.3f}/cam{cam_id}') + ax.axis("off") + plt.tight_layout() filepath = os.path.join(output, "{}.jpg".format(cnt)) fig.savefig(filepath) - plt.cla() - def vis_ranking_list(self, output, num_vis=100, rank_sort='ascending', max_rank=5, actmap=False): - """ + def vis_rank_list(self, output, vis_label, num_vis=100, rank_sort="ascending", label_sort="ascending", max_rank=5, + actmap=False): + r"""Visualize rank list of query instance Args: - output (str): a file or directory to save rankling list result. + output (str): a directory to save rank list result. + vis_label (bool): if visualize label of query + num_vis (int): rank_sort (str): save visualization results by which order, if rank_sort is ascending, AP from low to high, vice versa. - num_vis (int): - max_rank (int): + label_sort (bool): + max_rank (int): maximum number of rank result to visualize actmap (bool): """ assert rank_sort in ['ascending', 'descending'], "{} not match [ascending, descending]".format(rank_sort) @@ -123,7 +152,7 @@ class Visualizer: if rank_sort == 'descending': query_indices = query_indices[::-1] query_indices = query_indices[:num_vis] - self.save_rank_result(query_indices, output, max_rank, actmap) + self.save_rank_result(query_indices, output, max_rank, vis_label, label_sort, actmap) def plot_roc_curve(self): pos_sim, neg_sim = [], [] @@ -155,7 +184,7 @@ class Visualizer: same_cam.extend(self.sim[i, sameCam_idx]) diff_cam.extend(self.sim[i, diffCam_idx]) - fig = mplfigure(figsize=(10, 5)) + fig = plt.figure(figsize=(10, 5)) plt.hist(same_cam, bins=80, alpha=0.7, density=True, color='red', label='same camera') plt.hist(diff_cam, bins=80, alpha=0.5, density=True, color='blue', label='diff camera') plt.xticks(np.arange(0.1, 1.0, 0.1))