PyRetri/search/new_reid_search_query.py

120 lines
4.5 KiB
Python
Raw Normal View History

2020-04-02 14:00:49 +08:00
# -*- coding: utf-8 -*-
import json
import importlib
import os
import argparse
from .utils.misc import check_exist, get_dir
from retrieval_tool_box.config import get_defaults_cfg
from retrieval_tool_box.index import build_index_helper, feature_loader
from retrieval_tool_box.evaluate import build_evaluate_helper
fea_names = ["output"]
def load_datasets():
datasets = {
"market_gallery": {
"gallery": "market_gallery",
"query": "market_query",
"train": "market_gallery"
},
"duke_gallery": {
"gallery": "duke_gallery",
"query": "duke_query",
"train": "duke_gallery"
},
}
return datasets
def parse_args():
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
parser.add_argument('--fea_dir', '-fd', default=None, type=str, help="path of feature dirs", required=True)
parser.add_argument("--search_modules", "-sm", default=None, type=str, help="name of search module's directory")
parser.add_argument("--save_path", "-sp", default=None, type=str, help="path for saving results")
args = parser.parse_args()
return args
def main():
# init args
args = parse_args()
assert args.fea_dir is not None, 'the feature directory must be provided!'
assert args.search_modules is not None, 'the search modules must be provided!'
assert args.save_path is not None, 'the save path must be provided!'
# init retrieval pipeline settings
cfg = get_defaults_cfg()
# load search space
datasets = load_datasets()
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
if os.path.exists(args.save_path):
with open(args.save_path, "r") as f:
results = json.load(f)
else:
results = list()
for dir in os.listdir(args.fea_dir):
for data_name, data_args in datasets.items():
for query_name, query_args in queries.items():
if data_name in dir:
# get dirs
gallery_fea_dir, query_fea_dir, train_fea_dir = get_dir(args.fea_dir, dir, data_args)
# get evaluate setting
evaluate_args = evaluates["reid_overall"]
for dim_proc in query_args.dim_processors.names:
if dim_proc in ["PartPCA", "PartSVD", "PCA", "SVD"]:
query_args.dim_processors[dim_proc].train_fea_dir = train_fea_dir
for fea_name in fea_names:
result_dict = get_default_result_dict(dir, data_name, query_name, fea_name)
if check_exist(result_dict, results):
print("[Search Query]: config exists...")
continue
print(data_name + '_' + fea_name + '_' + query_name)
# load retrieval pipeline settings
query_args.feature_names = [fea_name]
cfg.index.merge_from_other_cfg(query_args)
cfg.evaluate.merge_from_other_cfg(evaluate_args)
# load features
query_fea, query_info, _ = feature_loader.load(query_fea_dir, [fea_name])
gallery_fea, gallery_info, _ = feature_loader.load(gallery_fea_dir, [fea_name])
# build helper and index features
query_helper = build_index_helper(cfg.index)
query_result_info, _, _ = query_helper.do_index(query_fea, query_info, gallery_fea)
# build helper and evaluate results
evaluate_helper = build_evaluate_helper(cfg.evaluate)
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info, gallery_info)
# record results
to_save_recall = dict()
for k in recall_at_k:
to_save_recall[str(k)] = recall_at_k[k]
result_dict["mAP"] = float(mAP)
result_dict["recall_at_k"] = to_save_recall
results.append(result_dict)
# save results
with open(args.save_path, "w") as f:
json.dump(results, f)
if __name__ == '__main__':
main()