mirror of https://github.com/JDAI-CV/fast-reid.git
feat: support visualizing label list
add features to support label list visualization, which can be used for label correction or check the hardest samplepull/51/head
parent
9b6fda3830
commit
9addfb0ae2
|
@ -1,6 +1,4 @@
|
|||
python demo/visualize_ranking.py --config-file 'logs/market1501/sbs_R50/config.yaml' \
|
||||
--parallel \
|
||||
--dataset-name 'Market1501' \
|
||||
--output 'logs/test_ranking' \
|
||||
--parallel --vis-label --dataset-name 'Market1501' --output 'logs/test_ranking' \
|
||||
--opts MODEL.WEIGHTS 'logs/market1501/sbs_R50/model_final.pth'
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import argparse
|
|||
import logging
|
||||
import tqdm
|
||||
import sys
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -62,6 +63,11 @@ def get_parser():
|
|||
help="a file or directory to save rankling list result.",
|
||||
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vis-label",
|
||||
action='store_true',
|
||||
help="if visualize label of query instance"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-vis",
|
||||
default=100,
|
||||
|
@ -72,6 +78,11 @@ def get_parser():
|
|||
default="ascending",
|
||||
help="rank order of visualization images by AP metric",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label-sort",
|
||||
default="ascending",
|
||||
help="label order of visualization images by cosine similarity metric",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-rank",
|
||||
default=10,
|
||||
|
@ -115,10 +126,10 @@ if __name__ == '__main__':
|
|||
distmat = distmat.numpy()
|
||||
|
||||
logger.info("Computing APs for all query images ...")
|
||||
cmc, all_ap, all_inp = evaluate_rank(1-distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
cmc, all_ap, all_inp = evaluate_rank(1 - distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
|
||||
visualizer = Visualizer(test_loader.loader.dataset)
|
||||
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
logger.info("Saving ranking list result ...")
|
||||
visualizer.vis_ranking_list(args.output, args.num_vis, rank_sort=args.rank_sort, max_rank=args.max_rank)
|
||||
|
||||
logger.info("Saving rank list result ...")
|
||||
query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
|
||||
args.rank_sort, args.label_sort, args.max_rank)
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import os
|
||||
|
||||
import cv2
|
||||
import matplotlib.figure as mplfigure
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
@ -46,9 +45,12 @@ class Visualizer:
|
|||
sort_idx = order[keep]
|
||||
return cmc, sort_idx
|
||||
|
||||
def save_rank_result(self, query_indices, output, max_rank=5, actmap=False):
|
||||
fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6))
|
||||
# fig.suptitle('query/AP/camid sim/true(false)/camid')
|
||||
def save_rank_result(self, query_indices, output, max_rank=5, vis_label=False, label_sort='ascending',
|
||||
actmap=False):
|
||||
if vis_label:
|
||||
fig, axes = plt.subplots(2, max_rank + 1, figsize=(3 * max_rank, 12))
|
||||
else:
|
||||
fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6))
|
||||
for cnt, q_idx in enumerate(tqdm.tqdm(query_indices)):
|
||||
all_imgs = []
|
||||
cmc, sort_idx = self.get_matched_result(q_idx)
|
||||
|
@ -58,11 +60,16 @@ class Visualizer:
|
|||
query_name = query_info['img_path'].split('/')[-1]
|
||||
all_imgs.append(query_img)
|
||||
query_img = np.rollaxis(np.asarray(query_img.numpy(), dtype=np.uint8), 0, 3)
|
||||
axes.flat[0].imshow(query_img)
|
||||
axes.flat[0].set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
|
||||
axes.flat[0].axis("off")
|
||||
# print('query' + query_info['img_path'].split('/')[-1])
|
||||
plt.clf()
|
||||
ax = fig.add_subplot(1, max_rank + 1, 1)
|
||||
ax.imshow(query_img)
|
||||
ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
|
||||
ax.axis("off")
|
||||
for i in range(max_rank):
|
||||
if vis_label:
|
||||
ax = fig.add_subplot(2, max_rank + 1, i + 2)
|
||||
else:
|
||||
ax = fig.add_subplot(1, max_rank + 1, i + 2)
|
||||
g_idx = self.num_query + sort_idx[i]
|
||||
gallery_info = self.dataset[g_idx]
|
||||
gallery_img = gallery_info['images']
|
||||
|
@ -71,18 +78,17 @@ class Visualizer:
|
|||
gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
|
||||
if cmc[i] == 1:
|
||||
label = 'true'
|
||||
axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
|
||||
height=gallery_img.shape[0] - 1, edgecolor=(1, 0, 0),
|
||||
fill=False, linewidth=5))
|
||||
ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
|
||||
height=gallery_img.shape[0] - 1, edgecolor=(1, 0, 0),
|
||||
fill=False, linewidth=5))
|
||||
else:
|
||||
label = 'false'
|
||||
axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
|
||||
height=gallery_img.shape[0] - 1,
|
||||
edgecolor=(0, 0, 1), fill=False, linewidth=5))
|
||||
axes.flat[i + 1].imshow(gallery_img)
|
||||
# print('/'.join(gallery_info['img_path'].split('/')[-2:]))
|
||||
axes.flat[i + 1].set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}')
|
||||
axes.flat[i + 1].axis("off")
|
||||
ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
|
||||
height=gallery_img.shape[0] - 1,
|
||||
edgecolor=(0, 0, 1), fill=False, linewidth=5))
|
||||
ax.imshow(gallery_img)
|
||||
ax.set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}')
|
||||
ax.axis("off")
|
||||
# if actmap:
|
||||
# act_outputs = []
|
||||
#
|
||||
|
@ -101,19 +107,42 @@ class Visualizer:
|
|||
# acts = self.get_actmap(act_outputs[0], sz)
|
||||
# for i in range(top + 1):
|
||||
# axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet')
|
||||
if vis_label:
|
||||
label_indice = np.where(cmc == 1)[0]
|
||||
if label_sort == "ascending": label_indice = label_indice[::-1]
|
||||
label_indice = label_indice[:max_rank]
|
||||
for i in range(max_rank):
|
||||
if i >= len(label_indice): break
|
||||
j = label_indice[i]
|
||||
g_idx = self.num_query + sort_idx[j]
|
||||
gallery_info = self.dataset[g_idx]
|
||||
gallery_img = gallery_info['images']
|
||||
cam_id = gallery_info['camid']
|
||||
gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
|
||||
ax = fig.add_subplot(2, max_rank + 1, max_rank + 3 + i)
|
||||
ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
|
||||
height=gallery_img.shape[0] - 1,
|
||||
edgecolor=(1, 0, 0),
|
||||
fill=False, linewidth=5))
|
||||
ax.imshow(gallery_img)
|
||||
ax.set_title(f'{self.sim[q_idx, sort_idx[j]]:.3f}/cam{cam_id}')
|
||||
ax.axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
filepath = os.path.join(output, "{}.jpg".format(cnt))
|
||||
fig.savefig(filepath)
|
||||
plt.cla()
|
||||
|
||||
def vis_ranking_list(self, output, num_vis=100, rank_sort='ascending', max_rank=5, actmap=False):
|
||||
"""
|
||||
def vis_rank_list(self, output, vis_label, num_vis=100, rank_sort="ascending", label_sort="ascending", max_rank=5,
|
||||
actmap=False):
|
||||
r"""Visualize rank list of query instance
|
||||
Args:
|
||||
output (str): a file or directory to save rankling list result.
|
||||
output (str): a directory to save rank list result.
|
||||
vis_label (bool): if visualize label of query
|
||||
num_vis (int):
|
||||
rank_sort (str): save visualization results by which order,
|
||||
if rank_sort is ascending, AP from low to high, vice versa.
|
||||
num_vis (int):
|
||||
max_rank (int):
|
||||
label_sort (bool):
|
||||
max_rank (int): maximum number of rank result to visualize
|
||||
actmap (bool):
|
||||
"""
|
||||
assert rank_sort in ['ascending', 'descending'], "{} not match [ascending, descending]".format(rank_sort)
|
||||
|
@ -123,7 +152,7 @@ class Visualizer:
|
|||
if rank_sort == 'descending': query_indices = query_indices[::-1]
|
||||
|
||||
query_indices = query_indices[:num_vis]
|
||||
self.save_rank_result(query_indices, output, max_rank, actmap)
|
||||
self.save_rank_result(query_indices, output, max_rank, vis_label, label_sort, actmap)
|
||||
|
||||
def plot_roc_curve(self):
|
||||
pos_sim, neg_sim = [], []
|
||||
|
@ -155,7 +184,7 @@ class Visualizer:
|
|||
same_cam.extend(self.sim[i, sameCam_idx])
|
||||
diff_cam.extend(self.sim[i, diffCam_idx])
|
||||
|
||||
fig = mplfigure(figsize=(10, 5))
|
||||
fig = plt.figure(figsize=(10, 5))
|
||||
plt.hist(same_cam, bins=80, alpha=0.7, density=True, color='red', label='same camera')
|
||||
plt.hist(diff_cam, bins=80, alpha=0.5, density=True, color='blue', label='diff camera')
|
||||
plt.xticks(np.arange(0.1, 1.0, 0.1))
|
||||
|
|
Loading…
Reference in New Issue