mirror of https://github.com/PyRetri/PyRetri.git
upload
parent
402cb98b3b
commit
06afa9b644
|
@ -67,7 +67,7 @@ Examples:
|
|||
# for dataset collecting images with the same label in one directory
|
||||
python3 main/make_data_json.py -d /data/caltech101/gallery/ -sp data_jsons/caltech_gallery.json -t general
|
||||
|
||||
python3 main/make_data_json.py -d /data/caltech101/query/ -sp data_jsons/caltech_query.json -t feneral
|
||||
python3 main/make_data_json.py -d /data/caltech101/query/ -sp data_jsons/caltech_query.json -t general
|
||||
|
||||
# for oxford/paris dataset
|
||||
python3 main/make_data_json.py -d /data/cbir/oxford/gallery/ -sp data_jsons/oxford_gallery.json -t oxford -gt /data/cbir/oxford/gt/
|
||||
|
@ -229,10 +229,10 @@ cd search/
|
|||
|
||||
### Define Search Space
|
||||
|
||||
We decompose the search space into three sub search spaces: data_process, extract and index, each of which corresponds to a specified file. Search space is defined by adding methods with hyper-parameters to a specified dict. You can add a search operator as follows:
|
||||
We decompose the search space into three sub search spaces: pre_process, extract and index, each of which corresponds to a specified file. Search space is defined by adding methods with hyper-parameters to a specified dict. You can add a search operator as follows:
|
||||
|
||||
```shell
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"PadResize224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -257,7 +257,7 @@ data_processes.add(
|
|||
)
|
||||
```
|
||||
|
||||
By doing this, a data_process operator named "PadResize224" is added to the data_process sub search space and will be searched in the following process.
|
||||
By doing this, a pre_process operator named "PadResize224" is added to the data_process sub search space and will be searched in the following process.
|
||||
|
||||
### Search
|
||||
|
||||
|
@ -287,7 +287,7 @@ python3 search_extract.py -sp /data/features/gap_gmp_gem_crow_spoc/ -sm search_m
|
|||
Search for the indexing combinations by:
|
||||
|
||||
```shell
|
||||
python3 search_query.py [-fd ${fea_dir}] [-sm ${search_modules}] [-sp ${save_path}]
|
||||
python3 search_index.py [-fd ${fea_dir}] [-sm ${search_modules}] [-sp ${save_path}]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
@ -299,8 +299,23 @@ Arguments:
|
|||
Examples:
|
||||
|
||||
```shell
|
||||
python3 search_query.py -fd /data/features/gap_gmp_gem_crow_spoc/ -sm search_modules -sp /data/features/gap_gmp_gem_crow_spoc_result.json
|
||||
python3 search_index.py -fd /data/features/gap_gmp_gem_crow_spoc/ -sm search_modules -sp /data/features/gap_gmp_gem_crow_spoc_result.json
|
||||
```
|
||||
|
||||
#### show search results
|
||||
|
||||
We provide two ways to show the search results. One is save all the search results in a csv format file, which can be used for further analyses. The other is showing the search results according to the given keywords. You can define the keywords as follows:
|
||||
|
||||
```sh
|
||||
keywords = {
|
||||
'data_name': ['market'],
|
||||
'pre_process_name': list(),
|
||||
'model_name': list(),
|
||||
'feature_map_name': list(),
|
||||
'aggregator_name': list(),
|
||||
'post_process_name': ['no_fea_process', 'l2_normalize', 'pca_whiten', 'pca_wo_whiten'],
|
||||
}
|
||||
```
|
||||
|
||||
See show_search_results.py for more details.
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ cd pyretri
|
|||
3. Install PyRetri.
|
||||
|
||||
```shell
|
||||
python setup.py install
|
||||
python3 setup.py install
|
||||
```
|
||||
|
||||
## Prepare Datasets
|
||||
|
|
|
@ -35,7 +35,7 @@ Choosing the implementations mentioned above as baselines and adding some tricks
|
|||
|
||||
For person re-identification, we use the model provided by [Person_reID_baseline](https://github.com/layumi/Person_reID_baseline_pytorch) and reproduce its resutls. In addition, we train a model on DukeMTMC-reID through the open source code for further experiments.
|
||||
|
||||
###pre-trained models
|
||||
### pre-trained models
|
||||
|
||||
| Training Set | Backbone | for Short | Download |
|
||||
| :-----------: | :-------: | :-------: | :------: |
|
||||
|
|
|
@ -84,7 +84,9 @@ class PartPCA(DimProcessorBase):
|
|||
pca = self.pcas[fea_name]["pca"]
|
||||
|
||||
ori_fea = fea[:, st_idx: ed_idx]
|
||||
proj_fea = pca.transform(ori_fea)
|
||||
proj_fea = normalize(ori_fea, norm='l2')
|
||||
proj_fea = pca.transform(proj_fea)
|
||||
proj_fea = normalize(proj_fea, norm='l2')
|
||||
|
||||
ret.append(proj_fea)
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ class PartSVD(DimProcessorBase):
|
|||
else:
|
||||
proj_part_dim = self._hyper_params["proj_dim"] - already_proj_dim
|
||||
assert proj_part_dim < ori_part_dim, "reduction dimension can not be distributed to each part!"
|
||||
already_proj_dim += proj_part_dim
|
||||
|
||||
svd = SKSVD(n_components=proj_part_dim)
|
||||
train_fea = fea[:, st_idx: ed_idx]
|
||||
|
@ -79,22 +80,23 @@ class PartSVD(DimProcessorBase):
|
|||
}
|
||||
|
||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||
if self._hyper_params["proj_dim"] != 0:
|
||||
ret = np.zeros(shape=(fea.shape[0], self._hyper_params["proj_dim"]))
|
||||
else:
|
||||
ret = np.zeros(shape=(fea.shape[0], fea.shape[1] - len(self.svds)))
|
||||
fea_names = np.sort(list(self.svds.keys()))
|
||||
ret = list()
|
||||
|
||||
for fea_name in self.svds:
|
||||
for fea_name in fea_names:
|
||||
st_idx, ed_idx = self.svds[fea_name]["pos"][0], self.svds[fea_name]["pos"][1]
|
||||
svd = self.svds[fea_name]["svd"]
|
||||
|
||||
proj_fea = fea[:, st_idx: ed_idx]
|
||||
proj_fea = normalize(proj_fea, norm='l2')
|
||||
proj_fea = svd.transform(proj_fea)
|
||||
if self._hyper_params["whiten"]:
|
||||
proj_fea = proj_fea / (self.svds[fea_name]["std"] + 1e-6)
|
||||
proj_fea = normalize(proj_fea, norm='l2')
|
||||
|
||||
ret[:, st_idx: ed_idx] = proj_fea
|
||||
ret.append(proj_fea)
|
||||
|
||||
ret = np.concatenate(ret, axis=1)
|
||||
return ret
|
||||
|
||||
|
||||
|
|
|
@ -49,24 +49,27 @@ def main():
|
|||
|
||||
# load search space
|
||||
datasets = load_datasets()
|
||||
data_processes = importlib.import_module("{}.data_process_dict".format(args.search_modules)).data_processes
|
||||
pre_processes = importlib.import_module("{}.pre_process_dict".format(args.search_modules)).pre_processes
|
||||
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
|
||||
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
|
||||
|
||||
# search in an exhaustive way
|
||||
for data_name, data_args in datasets.items():
|
||||
for data_proc_name, data_proc_args in data_processes.items():
|
||||
|
||||
for pre_proc_name, pre_proc_args in pre_processes.items():
|
||||
if 'market' in data_name:
|
||||
model_name = 'market_res50'
|
||||
elif 'duke' in data_name:
|
||||
model_name = 'duke_res50'
|
||||
|
||||
feature_full_name = data_name + "_" + data_proc_name + "_" + model_name
|
||||
feature_full_name = data_name + "_" + pre_proc_name + "_" + model_name
|
||||
print(feature_full_name)
|
||||
|
||||
if os.path.exists(os.path.join(args.save_path, feature_full_name)):
|
||||
print("[Search Extract]: config exists...")
|
||||
continue
|
||||
|
||||
# load retrieval pipeline settings
|
||||
cfg.datasets.merge_from_other_cfg(data_proc_args)
|
||||
cfg.datasets.merge_from_other_cfg(pre_proc_args)
|
||||
cfg.model.merge_from_other_cfg(models[model_name])
|
||||
cfg.extract.merge_from_other_cfg(extracts[model_name])
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import importlib
|
|||
import os
|
||||
import argparse
|
||||
|
||||
from .utils.misc import check_exist, get_dir
|
||||
from utils.misc import check_result_exist, get_dir, get_default_result_dict
|
||||
|
||||
from pyretri.config import get_defaults_cfg
|
||||
from pyretri.index import build_index_helper, feature_loader
|
||||
|
@ -53,8 +53,8 @@ def main():
|
|||
|
||||
# load search space
|
||||
datasets = load_datasets()
|
||||
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
|
||||
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
|
||||
indexes = importlib.import_module("{}.index_dict".format(args.search_modules)).indexes
|
||||
evaluates = importlib.import_module("{}.index_dict".format(args.search_modules)).evaluates
|
||||
|
||||
if os.path.exists(args.save_path):
|
||||
with open(args.save_path, "r") as f:
|
||||
|
@ -64,8 +64,9 @@ def main():
|
|||
|
||||
for dir in os.listdir(args.fea_dir):
|
||||
for data_name, data_args in datasets.items():
|
||||
for query_name, query_args in queries.items():
|
||||
for index_name, index_args in indexes.items():
|
||||
if data_name in dir:
|
||||
print(dir)
|
||||
|
||||
# get dirs
|
||||
gallery_fea_dir, query_fea_dir, train_fea_dir = get_dir(args.fea_dir, dir, data_args)
|
||||
|
@ -73,21 +74,19 @@ def main():
|
|||
# get evaluate setting
|
||||
evaluate_args = evaluates["reid_overall"]
|
||||
|
||||
for dim_proc in query_args.dim_processors.names:
|
||||
for dim_proc in index_args.dim_processors.names:
|
||||
if dim_proc in ["PartPCA", "PartSVD", "PCA", "SVD"]:
|
||||
query_args.dim_processors[dim_proc].train_fea_dir = train_fea_dir
|
||||
index_args.dim_processors[dim_proc].train_fea_dir = train_fea_dir
|
||||
|
||||
for fea_name in fea_names:
|
||||
|
||||
result_dict = get_default_result_dict(dir, data_name, query_name, fea_name)
|
||||
if check_exist(result_dict, results):
|
||||
result_dict = get_default_result_dict(dir, data_name, index_name, fea_name)
|
||||
if check_result_exist(result_dict, results):
|
||||
print("[Search Query]: config exists...")
|
||||
continue
|
||||
print(data_name + '_' + fea_name + '_' + query_name)
|
||||
|
||||
# load retrieval pipeline settings
|
||||
query_args.feature_names = [fea_name]
|
||||
cfg.index.merge_from_other_cfg(query_args)
|
||||
index_args.feature_names = [fea_name]
|
||||
cfg.index.merge_from_other_cfg(index_args)
|
||||
cfg.evaluate.merge_from_other_cfg(evaluate_args)
|
||||
|
||||
# load features
|
||||
|
@ -95,12 +94,12 @@ def main():
|
|||
gallery_fea, gallery_info, _ = feature_loader.load(gallery_fea_dir, [fea_name])
|
||||
|
||||
# build helper and index features
|
||||
query_helper = build_index_helper(cfg.index)
|
||||
query_result_info, _, _ = query_helper.do_index(query_fea, query_info, gallery_fea)
|
||||
index_helper = build_index_helper(cfg.index)
|
||||
index_result_info, _, _ = index_helper.do_index(query_fea, query_info, gallery_fea)
|
||||
|
||||
# build helper and evaluate results
|
||||
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
||||
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info, gallery_info)
|
||||
mAP, recall_at_k = evaluate_helper.do_eval(index_result_info, gallery_info)
|
||||
|
||||
# record results
|
||||
to_save_recall = dict()
|
|
@ -19,6 +19,7 @@ models.add(
|
|||
extracts.add(
|
||||
"market_res50",
|
||||
{
|
||||
"assemble": 1,
|
||||
"extractor": {
|
||||
"name": "ReIDSeries",
|
||||
"ReIDSeries": {
|
||||
|
@ -39,13 +40,14 @@ models.add(
|
|||
{
|
||||
"name": "ft_net",
|
||||
"ft_net": {
|
||||
"load_checkpoint": "/home/songrenjie/projects/reID_baseline/model/ft_ResNet50/net_59.pth"
|
||||
"load_checkpoint": "/home/songrenjie/projects/reID_baseline/model/ft_ResNet50/res50_duke.pth"
|
||||
}
|
||||
}
|
||||
)
|
||||
extracts.add(
|
||||
"duke_res50",
|
||||
{
|
||||
"assemble": 1,
|
||||
"extractor": {
|
||||
"name": "ReIDSeries",
|
||||
"ReIDSeries": {
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
from utils.search_modules import SearchModules
|
||||
from pyretri.config import get_defaults_cfg
|
||||
|
||||
queries = SearchModules()
|
||||
indexes = SearchModules()
|
||||
evaluates = SearchModules()
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"no_fea_process",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -32,7 +32,7 @@ queries.add(
|
|||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"l2_normalize",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -58,7 +58,7 @@ queries.add(
|
|||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"pca_wo_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -67,11 +67,12 @@ queries.add(
|
|||
"feature_names": [],
|
||||
|
||||
"dim_processors": {
|
||||
"names": ["PartPCA"],
|
||||
"PartPCA": {
|
||||
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||
"PCA": {
|
||||
"whiten": False,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 512
|
||||
"proj_dim": 512,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -89,7 +90,7 @@ queries.add(
|
|||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"pca_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -98,11 +99,12 @@ queries.add(
|
|||
"feature_names": [],
|
||||
|
||||
"dim_processors": {
|
||||
"names": ["PartPCA"],
|
||||
"PartPCA": {
|
||||
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||
"PCA": {
|
||||
"whiten": True,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 512
|
||||
"proj_dim": 512,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -120,7 +122,7 @@ queries.add(
|
|||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"svd_wo_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -129,11 +131,12 @@ queries.add(
|
|||
"feature_names": [],
|
||||
|
||||
"dim_processors": {
|
||||
"names": ["PartSVD"],
|
||||
"PartSVD": {
|
||||
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||
"SVD": {
|
||||
"whiten": False,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 511
|
||||
"proj_dim": 511,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -151,7 +154,7 @@ queries.add(
|
|||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"svd_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -160,11 +163,12 @@ queries.add(
|
|||
"feature_names": [],
|
||||
|
||||
"dim_processors": {
|
||||
"names": ["PartSVD"],
|
||||
"PartSVD": {
|
||||
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||
"SVD": {
|
||||
"whiten": True,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 511
|
||||
"proj_dim": 511,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -211,5 +215,5 @@ evaluates.add(
|
|||
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
queries.check_valid(cfg["index"])
|
||||
indexes.check_valid(cfg["index"])
|
||||
evaluates.check_valid(cfg["evaluate"])
|
|
@ -3,9 +3,9 @@
|
|||
from utils.search_modules import SearchModules
|
||||
from pyretri.config import get_defaults_cfg
|
||||
|
||||
data_processes = SearchModules()
|
||||
pre_processes = SearchModules()
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"Direct256128",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -31,4 +31,4 @@ data_processes.add(
|
|||
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
data_processes.check_valid(cfg["datasets"])
|
||||
pre_processes.check_valid(cfg["datasets"])
|
|
@ -1,149 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import importlib
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from pyretri.config import get_defaults_cfg
|
||||
from pyretri.query import build_query_helper
|
||||
from pyretri.evaluate import build_evaluate_helper
|
||||
from pyretri.query.utils import feature_loader
|
||||
|
||||
fea_names = ["output"]
|
||||
|
||||
task_mapping = {
|
||||
"market_gallery": {
|
||||
"gallery": "market_gallery",
|
||||
"query": "market_query",
|
||||
"train_fea_dir": "market_gallery"
|
||||
},
|
||||
"duke_gallery": {
|
||||
"gallery": "duke_gallery",
|
||||
"query": "duke_query",
|
||||
"train_fea_dir": "duke_gallery"
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def check_exist(now_res, exist_results):
|
||||
for e_r in exist_results:
|
||||
totoal_equal = True
|
||||
for key in now_res:
|
||||
if now_res[key] != e_r[key]:
|
||||
totoal_equal = False
|
||||
break
|
||||
if totoal_equal:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||
parser.add_argument('--fea_dir', '-f', default=None, type=str, help="path of feature directory", required=True)
|
||||
parser.add_argument(
|
||||
"--search_modules",
|
||||
"-m",
|
||||
default="",
|
||||
help="name of search module's directory",
|
||||
type=str,
|
||||
required=True
|
||||
)
|
||||
parser.add_argument("--save_path", "-s", default=None, type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_default_result_dict(dir, task_name, query_name, fea_name):
|
||||
|
||||
result_dict = {
|
||||
"task_name": task_name.split("_")[0],
|
||||
"dataprocess": dir.split("_")[0],
|
||||
"model_name": "_".join(dir.split("_")[-2:]),
|
||||
"feature_map_name": fea_name.split("_")[0],
|
||||
"fea_process_name": query_name
|
||||
}
|
||||
if fea_name == "output":
|
||||
result_dict["aggregator_name"] = "none"
|
||||
else:
|
||||
result_dict["aggregator_name"] = fea_name.split("_")[1]
|
||||
|
||||
return result_dict
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
|
||||
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
|
||||
|
||||
if os.path.exists(args.save_path):
|
||||
with open(args.save_path, "r") as f:
|
||||
results = json.load(f)
|
||||
else:
|
||||
results = list()
|
||||
|
||||
q_cnt = 0
|
||||
for dir in os.listdir(args.fea_dir):
|
||||
q_cnt += 1
|
||||
print("Processing {} / {} queries...".format(q_cnt, len(queries)))
|
||||
for task_name in task_mapping:
|
||||
for query_name, query_args in queries.items():
|
||||
if task_name in dir:
|
||||
for fea_name in fea_names:
|
||||
|
||||
gallery_fea_dir = os.path.join(args.fea_dir, dir)
|
||||
query_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["query"])
|
||||
train_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["train_fea_dir"])
|
||||
|
||||
assert os.path.exists(gallery_fea_dir), gallery_fea_dir
|
||||
assert os.path.exists(query_fea_dir), query_fea_dir
|
||||
assert os.path.exists(train_fea_dir), train_fea_dir
|
||||
|
||||
for post_proc in ["PartPCA", "PartSVD"]:
|
||||
if post_proc in query_args.post_processors.names:
|
||||
query_args.post_processors[post_proc].train_fea_dir = train_fea_dir
|
||||
query_args.gallery_fea_dir, query_args.query_fea_dir = gallery_fea_dir, query_fea_dir
|
||||
query_args.feature_names = [fea_name]
|
||||
eval_args = evaluates["reid_overall"]
|
||||
|
||||
result_dict = get_default_result_dict(dir, task_name, query_name, fea_name)
|
||||
|
||||
if check_exist(result_dict, results):
|
||||
print("[Search Query]: config exists...")
|
||||
continue
|
||||
|
||||
cfg.query.merge_from_other_cfg(query_args)
|
||||
cfg.evaluate.merge_from_other_cfg(eval_args)
|
||||
|
||||
query_helper = build_query_helper(cfg.query)
|
||||
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
||||
|
||||
query_fea, query_info_dicts, _ = feature_loader.load(cfg.query.query_fea_dir,
|
||||
cfg.query.feature_names)
|
||||
gallery_fea, gallery_info_dicts, _ = feature_loader.load(cfg.query.gallery_fea_dir,
|
||||
cfg.query.feature_names)
|
||||
|
||||
query_result_info_dicts, _, _ = query_helper.do_query(query_fea, query_info_dicts, gallery_fea)
|
||||
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info_dicts, gallery_info_dicts)
|
||||
|
||||
to_save_dict = dict()
|
||||
for k in recall_at_k:
|
||||
to_save_dict[str(k)] = recall_at_k[k]
|
||||
|
||||
result_dict["mAP"] = float(mAP)
|
||||
result_dict["recall_at_k"] = to_save_dict
|
||||
print(result_dict)
|
||||
assert False
|
||||
|
||||
results.append(result_dict)
|
||||
with open(args.save_path, "w") as f:
|
||||
json.dump(results, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -30,17 +30,11 @@ def load_datasets():
|
|||
return datasets
|
||||
|
||||
|
||||
def check_exist(save_path, full_name):
|
||||
if os.path.exists(os.path.join(save_path, full_name)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
|
||||
parser.add_argument('--save_path', '-sp', default=None, type=str, help="save path for feature")
|
||||
parser.add_argument("--search_modules", "-sm", default="", type=str, help="name of search module's directory")
|
||||
parser.add_argument("--search_modules", "-sm", default=None, type=str, help="name of search module's directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
@ -58,24 +52,24 @@ def main():
|
|||
|
||||
# load search space
|
||||
datasets = load_datasets()
|
||||
data_processes = importlib.import_module("{}.data_process_dict".format(args.search_modules)).data_processes
|
||||
pre_processes = importlib.import_module("{}.pre_process_dict".format(args.search_modules)).pre_processes
|
||||
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
|
||||
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
|
||||
|
||||
# search in an exhaustive way
|
||||
for data_name, data_args in datasets.items():
|
||||
for data_proc_name, data_proc_args in data_processes.items():
|
||||
for pre_proc_name, pre_proc_args in pre_processes.items():
|
||||
for model_name, model_args in models.items():
|
||||
|
||||
feature_full_name = data_name + "_" + data_proc_name + "_" + model_name
|
||||
feature_full_name = data_name + "_" + pre_proc_name + "_" + model_name
|
||||
print(feature_full_name)
|
||||
|
||||
if check_exist(args.save_path, feature_full_name):
|
||||
if os.path.exists(os.path.join(args.save_path, feature_full_name)):
|
||||
print("[Search Extract]: config exists...")
|
||||
continue
|
||||
|
||||
# load retrieval pipeline settings
|
||||
cfg.datasets.merge_from_other_cfg(data_proc_args)
|
||||
cfg.datasets.merge_from_other_cfg(pre_proc_args)
|
||||
cfg.model.merge_from_other_cfg(model_args)
|
||||
cfg.extract.merge_from_other_cfg(extracts[model_name])
|
||||
|
||||
|
|
|
@ -5,13 +5,26 @@ import importlib
|
|||
import os
|
||||
import argparse
|
||||
|
||||
from .utils.misc import check_exist, get_dir
|
||||
from utils.misc import check_result_exist, get_dir, get_default_result_dict
|
||||
|
||||
from pyretri.config import get_defaults_cfg
|
||||
from pyretri.index import build_index_helper, feature_loader
|
||||
from pyretri.evaluate import build_evaluate_helper
|
||||
|
||||
|
||||
# # gap, gmp, gem, spoc, crow
|
||||
# vgg_fea = ["pool4_GAP", "pool4_GMP", "pool4_GeM", "pool4_SPoC", "pool4_Crow",
|
||||
# "pool5_GAP", "pool5_GMP", "pool5_GeM", "pool5_SPoC", "pool5_Crow",
|
||||
# "fc"]
|
||||
# res_fea = ["pool3_GAP", "pool3_GMP", "pool3_GeM", "pool3_SPoC", "pool4_Crow",
|
||||
# "pool4_GAP", "pool4_GMP", "pool4_GeM", "pool4_SPoC", "pool4_Crow",
|
||||
# "pool5_GAP", "pool5_GMP", "pool5_GeM", "pool5_SPoC", "pool5_Crow"]
|
||||
|
||||
# # scda, rmca
|
||||
# vgg_fea = ["pool5_SCDA", "pool5_RMAC"]
|
||||
# res_fea = ["pool5_SCDA", "pool5_RMAC"]
|
||||
|
||||
# pwa
|
||||
vgg_fea = ["pool5_PWA"]
|
||||
res_fea = ["pool5_PWA"]
|
||||
|
||||
|
@ -80,8 +93,8 @@ def main():
|
|||
|
||||
# load search space
|
||||
datasets = load_datasets()
|
||||
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
|
||||
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
|
||||
indexes = importlib.import_module("{}.index_dict".format(args.search_modules)).indexes
|
||||
evaluates = importlib.import_module("{}.index_dict".format(args.search_modules)).evaluates
|
||||
|
||||
if os.path.exists(args.save_path):
|
||||
with open(args.save_path, "r") as f:
|
||||
|
@ -91,8 +104,10 @@ def main():
|
|||
|
||||
for dir in os.listdir(args.fea_dir):
|
||||
for data_name, data_args in datasets.items():
|
||||
for query_name, query_args in queries.items():
|
||||
for index_name, index_args in indexes.items():
|
||||
if data_name in dir:
|
||||
print(dir)
|
||||
|
||||
# get dirs
|
||||
gallery_fea_dir, query_fea_dir, train_fea_dir = get_dir(args.fea_dir, dir, data_args)
|
||||
|
||||
|
@ -102,19 +117,20 @@ def main():
|
|||
# get feature names
|
||||
fea_names = get_fea_names(gallery_fea_dir)
|
||||
|
||||
for post_proc in query_args.post_processors.names:
|
||||
if post_proc in ["PartPCA", "PartSVD", "PCA", "SVD"]:
|
||||
query_args.post_processors[post_proc].train_fea_dir = train_fea_dir
|
||||
# set train feature path for dimension reduction processes
|
||||
for dim_proc in index_args.dim_processors.names:
|
||||
if dim_proc in ["PartPCA", "PartSVD", "PCA", "SVD"]:
|
||||
index_args.dim_processors[dim_proc].train_fea_dir = train_fea_dir
|
||||
|
||||
for fea_name in fea_names:
|
||||
result_dict = get_default_result_dict(dir, data_name, query_name, fea_name)
|
||||
if check_exist(result_dict, results):
|
||||
result_dict = get_default_result_dict(dir, data_name, index_name, fea_name)
|
||||
if check_result_exist(result_dict, results):
|
||||
print("[Search Query]: config exists...")
|
||||
continue
|
||||
|
||||
# load retrieval pipeline settings
|
||||
query_args.feature_names = [fea_name]
|
||||
cfg.index.merge_from_other_cfg(query_args)
|
||||
index_args.feature_names = [fea_name]
|
||||
cfg.index.merge_from_other_cfg(index_args)
|
||||
cfg.evaluate.merge_from_other_cfg(evaluate_args)
|
||||
|
||||
# load features
|
||||
|
@ -122,12 +138,12 @@ def main():
|
|||
gallery_fea, gallery_info, _ = feature_loader.load(gallery_fea_dir, [fea_name])
|
||||
|
||||
# build helper and index features
|
||||
query_helper = build_query_helper(cfg.query)
|
||||
query_result_info, _, _ = query_helper.do_query(query_fea, query_info, gallery_fea)
|
||||
index_helper = build_index_helper(cfg.index)
|
||||
index_result_info, _, _ = index_helper.do_index(query_fea, query_info, gallery_fea)
|
||||
|
||||
# build helper and evaluate results
|
||||
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
||||
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info, gallery_info)
|
||||
mAP, recall_at_k = evaluate_helper.do_eval(index_result_info, gallery_info)
|
||||
|
||||
# record results
|
||||
to_save_recall = dict()
|
|
@ -3,10 +3,10 @@
|
|||
from utils.search_modules import SearchModules
|
||||
from pyretri.config import get_defaults_cfg
|
||||
|
||||
queries = SearchModules()
|
||||
indexes = SearchModules()
|
||||
evaluates = SearchModules()
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"pca_wo_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -15,11 +15,12 @@ queries.add(
|
|||
"feature_names": [],
|
||||
|
||||
"dim_processors": {
|
||||
"names": ["PartPCA"],
|
||||
"PartPCA": {
|
||||
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||
"PCA": {
|
||||
"whiten": False,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 512
|
||||
"proj_dim": 512,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -37,7 +38,7 @@ queries.add(
|
|||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"pca_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -46,11 +47,12 @@ queries.add(
|
|||
"feature_names": [],
|
||||
|
||||
"dim_processors": {
|
||||
"names": ["PartPCA"],
|
||||
"PartPCA": {
|
||||
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||
"PCA": {
|
||||
"whiten": True,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 512
|
||||
"proj_dim": 512,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -68,7 +70,7 @@ queries.add(
|
|||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"svd_wo_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -77,11 +79,12 @@ queries.add(
|
|||
"feature_names": [],
|
||||
|
||||
"dim_processors": {
|
||||
"names": ["PartSVD"],
|
||||
"PartSVD": {
|
||||
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||
"SVD": {
|
||||
"whiten": False,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 511
|
||||
"proj_dim": 511,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -99,7 +102,7 @@ queries.add(
|
|||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"svd_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -108,11 +111,12 @@ queries.add(
|
|||
"feature_names": [],
|
||||
|
||||
"dim_processors": {
|
||||
"names": ["PartSVD"],
|
||||
"PartSVD": {
|
||||
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||
"SVD": {
|
||||
"whiten": True,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 511
|
||||
"proj_dim": 511,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -150,5 +154,5 @@ evaluates.add(
|
|||
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
queries.check_valid(cfg["index"])
|
||||
indexes.check_valid(cfg["index"])
|
||||
evaluates.check_valid(cfg["evaluate"])
|
|
@ -3,9 +3,9 @@
|
|||
from utils.search_modules import SearchModules
|
||||
from pyretri.config import get_defaults_cfg
|
||||
|
||||
data_processes = SearchModules()
|
||||
pre_processes = SearchModules()
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"Shorter256Center224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -16,8 +16,8 @@ data_processes.add(
|
|||
"name": "CollateFn"
|
||||
},
|
||||
"transformers": {
|
||||
"names": ["ResizeShorter", "CenterCrop", "ToTensor", "Normalize"],
|
||||
"ResizeShorter": {
|
||||
"names": ["ShorterResize", "CenterCrop", "ToTensor", "Normalize"],
|
||||
"ShorterResize": {
|
||||
"size": 256
|
||||
},
|
||||
"CenterCrop": {
|
||||
|
@ -31,7 +31,7 @@ data_processes.add(
|
|||
}
|
||||
)
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"Direct224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -54,7 +54,7 @@ data_processes.add(
|
|||
}
|
||||
)
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"PadResize224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -80,4 +80,4 @@ data_processes.add(
|
|||
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
data_processes.check_valid(cfg["datasets"])
|
||||
pre_processes.check_valid(cfg["datasets"])
|
|
@ -21,9 +21,9 @@ def load_datasets():
|
|||
"cub_query": os.path.join(data_json_dir, "cub_query.json"),
|
||||
"indoor_gallery": os.path.join(data_json_dir, "indoor_gallery.json"),
|
||||
"indoor_query": os.path.join(data_json_dir, "indoor_query.json"),
|
||||
"caltech101_gallery": os.path.join(data_json_dir, "caltech101_gallery.json"),
|
||||
"caltech101_query": os.path.join(data_json_dir, "caltech101_query.json"),
|
||||
"paris": os.path.join(data_json_dir, "paris.json"),
|
||||
"caltech_gallery": os.path.join(data_json_dir, "caltech_gallery.json"),
|
||||
"caltech_query": os.path.join(data_json_dir, "caltech_query.json"),
|
||||
"paris_all": os.path.join(data_json_dir, "paris.json"),
|
||||
}
|
||||
for data_path in datasets.values():
|
||||
assert os.path.exists(data_path), "non-exist dataset path {}".format(data_path)
|
||||
|
@ -48,25 +48,30 @@ def main():
|
|||
# init retrieval pipeline settings
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
data_processes = importlib.import_module("{}.data_process_dict".format(args.search_modules)).data_processes
|
||||
# load search space
|
||||
datasets = load_datasets()
|
||||
pre_processes = importlib.import_module("{}.pre_process_dict".format(args.search_modules)).pre_processes
|
||||
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
|
||||
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
|
||||
|
||||
datasets = load_datasets()
|
||||
|
||||
for data_name, data_args in datasets.items():
|
||||
for data_proc_name, data_proc_args in data_processes.items():
|
||||
for pre_proc_name, pre_proc_args in pre_processes.items():
|
||||
for model_name, model_args in models.items():
|
||||
|
||||
feature_full_name = data_process_name + "_" + dataset_name + "_" + model_name
|
||||
feature_full_name = data_name + "_" + pre_proc_name + "_" + model_name
|
||||
print(feature_full_name)
|
||||
|
||||
if os.path.exists(os.path.join(args.save_path, feature_full_name)):
|
||||
print("[Search Extract]: config exists...")
|
||||
continue
|
||||
|
||||
# load retrieval pipeline settings
|
||||
cfg.datasets.merge_from_other_cfg(data_proc_args)
|
||||
cfg.datasets.merge_from_other_cfg(pre_proc_args)
|
||||
cfg.model.merge_from_other_cfg(model_args)
|
||||
cfg.extract.merge_from_other_cfg(extracts[model_name])
|
||||
|
||||
pwa_train_fea_dir = os.path.join("/data/my_features/gap_gmp_gem_crow_spoc", feature_full_name)
|
||||
# set train feature path for pwa
|
||||
pwa_train_fea_dir = os.path.join("/data/features/test_gap_gmp_gem_crow_spoc", feature_full_name)
|
||||
if "query" in pwa_train_fea_dir:
|
||||
pwa_train_fea_dir.replace("query", "gallery")
|
||||
elif "paris" in pwa_train_fea_dir:
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
from utils.search_modules import SearchModules
|
||||
from pyretri.config import get_defaults_cfg
|
||||
|
||||
queries = SearchModules()
|
||||
indexes = SearchModules()
|
||||
evaluates = SearchModules()
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"pca_wo_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -14,16 +14,17 @@ queries.add(
|
|||
|
||||
"feature_names": [],
|
||||
|
||||
"post_processor": {
|
||||
"name": "PartPCA",
|
||||
"PartPCA": {
|
||||
"dim_processors": {
|
||||
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||
"PCA": {
|
||||
"whiten": False,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 512
|
||||
"proj_dim": 512,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
"database_enhance": {
|
||||
"feature_enhancer": {
|
||||
"name": "Identity"
|
||||
},
|
||||
|
||||
|
@ -31,13 +32,13 @@ queries.add(
|
|||
"name": "KNN"
|
||||
},
|
||||
|
||||
"re_rank": {
|
||||
"re_ranker": {
|
||||
"name": "Identity"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"pca_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -45,16 +46,17 @@ queries.add(
|
|||
|
||||
"feature_names": [],
|
||||
|
||||
"post_processor": {
|
||||
"name": "PartPCA",
|
||||
"PartPCA": {
|
||||
"dim_processors": {
|
||||
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||
"PCA": {
|
||||
"whiten": True,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 512
|
||||
"proj_dim": 512,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
"database_enhance": {
|
||||
"feature_enhancer": {
|
||||
"name": "Identity"
|
||||
},
|
||||
|
||||
|
@ -62,13 +64,13 @@ queries.add(
|
|||
"name": "KNN"
|
||||
},
|
||||
|
||||
"re_rank": {
|
||||
"re_ranker": {
|
||||
"name": "Identity"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"svd_wo_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -76,16 +78,17 @@ queries.add(
|
|||
|
||||
"feature_names": [],
|
||||
|
||||
"post_processor": {
|
||||
"name": "PartSVD",
|
||||
"PartSVD": {
|
||||
"dim_processors": {
|
||||
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||
"SVD": {
|
||||
"whiten": False,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 511
|
||||
"proj_dim": 511,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
"database_enhance": {
|
||||
"feature_enhancer": {
|
||||
"name": "Identity"
|
||||
},
|
||||
|
||||
|
@ -93,13 +96,13 @@ queries.add(
|
|||
"name": "KNN"
|
||||
},
|
||||
|
||||
"re_rank": {
|
||||
"re_ranker": {
|
||||
"name": "Identity"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"svd_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -107,16 +110,17 @@ queries.add(
|
|||
|
||||
"feature_names": [],
|
||||
|
||||
"post_processor": {
|
||||
"name": "PartSVD",
|
||||
"PartSVD": {
|
||||
"dim_processors": {
|
||||
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||
"SVD": {
|
||||
"whiten": True,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 511
|
||||
"proj_dim": 511,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
"database_enhance": {
|
||||
"feature_enhancer": {
|
||||
"name": "Identity"
|
||||
},
|
||||
|
||||
|
@ -124,7 +128,7 @@ queries.add(
|
|||
"name": "KNN"
|
||||
},
|
||||
|
||||
"re_rank": {
|
||||
"re_ranker": {
|
||||
"name": "Identity"
|
||||
}
|
||||
}
|
||||
|
@ -150,5 +154,5 @@ evaluates.add(
|
|||
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
queries.check_valid(cfg["query"])
|
||||
indexes.check_valid(cfg["index"])
|
||||
evaluates.check_valid(cfg["evaluate"])
|
|
@ -3,9 +3,9 @@
|
|||
from utils.search_modules import SearchModules
|
||||
from pyretri.config import get_defaults_cfg
|
||||
|
||||
data_processes = SearchModules()
|
||||
pre_processes = SearchModules()
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"Shorter256Center224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -16,8 +16,8 @@ data_processes.add(
|
|||
"name": "CollateFn"
|
||||
},
|
||||
"transformers": {
|
||||
"names": ["ResizeShorter", "CenterCrop", "ToTensor", "Normalize"],
|
||||
"ResizeShorter": {
|
||||
"names": ["ShorterResize", "CenterCrop", "ToTensor", "Normalize"],
|
||||
"ShorterResize": {
|
||||
"size": 256
|
||||
},
|
||||
"CenterCrop": {
|
||||
|
@ -31,7 +31,7 @@ data_processes.add(
|
|||
}
|
||||
)
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"Direct224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -54,7 +54,7 @@ data_processes.add(
|
|||
}
|
||||
)
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"PadResize224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -80,4 +80,4 @@ data_processes.add(
|
|||
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
data_processes.check_valid(cfg["datasets"])
|
||||
pre_processes.check_valid(cfg["datasets"])
|
|
@ -1,167 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import importlib
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from pyretri.config import get_defaults_cfg
|
||||
from pyretri.query import build_query_helper
|
||||
from pyretri.evaluate import build_evaluate_helper
|
||||
from pyretri.index import feature_loader
|
||||
|
||||
|
||||
vgg_fea = ["pool5_PWA"]
|
||||
res_fea = ["pool5_PWA"]
|
||||
|
||||
task_mapping = {
|
||||
"oxford_gallery": {
|
||||
"gallery": "oxford_gallery",
|
||||
"query": "oxford_query",
|
||||
"train_fea_dir": "paris"
|
||||
},
|
||||
"cub_gallery": {
|
||||
"gallery": "cub_gallery",
|
||||
"query": "cub_query",
|
||||
"train_fea_dir": "cub_gallery"
|
||||
},
|
||||
"indoor_gallery": {
|
||||
"gallery": "indoor_gallery",
|
||||
"query": "indoor_query",
|
||||
"train_fea_dir": "indoor_gallery"
|
||||
},
|
||||
"caltech101_gallery": {
|
||||
"gallery": "caltech101_gallery",
|
||||
"query": "caltech101_query",
|
||||
"train_fea_dir": "caltech101_gallery"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def check_exist(now_res, exist_results):
|
||||
for e_r in exist_results:
|
||||
totoal_equal = True
|
||||
for key in now_res:
|
||||
if now_res[key] != e_r[key]:
|
||||
totoal_equal = False
|
||||
break
|
||||
if totoal_equal:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_default_result_dict(dir, task_name, query_name, fea_name):
|
||||
result_dict = {
|
||||
"task_name": task_name.split("_")[0],
|
||||
"dataprocess": dir.split("_")[0],
|
||||
"model_name": "_".join(dir.split("_")[-2:]),
|
||||
"feature_map_name": fea_name.split("_")[0],
|
||||
"fea_process_name": query_name
|
||||
}
|
||||
|
||||
if fea_name == "fc":
|
||||
result_dict["aggregator_name"] = "none"
|
||||
else:
|
||||
result_dict["aggregator_name"] = fea_name.split("_")[1]
|
||||
|
||||
return result_dict
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||
parser.add_argument('--fea_dir', '-fd', default=None, type=str, help="path of feature dirs", required=True)
|
||||
parser.add_argument("--search_modules", "-sm", default=None, type=str, help="name of search module's directory")
|
||||
parser.add_argument("--save_path", "-sp", default=None, type=str, help="path for saving results")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# init args
|
||||
args = parse_args()
|
||||
assert args.fea_dir is not None, 'the feature directory must be provided!'
|
||||
assert args.search_modules is not None, 'the search modules must be provided!'
|
||||
assert args.save_path is not None, 'the save path must be provided!'
|
||||
|
||||
# init retrieval pipeline settings
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
# load search space
|
||||
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
|
||||
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
|
||||
|
||||
if os.path.exists(args.save_path):
|
||||
with open(args.save_path, "r") as f:
|
||||
results = json.load(f)
|
||||
else:
|
||||
results = list()
|
||||
|
||||
q_cnt = 0
|
||||
for dir in os.listdir(args.fea_dir):
|
||||
q_cnt += 1
|
||||
print("Processing {} / {} queries...".format(q_cnt, len(queries)))
|
||||
for query_name, query_args in queries.items():
|
||||
for task_name in task_mapping:
|
||||
if task_name in dir:
|
||||
|
||||
if "vgg" in gallery_fea_dir:
|
||||
fea_names = vgg_fea
|
||||
else:
|
||||
fea_names = res_fea
|
||||
|
||||
for fea_name in fea_names:
|
||||
gallery_fea_dir = os.path.join(args.fea_dir, dir)
|
||||
query_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["query"])
|
||||
train_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["train_fea_dir"])
|
||||
|
||||
for post_proc in ["PartPCA", "PartSVD"]:
|
||||
if post_proc in query_args.post_processors.names:
|
||||
query_args.post_processors[post_proc].train_fea_dir = train_fea_dir
|
||||
|
||||
query.gallery_fea_dir, query.query_fea_dir = gallery_fea_dir, query_fea_dir
|
||||
|
||||
query.feature_names = [fea_name]
|
||||
if task_name == "oxford_base":
|
||||
evaluate = evaluates["oxford_overall"]
|
||||
else:
|
||||
evaluate = evaluates["overall"]
|
||||
|
||||
result_dict = get_default_result_dict(dir, task_name, query_name, fea_name)
|
||||
|
||||
if check_exist(result_dict, results):
|
||||
print("[Search Query]: config exists...")
|
||||
continue
|
||||
|
||||
# load retrieval pipeline settings
|
||||
cfg.query.merge_from_other_cfg(query)
|
||||
cfg.evaluate.merge_from_other_cfg(evaluate)
|
||||
|
||||
# load features
|
||||
query_fea, query_info, _ = feature_loader.load(cfg.query.query_fea_dir, cfg.query.feature_names)
|
||||
gallery_fea, gallery_info, _ = feature_loader.load(cfg.query.gallery_fea_dir,
|
||||
cfg.query.feature_names)
|
||||
|
||||
# build helper and index features
|
||||
query_helper = build_query_helper(cfg.query)
|
||||
query_result_info, _, _ = query_helper.do_query(query_fea, query_info, gallery_fea)
|
||||
|
||||
# build helper and evaluate results
|
||||
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
||||
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info, gallery_info)
|
||||
|
||||
# save results
|
||||
to_save_dict = dict()
|
||||
for k in recall_at_k:
|
||||
to_save_dict[str(k)] = recall_at_k[k]
|
||||
result_dict["mAP"] = float(mAP)
|
||||
result_dict["recall_at_k"] = to_save_dict
|
||||
|
||||
results.append(result_dict)
|
||||
with open(args.save_path, "w") as f:
|
||||
json.dump(results, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -3,10 +3,10 @@
|
|||
from utils.search_modules import SearchModules
|
||||
from pyretri.config import get_defaults_cfg
|
||||
|
||||
queries = SearchModules()
|
||||
indexes = SearchModules()
|
||||
evaluates = SearchModules()
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"pca_wo_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -14,16 +14,17 @@ queries.add(
|
|||
|
||||
"feature_names": [],
|
||||
|
||||
"post_processor": {
|
||||
"name": "PartPCA",
|
||||
"PartPCA": {
|
||||
"dim_processors": {
|
||||
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||
"PCA": {
|
||||
"whiten": False,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 512
|
||||
"proj_dim": 512,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
"database_enhance": {
|
||||
"feature_enhancer": {
|
||||
"name": "Identity"
|
||||
},
|
||||
|
||||
|
@ -31,13 +32,13 @@ queries.add(
|
|||
"name": "KNN"
|
||||
},
|
||||
|
||||
"re_rank": {
|
||||
"re_ranker": {
|
||||
"name": "Identity"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"pca_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -45,16 +46,17 @@ queries.add(
|
|||
|
||||
"feature_names": [],
|
||||
|
||||
"post_processor": {
|
||||
"name": "PartPCA",
|
||||
"PartPCA": {
|
||||
"dim_processors": {
|
||||
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||
"PCA": {
|
||||
"whiten": True,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 512
|
||||
"proj_dim": 512,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
"database_enhance": {
|
||||
"feature_enhancer": {
|
||||
"name": "Identity"
|
||||
},
|
||||
|
||||
|
@ -62,13 +64,13 @@ queries.add(
|
|||
"name": "KNN"
|
||||
},
|
||||
|
||||
"re_rank": {
|
||||
"re_ranker": {
|
||||
"name": "Identity"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"svd_wo_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -76,16 +78,17 @@ queries.add(
|
|||
|
||||
"feature_names": [],
|
||||
|
||||
"post_processor": {
|
||||
"name": "PartSVD",
|
||||
"PartSVD": {
|
||||
"dim_processors": {
|
||||
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||
"SVD": {
|
||||
"whiten": False,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 511
|
||||
"proj_dim": 511,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
"database_enhance": {
|
||||
"feature_enhancer": {
|
||||
"name": "Identity"
|
||||
},
|
||||
|
||||
|
@ -93,13 +96,13 @@ queries.add(
|
|||
"name": "KNN"
|
||||
},
|
||||
|
||||
"re_rank": {
|
||||
"re_ranker": {
|
||||
"name": "Identity"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
queries.add(
|
||||
indexes.add(
|
||||
"svd_whiten",
|
||||
{
|
||||
"gallery_fea_dir": "",
|
||||
|
@ -107,16 +110,17 @@ queries.add(
|
|||
|
||||
"feature_names": [],
|
||||
|
||||
"post_processor": {
|
||||
"name": "PartSVD",
|
||||
"PartSVD": {
|
||||
"dim_processors": {
|
||||
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||
"SVD": {
|
||||
"whiten": True,
|
||||
"train_fea_dir": "",
|
||||
"proj_dim": 511
|
||||
"proj_dim": 511,
|
||||
"l2": True,
|
||||
}
|
||||
},
|
||||
|
||||
"database_enhance": {
|
||||
"feature_enhancer": {
|
||||
"name": "Identity"
|
||||
},
|
||||
|
||||
|
@ -124,7 +128,7 @@ queries.add(
|
|||
"name": "KNN"
|
||||
},
|
||||
|
||||
"re_rank": {
|
||||
"re_ranker": {
|
||||
"name": "Identity"
|
||||
}
|
||||
}
|
||||
|
@ -150,5 +154,5 @@ evaluates.add(
|
|||
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
queries.check_valid(cfg["query"])
|
||||
indexes.check_valid(cfg["index"])
|
||||
evaluates.check_valid(cfg["evaluate"])
|
|
@ -3,9 +3,9 @@
|
|||
from utils.search_modules import SearchModules
|
||||
from pyretri.config import get_defaults_cfg
|
||||
|
||||
data_processes = SearchModules()
|
||||
pre_processes = SearchModules()
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"Shorter256Center224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -16,8 +16,8 @@ data_processes.add(
|
|||
"name": "CollateFn"
|
||||
},
|
||||
"transformers": {
|
||||
"names": ["ResizeShorter", "CenterCrop", "ToTensor", "Normalize"],
|
||||
"ResizeShorter": {
|
||||
"names": ["ShorterResize", "CenterCrop", "ToTensor", "Normalize"],
|
||||
"ShorterResize": {
|
||||
"size": 256
|
||||
},
|
||||
"CenterCrop": {
|
||||
|
@ -31,7 +31,7 @@ data_processes.add(
|
|||
}
|
||||
)
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"Direct224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -54,7 +54,7 @@ data_processes.add(
|
|||
}
|
||||
)
|
||||
|
||||
data_processes.add(
|
||||
pre_processes.add(
|
||||
"PadResize224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
|
@ -80,4 +80,4 @@ data_processes.add(
|
|||
|
||||
cfg = get_defaults_cfg()
|
||||
|
||||
data_processes.check_valid(cfg["datasets"])
|
||||
pre_processes.check_valid(cfg["datasets"])
|
|
@ -4,9 +4,10 @@ import os
|
|||
import argparse
|
||||
import json
|
||||
|
||||
import csv
|
||||
import codecs
|
||||
|
||||
from utils.misc import save_to_csv, filter_by_keywords
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||
|
@ -22,62 +23,30 @@ def show_results(results):
|
|||
print(results[i])
|
||||
|
||||
|
||||
def save_to_csv(results, csv_path):
|
||||
start = []
|
||||
col_num = 12
|
||||
if not os.path.exists(csv_path):
|
||||
start = ["data_process", "model", "feature", "fea_process", "market_mAP", "market_mAP_re", "market_R1",
|
||||
"market_R1_re", "duke_mAP", "duke_mAP_re", "duke_R1", "duke_R1_re"]
|
||||
with open(csv_path, 'a+') as f:
|
||||
csv_write = csv.writer(f)
|
||||
if len(start) > 0:
|
||||
csv_write.writerow(start)
|
||||
for i in range(len(results)):
|
||||
data_row = [0 for x in range(col_num)]
|
||||
data_row[0] = results[i]["dataprocess"]
|
||||
data_row[1] = results[i]["model_name"]
|
||||
data_row[2] = results[i]["feature_map_name"]
|
||||
data_row[3] = results[i]["fea_process_name"]
|
||||
if results[i]["task_name"] == 'market':
|
||||
data_row[4] = results[i]["mAP"]
|
||||
data_row[6] = results[i]["recall_at_k"]['1']
|
||||
elif results[i]["task_name"] == 'duke':
|
||||
data_row[8] = results[i]["mAP"]
|
||||
data_row[10] = results[i]["recall_at_k"]['1']
|
||||
csv_write.writerow(data_row)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# init args
|
||||
args = parse_args()
|
||||
assert os.path.exists(args.results_json_path), 'the config file must be existed!'
|
||||
|
||||
with open(args.results_json_path, "r") as f:
|
||||
results = json.load(f)
|
||||
|
||||
key_words = {
|
||||
'task_name': ['market'],
|
||||
'dataprocess': list(),
|
||||
'model_name': list(),
|
||||
'feature_map_name': list(),
|
||||
'aggregator_name': list(),
|
||||
'fea_process_name': ['no_fea_process', 'l2_normalize', 'pca_whiten', 'pca_wo_whiten'],
|
||||
}
|
||||
|
||||
# save the search results in a csv format file.
|
||||
csv_path = '/home/songrenjie/projects/RetrievalToolBox/test.csv'
|
||||
save_to_csv(results, csv_path)
|
||||
|
||||
for key in key_words:
|
||||
no_match = []
|
||||
if len(key_words[key]) == 0:
|
||||
continue
|
||||
else:
|
||||
for i in range(len(results)):
|
||||
if not results[i][key] in key_words[key]:
|
||||
no_match.append(i)
|
||||
for num in no_match[::-1]:
|
||||
results.pop(num)
|
||||
# define the keywords to be selected
|
||||
keywords = {
|
||||
'data_name': ['market'],
|
||||
'pre_process_name': list(),
|
||||
'model_name': list(),
|
||||
'feature_map_name': list(),
|
||||
'aggregator_name': list(),
|
||||
'post_process_name': ['no_fea_process', 'l2_normalize', 'pca_whiten', 'pca_wo_whiten'],
|
||||
}
|
||||
|
||||
# show search results according to the given keywords
|
||||
results = filter_by_keywords(results, keywords)
|
||||
show_results(results)
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from typing import Dict, List
|
||||
import csv
|
||||
|
||||
|
||||
def check_exist(now_res: Dict, exist_results: List) -> bool:
|
||||
def check_result_exist(now_res: Dict, exist_results: List) -> bool:
|
||||
"""
|
||||
Check if the config exists.
|
||||
|
||||
|
@ -45,14 +46,14 @@ def get_dir(root_path: str, dir: str, dataset: Dict) -> (str, str, str):
|
|||
return gallery_fea_dir, query_fea_dir, train_fea_dir
|
||||
|
||||
|
||||
def get_default_result_dict(dir, data_name, query_name, fea_name) -> Dict:
|
||||
def get_default_result_dict(dir: str, data_name: str, index_name: str, fea_name: str) -> Dict:
|
||||
"""
|
||||
Get the default result dict based on the experimental factors.
|
||||
|
||||
Args:
|
||||
dir (str): the path of one single extracted feature directory.
|
||||
data_name (str): the name of the dataset.
|
||||
query_name (str): the name of query process.
|
||||
index_name (str): the name of query process.
|
||||
fea_name (str): the name of the features to be loaded.
|
||||
|
||||
Returns:
|
||||
|
@ -60,16 +61,70 @@ def get_default_result_dict(dir, data_name, query_name, fea_name) -> Dict:
|
|||
"""
|
||||
result_dict = {
|
||||
"data_name": data_name.split("_")[0],
|
||||
"dataprocess": dir.split("_")[0],
|
||||
"pre_process_name": dir.split("_")[2],
|
||||
"model_name": "_".join(dir.split("_")[-2:]),
|
||||
"feature_map_name": fea_name.split("_")[0],
|
||||
"fea_process_name": query_name
|
||||
"post_process_name": index_name
|
||||
}
|
||||
|
||||
if fea_name == "fc":
|
||||
if len(fea_name.split("_")) == 1:
|
||||
result_dict["aggregator_name"] = "none"
|
||||
else:
|
||||
result_dict["aggregator_name"] = fea_name.split("_")[1]
|
||||
|
||||
return result_dict
|
||||
|
||||
|
||||
def save_to_csv(results: List[Dict], csv_path: str) -> None:
|
||||
"""
|
||||
Save the search results in a csv format file.
|
||||
|
||||
Args:
|
||||
results (List): a list of retrieval results.
|
||||
csv_path (str): the path for saving the csv file.
|
||||
"""
|
||||
start = ["data", "pre_process", "model", "feature_map", "aggregator", "post_process"]
|
||||
for i in range(len(start)):
|
||||
results = sorted(results, key=lambda result: result[start[len(start) - i - 1] + "_name"])
|
||||
start.append('mAP')
|
||||
start.append('Recall@1')
|
||||
|
||||
with open(csv_path, 'w') as f:
|
||||
csv_write = csv.writer(f)
|
||||
if len(start) > 0:
|
||||
csv_write.writerow(start)
|
||||
for i in range(len(results)):
|
||||
data_row = [0 for x in range(len(start))]
|
||||
data_row[0] = results[i]["data_name"]
|
||||
data_row[1] = results[i]["pre_process_name"]
|
||||
data_row[2] = results[i]["model_name"]
|
||||
data_row[3] = results[i]["feature_map_name"]
|
||||
data_row[4] = results[i]["aggregator_name"]
|
||||
data_row[5] = results[i]["post_process_name"]
|
||||
data_row[6] = results[i]["mAP"]
|
||||
data_row[7] = results[i]["recall_at_k"]['1']
|
||||
csv_write.writerow(data_row)
|
||||
|
||||
|
||||
def filter_by_keywords(results: List[Dict], keywords: Dict) -> List[Dict]:
|
||||
"""
|
||||
Filter the search results according to the given keywords
|
||||
|
||||
Args:
|
||||
results (List): a list of retrieval results.
|
||||
keywords (Dict): a dict containing keywords to be selected.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
for key in keywords:
|
||||
no_match = []
|
||||
if len(keywords[key]) == 0:
|
||||
continue
|
||||
else:
|
||||
for i in range(len(results)):
|
||||
if not results[i][key] in keywords[key]:
|
||||
no_match.append(i)
|
||||
for num in no_match[::-1]:
|
||||
results.pop(num)
|
||||
return results
|
||||
|
|
Loading…
Reference in New Issue