75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
|
|
__all__ = ['visualize_ranked_results']
|
|
|
|
import numpy as np
|
|
import os
|
|
import os.path as osp
|
|
import shutil
|
|
|
|
from .tools import mkdir_if_missing
|
|
|
|
|
|
def visualize_ranked_results(distmat, dataset, save_dir='', topk=20):
|
|
"""Visualizes ranked results.
|
|
|
|
Supports both image-reid and video-reid.
|
|
|
|
Args:
|
|
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
|
|
dataset (tuple): a 2-tuple containing (query, gallery), each of which contains
|
|
tuples of (img_path(s), pid, camid).
|
|
save_dir (str): directory to save output images.
|
|
topk (int, optional): denoting top-k images in the rank list to be visualized.
|
|
"""
|
|
num_q, num_g = distmat.shape
|
|
|
|
print('Visualizing top-{} ranks'.format(topk))
|
|
print('# query: {}\n# gallery {}'.format(num_q, num_g))
|
|
print('Saving images to "{}"'.format(save_dir))
|
|
|
|
query, gallery = dataset
|
|
assert num_q == len(query)
|
|
assert num_g == len(gallery)
|
|
|
|
indices = np.argsort(distmat, axis=1)
|
|
mkdir_if_missing(save_dir)
|
|
|
|
def _cp_img_to(src, dst, rank, prefix):
|
|
"""
|
|
Args:
|
|
src: image path or tuple (for vidreid)
|
|
dst: target directory
|
|
rank: int, denoting ranked position, starting from 1
|
|
prefix: string
|
|
"""
|
|
if isinstance(src, tuple) or isinstance(src, list):
|
|
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
|
|
mkdir_if_missing(dst)
|
|
for img_path in src:
|
|
shutil.copy(img_path, dst)
|
|
else:
|
|
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src))
|
|
shutil.copy(src, dst)
|
|
|
|
for q_idx in range(num_q):
|
|
qimg_path, qpid, qcamid = query[q_idx]
|
|
if isinstance(qimg_path, tuple) or isinstance(qimg_path, list):
|
|
qdir = osp.join(save_dir, osp.basename(qimg_path[0]))
|
|
else:
|
|
qdir = osp.join(save_dir, osp.basename(qimg_path))
|
|
mkdir_if_missing(qdir)
|
|
_cp_img_to(qimg_path, qdir, rank=0, prefix='query')
|
|
|
|
rank_idx = 1
|
|
for g_idx in indices[q_idx,:]:
|
|
gimg_path, gpid, gcamid = gallery[g_idx]
|
|
invalid = (qpid == gpid) & (qcamid == gcamid)
|
|
if not invalid:
|
|
_cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery')
|
|
rank_idx += 1
|
|
if rank_idx > topk:
|
|
break
|
|
|
|
print("Done") |