deep-person-reid/torchreid/utils/reidtools.py

117 lines
4.2 KiB
Python
Raw Normal View History

2018-08-01 17:51:42 +08:00
from __future__ import absolute_import
from __future__ import print_function
2019-03-20 01:26:08 +08:00
__all__ = ['visualize_ranked_results']
2018-08-01 17:51:42 +08:00
import numpy as np
import os
import os.path as osp
import shutil
import cv2
from matplotlib import pyplot as plt
2018-08-01 17:51:42 +08:00
2019-03-20 01:26:08 +08:00
from .tools import mkdir_if_missing
2018-08-01 17:51:42 +08:00
PLOT_FONT_SIZE = 3
def visualize_ranked_results(distmat, dataset, data_type, width=128, height=256, save_dir='', topk=20):
"""Visualizes ranked results.
2018-08-01 17:51:42 +08:00
Supports both image-reid and video-reid.
2018-08-01 17:51:42 +08:00
For image-reid, ranks will be plotted in a single figure. For video-reid, ranks will be
saved in folders each containing a tracklet.
2018-08-01 17:51:42 +08:00
Args:
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
dataset (tuple): a 2-tuple containing (query, gallery), each of which contains
tuples of (img_path(s), pid, camid).
data_type (str): "image" or "video".
width (int, optional): resized image width. Default is 128.
height (int, optional): resized image height. Default is 256.
save_dir (str): directory to save output images.
topk (int, optional): denoting top-k images in the rank list to be visualized.
2018-08-01 17:51:42 +08:00
"""
num_q, num_g = distmat.shape
2019-01-31 06:41:47 +08:00
print('# query: {}\n# gallery {}'.format(num_q, num_g))
2019-08-04 00:49:19 +08:00
print('Visualizing top-{} ranks ...'.format(topk))
2018-08-01 17:51:42 +08:00
2018-11-08 05:45:48 +08:00
query, gallery = dataset
assert num_q == len(query)
assert num_g == len(gallery)
2018-08-01 17:51:42 +08:00
indices = np.argsort(distmat, axis=1)
mkdir_if_missing(save_dir)
def _cp_img_to(src, dst, rank, prefix, matched=False):
2018-08-01 17:51:42 +08:00
"""
2019-03-16 07:17:38 +08:00
Args:
src: image path or tuple (for vidreid)
dst: target directory
rank: int, denoting ranked position, starting from 1
prefix: string
matched: bool
2018-08-01 17:51:42 +08:00
"""
if isinstance(src, tuple) or isinstance(src, list):
if prefix == 'gallery':
suffix = 'TRUE' if matched else 'FALSE'
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) + '_' + suffix
else:
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
2018-08-01 17:51:42 +08:00
mkdir_if_missing(dst)
for img_path in src:
shutil.copy(img_path, dst)
else:
2018-08-01 18:18:11 +08:00
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src))
2018-08-01 17:51:42 +08:00
shutil.copy(src, dst)
for q_idx in range(num_q):
2018-11-08 05:45:48 +08:00
qimg_path, qpid, qcamid = query[q_idx]
if data_type == 'image':
qimg = cv2.imread(qimg_path)
qimg = cv2.resize(qimg, (width, height))
fig = plt.figure()
fig.add_subplot(1, topk+1, 1) # totally 1 query and topk gallery
plt.axis('off')
plt.title('Query', fontsize=PLOT_FONT_SIZE)
plt.imshow(qimg)
2018-12-27 18:25:09 +08:00
else:
qdir = osp.join(save_dir, osp.basename(osp.splitext(qimg_path)[0]))
mkdir_if_missing(qdir)
_cp_img_to(qimg_path, qdir, rank=0, prefix='query')
2018-08-01 17:51:42 +08:00
2018-08-01 18:18:11 +08:00
rank_idx = 1
for g_idx in indices[q_idx,:]:
2018-11-08 05:45:48 +08:00
gimg_path, gpid, gcamid = gallery[g_idx]
2018-08-01 18:18:11 +08:00
invalid = (qpid == gpid) & (qcamid == gcamid)
2018-08-01 18:18:11 +08:00
if not invalid:
if data_type == 'image':
gimg = cv2.imread(gimg_path)
gimg = cv2.resize(gimg, (width, height))
fig.add_subplot(1, topk+1, rank_idx+1)
plt.axis('off')
title_color = 'green' if gpid == qpid else 'red'
plt.title('Rank-'+str(rank_idx), fontsize=PLOT_FONT_SIZE, color=title_color)
plt.imshow(gimg)
else:
_cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery')
2018-08-01 18:18:11 +08:00
rank_idx += 1
if rank_idx > topk:
break
if data_type == 'image':
imname = osp.basename(osp.splitext(qimg_path)[0])
fig.savefig(osp.join(save_dir, imname+'.pdf'), bbox_inches='tight')
plt.close()
2019-08-04 00:49:19 +08:00
if (q_idx+1) % 100 == 0:
print('- done {}/{}'.format(q_idx+1, num_q))
print('Done. Images have been saved to "{}" ...'.format(save_dir))