PyRetri/search/search_index.py

163 lines
5.8 KiB
Python
Raw Permalink Normal View History

2020-04-02 14:00:49 +08:00
# -*- coding: utf-8 -*-
import json
import importlib
import os
import argparse
2020-04-17 20:31:26 +08:00
from utils.misc import check_result_exist, get_dir, get_default_result_dict
2020-04-02 14:00:49 +08:00
2020-04-15 14:44:22 +08:00
from pyretri.config import get_defaults_cfg
from pyretri.index import build_index_helper, feature_loader
from pyretri.evaluate import build_evaluate_helper
2020-04-02 14:00:49 +08:00
2020-04-17 20:31:26 +08:00
# # gap, gmp, gem, spoc, crow
# vgg_fea = ["pool4_GAP", "pool4_GMP", "pool4_GeM", "pool4_SPoC", "pool4_Crow",
# "pool5_GAP", "pool5_GMP", "pool5_GeM", "pool5_SPoC", "pool5_Crow",
# "fc"]
# res_fea = ["pool3_GAP", "pool3_GMP", "pool3_GeM", "pool3_SPoC", "pool4_Crow",
# "pool4_GAP", "pool4_GMP", "pool4_GeM", "pool4_SPoC", "pool4_Crow",
# "pool5_GAP", "pool5_GMP", "pool5_GeM", "pool5_SPoC", "pool5_Crow"]
# # scda, rmca
# vgg_fea = ["pool5_SCDA", "pool5_RMAC"]
# res_fea = ["pool5_SCDA", "pool5_RMAC"]
# pwa
2020-04-02 14:00:49 +08:00
vgg_fea = ["pool5_PWA"]
res_fea = ["pool5_PWA"]
def load_datasets():
datasets = {
"oxford_gallery": {
"gallery": "oxford_gallery",
"query": "oxford_query",
"train": "paris_all"
},
"cub_gallery": {
"gallery": "cub_gallery",
"query": "cub_query",
"train": "cub_gallery"
},
"indoor_gallery": {
"gallery": "indoor_gallery",
"query": "indoor_query",
"train": "indoor_gallery"
},
"caltech_gallery": {
"gallery": "caltech_gallery",
"query": "caltech_query",
"train": "caltech_gallery"
}
}
return datasets
def get_evaluate(fea_dir, evaluates):
if "oxford" in fea_dir:
evaluate = evaluates["oxford_overall"]
else:
evaluate = evaluates["overall"]
return evaluate
def get_fea_names(fea_dir):
if "vgg" in fea_dir:
fea_names = vgg_fea
else:
fea_names = res_fea
return fea_names
def parse_args():
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
parser.add_argument('--fea_dir', '-fd', default=None, type=str, help="path of feature dirs", required=True)
parser.add_argument("--search_modules", "-sm", default=None, type=str, help="name of search module's directory")
parser.add_argument("--save_path", "-sp", default=None, type=str, help="path for saving results")
args = parser.parse_args()
return args
def main():
# init args
args = parse_args()
assert args.fea_dir is not None, 'the feature directory must be provided!'
assert args.search_modules is not None, 'the search modules must be provided!'
assert args.save_path is not None, 'the save path must be provided!'
# init retrieval pipeline settings
cfg = get_defaults_cfg()
# load search space
datasets = load_datasets()
2020-04-17 20:31:26 +08:00
indexes = importlib.import_module("{}.index_dict".format(args.search_modules)).indexes
evaluates = importlib.import_module("{}.index_dict".format(args.search_modules)).evaluates
2020-04-02 14:00:49 +08:00
if os.path.exists(args.save_path):
with open(args.save_path, "r") as f:
results = json.load(f)
else:
results = list()
for dir in os.listdir(args.fea_dir):
for data_name, data_args in datasets.items():
2020-04-17 20:31:26 +08:00
for index_name, index_args in indexes.items():
2020-04-02 14:00:49 +08:00
if data_name in dir:
2020-04-17 20:31:26 +08:00
print(dir)
2020-04-02 14:00:49 +08:00
# get dirs
gallery_fea_dir, query_fea_dir, train_fea_dir = get_dir(args.fea_dir, dir, data_args)
# get evaluate setting
evaluate_args = get_evaluate(gallery_fea_dir, evaluates)
# get feature names
fea_names = get_fea_names(gallery_fea_dir)
2020-04-17 20:31:26 +08:00
# set train feature path for dimension reduction processes
for dim_proc in index_args.dim_processors.names:
if dim_proc in ["PartPCA", "PartSVD", "PCA", "SVD"]:
index_args.dim_processors[dim_proc].train_fea_dir = train_fea_dir
2020-04-02 14:00:49 +08:00
for fea_name in fea_names:
2020-04-17 20:31:26 +08:00
result_dict = get_default_result_dict(dir, data_name, index_name, fea_name)
if check_result_exist(result_dict, results):
2020-04-02 14:00:49 +08:00
print("[Search Query]: config exists...")
continue
# load retrieval pipeline settings
2020-04-17 20:31:26 +08:00
index_args.feature_names = [fea_name]
cfg.index.merge_from_other_cfg(index_args)
2020-04-02 14:00:49 +08:00
cfg.evaluate.merge_from_other_cfg(evaluate_args)
# load features
query_fea, query_info, _ = feature_loader.load(query_fea_dir, [fea_name])
gallery_fea, gallery_info, _ = feature_loader.load(gallery_fea_dir, [fea_name])
# build helper and index features
2020-04-17 20:31:26 +08:00
index_helper = build_index_helper(cfg.index)
index_result_info, _, _ = index_helper.do_index(query_fea, query_info, gallery_fea)
2020-04-02 14:00:49 +08:00
# build helper and evaluate results
evaluate_helper = build_evaluate_helper(cfg.evaluate)
2020-04-17 20:31:26 +08:00
mAP, recall_at_k = evaluate_helper.do_eval(index_result_info, gallery_info)
2020-04-02 14:00:49 +08:00
# record results
to_save_recall = dict()
for k in recall_at_k:
to_save_recall[str(k)] = recall_at_k[k]
result_dict["mAP"] = float(mAP)
result_dict["recall_at_k"] = to_save_recall
results.append(result_dict)
# save results
with open(args.save_path, "w") as f:
json.dump(results, f)
if __name__ == '__main__':
main()