From 2383698d520875c90e4f21a4fbcf55ef11c125c6 Mon Sep 17 00:00:00 2001 From: Abner Ayala-Acevedo <49654624+avn3r-dn@users.noreply.github.com> Date: Fri, 23 Aug 2019 12:12:32 -0700 Subject: [PATCH] Fix Visual Rank Video Reid Error Fix video reid error while making sure it still works for image reid. When trying to save visual ranks for video reid. It crashes on statement `qdir = osp.join(save_dir, osp.basename(osp.splitext(qimg_path)[0]))` that is because qimg_path can be a tuple or list for video reid. Hence, osp.splittext gives error when input is not a str. To fix this I make sure that osp.splitext always gets pass an str regardless if we doing video or image reid. Let me know if you have any questions. --- torchreid/utils/reidtools.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchreid/utils/reidtools.py b/torchreid/utils/reidtools.py index 0573843..0855ad2 100644 --- a/torchreid/utils/reidtools.py +++ b/torchreid/utils/reidtools.py @@ -77,7 +77,8 @@ def visualize_ranked_results(distmat, dataset, data_type, width=128, height=256, qimg_path, qpid, qcamid = query[q_idx] num_cols = topk + 1 grid_img = 255 * np.ones((height, num_cols*width+topk*GRID_SPACING+QUERY_EXTRA_SPACING, 3), dtype=np.uint8) - + qimg_path_name = qimg_path[0] if isinstance(qimg_path, tuple) or isinstance(qimg_path, list) else qimg_path + if data_type == 'image': qimg = cv2.imread(qimg_path) qimg = cv2.resize(qimg, (width, height)) @@ -85,7 +86,7 @@ def visualize_ranked_results(distmat, dataset, data_type, width=128, height=256, qimg = cv2.resize(qimg, (width, height)) # resize twice to ensure that the border width is consistent across images grid_img[:, :width, :] = qimg else: - qdir = osp.join(save_dir, osp.basename(osp.splitext(qimg_path)[0])) + qdir = osp.join(save_dir, osp.basename(osp.splitext(qimg_path_name)[0])) mkdir_if_missing(qdir) _cp_img_to(qimg_path, qdir, rank=0, prefix='query') @@ -112,7 +113,7 @@ def visualize_ranked_results(distmat, dataset, data_type, width=128, height=256, if rank_idx > topk: break - imname = osp.basename(osp.splitext(qimg_path)[0]) + imname = osp.basename(osp.splitext(qimg_path_name)[0]) cv2.imwrite(osp.join(save_dir, imname+'.jpg'), grid_img) if (q_idx+1) % 100 == 0: