feat: support visualizing label list

add features to support label list visualization, which can be used
for label correction or check the hardest sample
pull/51/head
liaoxingyu 2020-05-12 21:35:33 +08:00
parent 9b6fda3830
commit 9addfb0ae2
3 changed files with 71 additions and 33 deletions

View File

@ -1,6 +1,4 @@
python demo/visualize_ranking.py --config-file 'logs/market1501/sbs_R50/config.yaml' \
--parallel \
--dataset-name 'Market1501' \
--output 'logs/test_ranking' \
--parallel --vis-label --dataset-name 'Market1501' --output 'logs/test_ranking' \
--opts MODEL.WEIGHTS 'logs/market1501/sbs_R50/model_final.pth'

View File

@ -8,6 +8,7 @@ import argparse
import logging
import tqdm
import sys
import os
import numpy as np
import torch
@ -62,6 +63,11 @@ def get_parser():
help="a file or directory to save rankling list result.",
)
parser.add_argument(
"--vis-label",
action='store_true',
help="if visualize label of query instance"
)
parser.add_argument(
"--num-vis",
default=100,
@ -72,6 +78,11 @@ def get_parser():
default="ascending",
help="rank order of visualization images by AP metric",
)
parser.add_argument(
"--label-sort",
default="ascending",
help="label order of visualization images by cosine similarity metric",
)
parser.add_argument(
"--max-rank",
default=10,
@ -115,10 +126,10 @@ if __name__ == '__main__':
distmat = distmat.numpy()
logger.info("Computing APs for all query images ...")
cmc, all_ap, all_inp = evaluate_rank(1-distmat, q_pids, g_pids, q_camids, g_camids)
cmc, all_ap, all_inp = evaluate_rank(1 - distmat, q_pids, g_pids, q_camids, g_camids)
visualizer = Visualizer(test_loader.loader.dataset)
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
logger.info("Saving ranking list result ...")
visualizer.vis_ranking_list(args.output, args.num_vis, rank_sort=args.rank_sort, max_rank=args.max_rank)
logger.info("Saving rank list result ...")
query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
args.rank_sort, args.label_sort, args.max_rank)

View File

