Source code for torchreid.utils.reidtools

from __future__ import absolute_import
from __future__ import print_function

__all__ = ['visualize_ranked_results']

import numpy as np
import os
import os.path as osp
import shutil

from .tools import mkdir_if_missing


[docs]def visualize_ranked_results(distmat, dataset, save_dir='', topk=20): """Visualizes ranked results. Supports both image-reid and video-reid. Args: distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). dataset (tuple): a 2-tuple containing (query, gallery), each of which contains tuples of (img_path(s), pid, camid). save_dir (str): directory to save output images. topk (int, optional): denoting top-k images in the rank list to be visualized. """ num_q, num_g = distmat.shape print('Visualizing top-{} ranks'.format(topk)) print('# query: {}\n# gallery {}'.format(num_q, num_g)) print('Saving images to "{}"'.format(save_dir)) query, gallery = dataset assert num_q == len(query) assert num_g == len(gallery) indices = np.argsort(distmat, axis=1) mkdir_if_missing(save_dir) def _cp_img_to(src, dst, rank, prefix): """ Args: src: image path or tuple (for vidreid) dst: target directory rank: int, denoting ranked position, starting from 1 prefix: string """ if isinstance(src, tuple) or isinstance(src, list): dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) mkdir_if_missing(dst) for img_path in src: shutil.copy(img_path, dst) else: dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) shutil.copy(src, dst) for q_idx in range(num_q): qimg_path, qpid, qcamid = query[q_idx] if isinstance(qimg_path, tuple) or isinstance(qimg_path, list): qdir = osp.join(save_dir, osp.basename(qimg_path[0])) else: qdir = osp.join(save_dir, osp.basename(qimg_path)) mkdir_if_missing(qdir) _cp_img_to(qimg_path, qdir, rank=0, prefix='query') rank_idx = 1 for g_idx in indices[q_idx,:]: gimg_path, gpid, gcamid = gallery[g_idx] invalid = (qpid == gpid) & (qcamid == gcamid) if not invalid: _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery') rank_idx += 1 if rank_idx > topk: break print("Done")