PyRetri/pyretri/index/helper/helper.py

108 lines
3.6 KiB
Python

# -*- coding: utf-8 -*-
import os
import shutil
import torch
import numpy as np
from ..dim_processor import DimProcessorBase
from ..feature_enhancer import EnhanceBase
from ..metric import MetricBase
from ..re_ranker import ReRankerBase
from ..utils import feature_loader
import matplotlib.pyplot as plt
from typing import Dict, List
class IndexHelper:
"""
A helper class to index features.
"""
def __init__(
self,
dim_processors: List[DimProcessorBase],
feature_enhancer: EnhanceBase,
metric: MetricBase,
re_ranker: ReRankerBase,
):
"""
Args:
dim_processors (list):
feature_enhancer (EnhanceBase):
metric (MetricBase):
re_ranker (ReRankerBase):
"""
self.dim_procs = dim_processors
self.feature_enhance = feature_enhancer
self.metric = metric
self.re_rank = re_ranker
def show_topk_retrieved_images(self, single_query_info: Dict, topk: int, gallery_info: List[Dict]) -> None:
"""
Show the top-k retrieved images of one query.
Args:
single_query_info (dict): a dict of single query information.
topk (int): number of the nearest images to be showed.
gallery_info (list): a list of gallery set information.
"""
query_idx = single_query_info["ranked_neighbors_idx"]
query_topk_idx = query_idx[:topk]
for idx in query_topk_idx:
img_path = gallery_info[idx]["path"]
plt.figure()
plt.imshow(img_path)
plt.show()
def save_topk_retrieved_images(self, save_path: str, single_query_info: Dict, topk: int, gallery_info: List[Dict]) -> None:
"""
Save the top-k retrieved images of one query.
Args:
save_path (str): the path to save the retrieved images.
single_query_info (dict): a dict of single query information.
topk (int): number of the nearest images to be saved.
gallery_info (list): a list of gallery set information.
"""
query_idx = single_query_info["ranked_neighbors_idx"]
query_topk_idx = query_idx[:topk]
for idx in query_topk_idx:
img_path = gallery_info[idx]["path"]
shutil.copy(img_path, os.path.join(save_path, str(idx)+'.png'))
def do_index(self, query_fea: np.ndarray, query_info: List, gallery_fea: np.ndarray) -> (List, np.ndarray, np.ndarray):
"""
Index the query features.
Args:
query_fea (np.ndarray): query set features.
query_info (list): a list of gallery set information.
gallery_fea (np.ndarray): gallery set features.
Returns:
tuple(List, np.ndarray, np.ndarray): query feature information, query features and gallery features after process.
"""
for dim_proc in self.dim_procs:
query_fea, gallery_fea = dim_proc(query_fea), dim_proc(gallery_fea)
query_fea, gallery_fea = torch.Tensor(query_fea), torch.Tensor(gallery_fea)
# if torch.cuda.is_available():
# query_fea = query_fea.cuda()
# gallery_fea = gallery_fea.cuda()
gallery_fea = self.feature_enhance(gallery_fea)
dis, sorted_index = self.metric(query_fea, gallery_fea)
sorted_index = self.re_rank(query_fea, gallery_fea, dis=dis, sorted_index=sorted_index)
for i, info in enumerate(query_info):
info["ranked_neighbors_idx"] = sorted_index[i].tolist()
return query_info, query_fea, gallery_fea