@ -7,7 +7,6 @@
import os
import cv2
import matplotlib.figure as mplfigure
import matplotlib.pyplot as plt
import numpy as np
import tqdm
@ -46,9 +45,12 @@ class Visualizer:
sort_idx = order[keep]
return cmc, sort_idx
def save_rank_result(self, query_indices, output, max_rank=5, actmap=False):
fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6))
# fig.suptitle('query/AP/camid sim/true(false)/camid')
def save_rank_result(self, query_indices, output, max_rank=5, vis_label=False, label_sort='ascending',
actmap=False):
if vis_label:
fig, axes = plt.subplots(2, max_rank + 1, figsize=(3 * max_rank, 12))
else:
fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6))
for cnt, q_idx in enumerate(tqdm.tqdm(query_indices)):
all_imgs = []
cmc, sort_idx = self.get_matched_result(q_idx)
@ -58,11 +60,16 @@ class Visualizer:
query_name = query_info['img_path'].split('/')[-1]
all_imgs.append(query_img)
query_img = np.rollaxis(np.asarray(query_img.numpy(), dtype=np.uint8), 0, 3)
axes.flat[0].imshow(query_img)
axes.flat[0].set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
axes.flat[0].axis("off")
# print('query' + query_info['img_path'].split('/')[-1])
plt.clf()
ax = fig.add_subplot(1, max_rank + 1, 1)
ax.imshow(query_img)
ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
ax.axis("off")
for i in range(max_rank):
if vis_label:
ax = fig.add_subplot(2, max_rank + 1, i + 2)
else:
ax = fig.add_subplot(1, max_rank + 1, i + 2)
g_idx = self.num_query + sort_idx[i]
gallery_info = self.dataset[g_idx]
gallery_img = gallery_info['images']
@ -71,18 +78,17 @@ class Visualizer:
gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
if cmc[i] == 1:
label = 'true'
axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
height=gallery_img.shape[0] - 1, edgecolor=(1, 0, 0),
fill=False, linewidth=5))
ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
height=gallery_img.shape[0] - 1, edgecolor=(1, 0, 0),
fill=False, linewidth=5))
else:
label = 'false'
axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
height=gallery_img.shape[0] - 1,
edgecolor=(0, 0, 1), fill=False, linewidth=5))
axes.flat[i + 1].imshow(gallery_img)
# print('/'.join(gallery_info['img_path'].split('/')[-2:]))
axes.flat[i + 1].set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}')
axes.flat[i + 1].axis("off")
ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
height=gallery_img.shape[0] - 1,
edgecolor=(0, 0, 1), fill=False, linewidth=5))
ax.imshow(gallery_img)
ax.set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}')
ax.axis("off")
# if actmap:
# act_outputs = []
#
@ -101,19 +107,42 @@ class Visualizer:
# acts = self.get_actmap(act_outputs[0], sz)
# for i in range(top + 1):
# axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet')
if vis_label:
label_indice = np.where(cmc == 1)[0]
if label_sort == "ascending": label_indice = label_indice[::-1]
label_indice = label_indice[:max_rank]
for i in range(max_rank):
if i >= len(label_indice): break
j = label_indice[i]
g_idx = self.num_query + sort_idx[j]
gallery_info = self.dataset[g_idx]
gallery_img = gallery_info['images']
cam_id = gallery_info['camid']
gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
ax = fig.add_subplot(2, max_rank + 1, max_rank + 3 + i)
ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
height=gallery_img.shape[0] - 1,
edgecolor=(1, 0, 0),
fill=False, linewidth=5))
ax.imshow(gallery_img)
ax.set_title(f'{self.sim[q_idx, sort_idx[j]]:.3f}/cam{cam_id}')
ax.axis("off")
plt.tight_layout()
filepath = os.path.join(output, "{}.jpg".format(cnt))
fig.savefig(filepath)
plt.cla()
def vis_ranking_list(self, output, num_vis=100, rank_sort='ascending', max_rank=5, actmap=False):
"""
def vis_rank_list(self, output, vis_label, num_vis=100, rank_sort="ascending", label_sort="ascending", max_rank=5,
actmap=False):
r"""Visualize rank list of query instance
Args:
output (str): a file or directory to save rankling list result.
output (str): a directory to save rank list result.
vis_label (bool): if visualize label of query
num_vis (int):
rank_sort (str): save visualization results by which order,
if rank_sort is ascending, AP from low to high, vice versa.
num_vis (int):
max_rank (int):
label_sort (bool):
max_rank (int): maximum number of rank result to visualize
actmap (bool):
"""
assert rank_sort in ['ascending', 'descending'], "{} not match [ascending, descending]".format(rank_sort)
@ -123,7 +152,7 @@ class Visualizer:
if rank_sort == 'descending': query_indices = query_indices[::-1]
query_indices = query_indices[:num_vis]
self.save_rank_result(query_indices, output, max_rank, actmap)
self.save_rank_result(query_indices, output, max_rank, vis_label, label_sort, actmap)
def plot_roc_curve(self):
pos_sim, neg_sim = [], []
@ -155,7 +184,7 @@ class Visualizer:
same_cam.extend(self.sim[i, sameCam_idx])
diff_cam.extend(self.sim[i, diffCam_idx])
fig = mplfigure(figsize=(10, 5))
fig = plt.figure(figsize=(10, 5))
plt.hist(same_cam, bins=80, alpha=0.7, density=True, color='red', label='same camera')
plt.hist(diff_cam, bins=80, alpha=0.5, density=True, color='blue', label='diff camera')
plt.xticks(np.arange(0.1, 1.0, 0.1))