2020-04-02 14:00:49 +08:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2020-04-15 14:44:22 +08:00
|
|
|
from pyretri.config import get_defaults_cfg, setup_cfg
|
|
|
|
from pyretri.datasets import build_folder, build_loader
|
|
|
|
from pyretri.models import build_model
|
|
|
|
from pyretri.extract import build_extract_helper
|
2020-04-02 14:00:49 +08:00
|
|
|
|
|
|
|
from torchvision import models
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
|
|
|
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
|
|
|
|
parser.add_argument('--data_json', '-dj', default=None, type=str, help='json file for dataset to be extracted')
|
|
|
|
parser.add_argument('--save_path', '-sp', default=None, type=str, help='save path for features')
|
|
|
|
parser.add_argument('--config_file', '-cfg', default=None, metavar='FILE', type=str, help='path to config file')
|
|
|
|
parser.add_argument('--save_interval', '-si', default=5000, type=int, help='number of features saved in one part file')
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
|
|
# init args
|
|
|
|
args = parse_args()
|
|
|
|
assert args.data_json is not None, 'the dataset json must be provided!'
|
|
|
|
assert args.save_path is not None, 'the save path must be provided!'
|
|
|
|
assert args.config_file is not None, 'a config file must be provided!'
|
|
|
|
|
|
|
|
# init and load retrieval pipeline settings
|
|
|
|
cfg = get_defaults_cfg()
|
|
|
|
cfg = setup_cfg(cfg, args.config_file, args.opts)
|
|
|
|
|
|
|
|
# build dataset and dataloader
|
|
|
|
dataset = build_folder(args.data_json, cfg.datasets)
|
|
|
|
dataloader = build_loader(dataset, cfg.datasets)
|
|
|
|
|
|
|
|
# build model
|
|
|
|
model = build_model(cfg.model)
|
|
|
|
|
|
|
|
# build helper and extract features
|
|
|
|
extract_helper = build_extract_helper(model, cfg.extract)
|
|
|
|
extract_helper.do_extract(dataloader, args.save_path, args.save_interval)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|