PyRetri/search/search_extract.py

90 lines
3.5 KiB
Python
Raw Normal View History

2020-04-02 14:00:49 +08:00
# -*- coding: utf-8 -*-
import os
import argparse
import importlib
2020-04-15 14:44:22 +08:00
from pyretri.config import get_defaults_cfg
from pyretri.datasets import build_folder, build_loader
from pyretri.models import build_model
from pyretri.extract import build_extract_helper
2020-04-02 14:00:49 +08:00
def load_datasets():
data_json_dir = "/home/songrenjie/projects/RetrievalToolBox/new_data_jsons/"
datasets = {
"oxford_gallery": os.path.join(data_json_dir, "oxford_gallery.json"),
"oxford_query": os.path.join(data_json_dir, "oxford_query.json"),
"cub_gallery": os.path.join(data_json_dir, "cub_gallery.json"),
"cub_query": os.path.join(data_json_dir, "cub_query.json"),
"indoor_gallery": os.path.join(data_json_dir, "indoor_gallery.json"),
"indoor_query": os.path.join(data_json_dir, "indoor_query.json"),
"caltech_gallery": os.path.join(data_json_dir, "caltech_gallery.json"),
"caltech_query": os.path.join(data_json_dir, "caltech_query.json"),
"paris_all": os.path.join(data_json_dir, "paris.json"),
}
for data_path in datasets.values():
assert os.path.exists(data_path), "non-exist dataset path {}".format(data_path)
return datasets
def parse_args():
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
parser.add_argument('--save_path', '-sp', default=None, type=str, help="save path for feature")
2020-04-17 20:31:26 +08:00
parser.add_argument("--search_modules", "-sm", default=None, type=str, help="name of search module's directory")
2020-04-02 14:00:49 +08:00
args = parser.parse_args()
return args
def main():
# init args
args = parse_args()
assert args.save_path is not None, 'the save path must be provided!'
assert args.search_modules is not None, 'the search modules must be provided!'
# init retrieval pipeline settings
cfg = get_defaults_cfg()
# load search space
datasets = load_datasets()
2020-04-17 20:31:26 +08:00
pre_processes = importlib.import_module("{}.pre_process_dict".format(args.search_modules)).pre_processes
2020-04-02 14:00:49 +08:00
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
# search in an exhaustive way
for data_name, data_args in datasets.items():
2020-04-17 20:31:26 +08:00
for pre_proc_name, pre_proc_args in pre_processes.items():
2020-04-02 14:00:49 +08:00
for model_name, model_args in models.items():
2020-04-17 20:31:26 +08:00
feature_full_name = data_name + "_" + pre_proc_name + "_" + model_name
2020-04-02 14:00:49 +08:00
print(feature_full_name)
2020-04-17 20:31:26 +08:00
if os.path.exists(os.path.join(args.save_path, feature_full_name)):
2020-04-02 14:00:49 +08:00
print("[Search Extract]: config exists...")
continue
# load retrieval pipeline settings
2020-04-17 20:31:26 +08:00
cfg.datasets.merge_from_other_cfg(pre_proc_args)
2020-04-02 14:00:49 +08:00
cfg.model.merge_from_other_cfg(model_args)
cfg.extract.merge_from_other_cfg(extracts[model_name])
# build dataset and dataloader
dataset = build_folder(data_args, cfg.datasets)
dataloader = build_loader(dataset, cfg.datasets)
# build model
model = build_model(cfg.model)
# build helper and extract features
extract_helper = build_extract_helper(model, cfg.extract)
extract_helper.do_extract(dataloader, save_path=os.path.join(args.save_path, feature_full_name))
if __name__ == '__main__':
main()