mirror of https://github.com/PyRetri/PyRetri.git
168 lines
6.0 KiB
Python
168 lines
6.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import json
|
|
import importlib
|
|
import os
|
|
import argparse
|
|
|
|
from pyretri.config import get_defaults_cfg
|
|
from pyretri.query import build_query_helper
|
|
from pyretri.evaluate import build_evaluate_helper
|
|
from pyretri.index import feature_loader
|
|
|
|
|
|
vgg_fea = ["pool5_PWA"]
|
|
res_fea = ["pool5_PWA"]
|
|
|
|
task_mapping = {
|
|
"oxford_gallery": {
|
|
"gallery": "oxford_gallery",
|
|
"query": "oxford_query",
|
|
"train_fea_dir": "paris"
|
|
},
|
|
"cub_gallery": {
|
|
"gallery": "cub_gallery",
|
|
"query": "cub_query",
|
|
"train_fea_dir": "cub_gallery"
|
|
},
|
|
"indoor_gallery": {
|
|
"gallery": "indoor_gallery",
|
|
"query": "indoor_query",
|
|
"train_fea_dir": "indoor_gallery"
|
|
},
|
|
"caltech101_gallery": {
|
|
"gallery": "caltech101_gallery",
|
|
"query": "caltech101_query",
|
|
"train_fea_dir": "caltech101_gallery"
|
|
}
|
|
}
|
|
|
|
|
|
def check_exist(now_res, exist_results):
|
|
for e_r in exist_results:
|
|
totoal_equal = True
|
|
for key in now_res:
|
|
if now_res[key] != e_r[key]:
|
|
totoal_equal = False
|
|
break
|
|
if totoal_equal:
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_default_result_dict(dir, task_name, query_name, fea_name):
|
|
result_dict = {
|
|
"task_name": task_name.split("_")[0],
|
|
"dataprocess": dir.split("_")[0],
|
|
"model_name": "_".join(dir.split("_")[-2:]),
|
|
"feature_map_name": fea_name.split("_")[0],
|
|
"fea_process_name": query_name
|
|
}
|
|
|
|
if fea_name == "fc":
|
|
result_dict["aggregator_name"] = "none"
|
|
else:
|
|
result_dict["aggregator_name"] = fea_name.split("_")[1]
|
|
|
|
return result_dict
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
|
parser.add_argument('--fea_dir', '-fd', default=None, type=str, help="path of feature dirs", required=True)
|
|
parser.add_argument("--search_modules", "-sm", default=None, type=str, help="name of search module's directory")
|
|
parser.add_argument("--save_path", "-sp", default=None, type=str, help="path for saving results")
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
|
|
# init args
|
|
args = parse_args()
|
|
assert args.fea_dir is not None, 'the feature directory must be provided!'
|
|
assert args.search_modules is not None, 'the search modules must be provided!'
|
|
assert args.save_path is not None, 'the save path must be provided!'
|
|
|
|
# init retrieval pipeline settings
|
|
cfg = get_defaults_cfg()
|
|
|
|
# load search space
|
|
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
|
|
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
|
|
|
|
if os.path.exists(args.save_path):
|
|
with open(args.save_path, "r") as f:
|
|
results = json.load(f)
|
|
else:
|
|
results = list()
|
|
|
|
q_cnt = 0
|
|
for dir in os.listdir(args.fea_dir):
|
|
q_cnt += 1
|
|
print("Processing {} / {} queries...".format(q_cnt, len(queries)))
|
|
for query_name, query_args in queries.items():
|
|
for task_name in task_mapping:
|
|
if task_name in dir:
|
|
|
|
if "vgg" in gallery_fea_dir:
|
|
fea_names = vgg_fea
|
|
else:
|
|
fea_names = res_fea
|
|
|
|
for fea_name in fea_names:
|
|
gallery_fea_dir = os.path.join(args.fea_dir, dir)
|
|
query_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["query"])
|
|
train_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["train_fea_dir"])
|
|
|
|
for post_proc in ["PartPCA", "PartSVD"]:
|
|
if post_proc in query_args.post_processors.names:
|
|
query_args.post_processors[post_proc].train_fea_dir = train_fea_dir
|
|
|
|
query.gallery_fea_dir, query.query_fea_dir = gallery_fea_dir, query_fea_dir
|
|
|
|
query.feature_names = [fea_name]
|
|
if task_name == "oxford_base":
|
|
evaluate = evaluates["oxford_overall"]
|
|
else:
|
|
evaluate = evaluates["overall"]
|
|
|
|
result_dict = get_default_result_dict(dir, task_name, query_name, fea_name)
|
|
|
|
if check_exist(result_dict, results):
|
|
print("[Search Query]: config exists...")
|
|
continue
|
|
|
|
# load retrieval pipeline settings
|
|
cfg.query.merge_from_other_cfg(query)
|
|
cfg.evaluate.merge_from_other_cfg(evaluate)
|
|
|
|
# load features
|
|
query_fea, query_info, _ = feature_loader.load(cfg.query.query_fea_dir, cfg.query.feature_names)
|
|
gallery_fea, gallery_info, _ = feature_loader.load(cfg.query.gallery_fea_dir,
|
|
cfg.query.feature_names)
|
|
|
|
# build helper and index features
|
|
query_helper = build_query_helper(cfg.query)
|
|
query_result_info, _, _ = query_helper.do_query(query_fea, query_info, gallery_fea)
|
|
|
|
# build helper and evaluate results
|
|
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
|
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info, gallery_info)
|
|
|
|
# save results
|
|
to_save_dict = dict()
|
|
for k in recall_at_k:
|
|
to_save_dict[str(k)] = recall_at_k[k]
|
|
result_dict["mAP"] = float(mAP)
|
|
result_dict["recall_at_k"] = to_save_dict
|
|
|
|
results.append(result_dict)
|
|
with open(args.save_path, "w") as f:
|
|
json.dump(results, f)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|