style: remove title in visualization

pull/50/head
liaoxingyu 2020-05-11 14:12:29 +08:00
parent 13bb03eb07
commit 9b6fda3830
2 changed files with 13 additions and 6 deletions

View File

@ -23,7 +23,7 @@ from predictor import FeatureExtractionDemo
from fastreid.utils.visualizer import Visualizer
cudnn.benchmark = True
logger = logging.getLogger('fastreid.visualize.ranking')
logger = logging.getLogger('fastreid')
def setup_cfg(args):
@ -66,7 +66,11 @@ def get_parser():
"--num-vis",
default=100,
help="number of query images to be visualized",
)
parser.add_argument(
"--rank-sort",
default="ascending",
help="rank order of visualization images by AP metric",
)
parser.add_argument(
"--max-rank",
@ -116,5 +120,5 @@ if __name__ == '__main__':
visualizer = Visualizer(test_loader.loader.dataset)
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
logger.info("Saving ranking list result ...")
visualizer.vis_ranking_list(args.output, args.num_vis, max_rank=args.max_rank)
visualizer.vis_ranking_list(args.output, args.num_vis, rank_sort=args.rank_sort, max_rank=args.max_rank)

View File

@ -47,8 +47,8 @@ class Visualizer:
return cmc, sort_idx
def save_rank_result(self, query_indices, output, max_rank=5, actmap=False):
fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 5))
fig.suptitle('query/AP/camid sim/true(false)/camid')
fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6))
# fig.suptitle('query/AP/camid sim/true(false)/camid')
for cnt, q_idx in enumerate(tqdm.tqdm(query_indices)):
all_imgs = []
cmc, sort_idx = self.get_matched_result(q_idx)
@ -59,7 +59,8 @@ class Visualizer:
all_imgs.append(query_img)
query_img = np.rollaxis(np.asarray(query_img.numpy(), dtype=np.uint8), 0, 3)
axes.flat[0].imshow(query_img)
axes.flat[0].set_title('{}/AP:{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
axes.flat[0].set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
axes.flat[0].axis("off")
# print('query' + query_info['img_path'].split('/')[-1])
for i in range(max_rank):
g_idx = self.num_query + sort_idx[i]
@ -81,6 +82,7 @@ class Visualizer:
axes.flat[i + 1].imshow(gallery_img)
# print('/'.join(gallery_info['img_path'].split('/')[-2:]))
axes.flat[i + 1].set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}')
axes.flat[i + 1].axis("off")
# if actmap:
# act_outputs = []
#
@ -99,6 +101,7 @@ class Visualizer:
# acts = self.get_actmap(act_outputs[0], sz)
# for i in range(top + 1):
# axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet')
plt.tight_layout()
filepath = os.path.join(output, "{}.jpg".format(cnt))
fig.savefig(filepath)
plt.cla()