mirror of https://github.com/PyRetri/PyRetri.git
upload
parent
402cb98b3b
commit
06afa9b644
|
@ -67,7 +67,7 @@ Examples:
|
||||||
# for dataset collecting images with the same label in one directory
|
# for dataset collecting images with the same label in one directory
|
||||||
python3 main/make_data_json.py -d /data/caltech101/gallery/ -sp data_jsons/caltech_gallery.json -t general
|
python3 main/make_data_json.py -d /data/caltech101/gallery/ -sp data_jsons/caltech_gallery.json -t general
|
||||||
|
|
||||||
python3 main/make_data_json.py -d /data/caltech101/query/ -sp data_jsons/caltech_query.json -t feneral
|
python3 main/make_data_json.py -d /data/caltech101/query/ -sp data_jsons/caltech_query.json -t general
|
||||||
|
|
||||||
# for oxford/paris dataset
|
# for oxford/paris dataset
|
||||||
python3 main/make_data_json.py -d /data/cbir/oxford/gallery/ -sp data_jsons/oxford_gallery.json -t oxford -gt /data/cbir/oxford/gt/
|
python3 main/make_data_json.py -d /data/cbir/oxford/gallery/ -sp data_jsons/oxford_gallery.json -t oxford -gt /data/cbir/oxford/gt/
|
||||||
|
@ -229,10 +229,10 @@ cd search/
|
||||||
|
|
||||||
### Define Search Space
|
### Define Search Space
|
||||||
|
|
||||||
We decompose the search space into three sub search spaces: data_process, extract and index, each of which corresponds to a specified file. Search space is defined by adding methods with hyper-parameters to a specified dict. You can add a search operator as follows:
|
We decompose the search space into three sub search spaces: pre_process, extract and index, each of which corresponds to a specified file. Search space is defined by adding methods with hyper-parameters to a specified dict. You can add a search operator as follows:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"PadResize224",
|
"PadResize224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -257,7 +257,7 @@ data_processes.add(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
By doing this, a data_process operator named "PadResize224" is added to the data_process sub search space and will be searched in the following process.
|
By doing this, a pre_process operator named "PadResize224" is added to the data_process sub search space and will be searched in the following process.
|
||||||
|
|
||||||
### Search
|
### Search
|
||||||
|
|
||||||
|
@ -287,7 +287,7 @@ python3 search_extract.py -sp /data/features/gap_gmp_gem_crow_spoc/ -sm search_m
|
||||||
Search for the indexing combinations by:
|
Search for the indexing combinations by:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python3 search_query.py [-fd ${fea_dir}] [-sm ${search_modules}] [-sp ${save_path}]
|
python3 search_index.py [-fd ${fea_dir}] [-sm ${search_modules}] [-sp ${save_path}]
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -299,8 +299,23 @@ Arguments:
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python3 search_query.py -fd /data/features/gap_gmp_gem_crow_spoc/ -sm search_modules -sp /data/features/gap_gmp_gem_crow_spoc_result.json
|
python3 search_index.py -fd /data/features/gap_gmp_gem_crow_spoc/ -sm search_modules -sp /data/features/gap_gmp_gem_crow_spoc_result.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### show search results
|
||||||
|
|
||||||
|
We provide two ways to show the search results. One is save all the search results in a csv format file, which can be used for further analyses. The other is showing the search results according to the given keywords. You can define the keywords as follows:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
keywords = {
|
||||||
|
'data_name': ['market'],
|
||||||
|
'pre_process_name': list(),
|
||||||
|
'model_name': list(),
|
||||||
|
'feature_map_name': list(),
|
||||||
|
'aggregator_name': list(),
|
||||||
|
'post_process_name': ['no_fea_process', 'l2_normalize', 'pca_whiten', 'pca_wo_whiten'],
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
See show_search_results.py for more details.
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ cd pyretri
|
||||||
3. Install PyRetri.
|
3. Install PyRetri.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python setup.py install
|
python3 setup.py install
|
||||||
```
|
```
|
||||||
|
|
||||||
## Prepare Datasets
|
## Prepare Datasets
|
||||||
|
|
|
@ -35,7 +35,7 @@ Choosing the implementations mentioned above as baselines and adding some tricks
|
||||||
|
|
||||||
For person re-identification, we use the model provided by [Person_reID_baseline](https://github.com/layumi/Person_reID_baseline_pytorch) and reproduce its resutls. In addition, we train a model on DukeMTMC-reID through the open source code for further experiments.
|
For person re-identification, we use the model provided by [Person_reID_baseline](https://github.com/layumi/Person_reID_baseline_pytorch) and reproduce its resutls. In addition, we train a model on DukeMTMC-reID through the open source code for further experiments.
|
||||||
|
|
||||||
###pre-trained models
|
### pre-trained models
|
||||||
|
|
||||||
| Training Set | Backbone | for Short | Download |
|
| Training Set | Backbone | for Short | Download |
|
||||||
| :-----------: | :-------: | :-------: | :------: |
|
| :-----------: | :-------: | :-------: | :------: |
|
||||||
|
|
|
@ -84,7 +84,9 @@ class PartPCA(DimProcessorBase):
|
||||||
pca = self.pcas[fea_name]["pca"]
|
pca = self.pcas[fea_name]["pca"]
|
||||||
|
|
||||||
ori_fea = fea[:, st_idx: ed_idx]
|
ori_fea = fea[:, st_idx: ed_idx]
|
||||||
proj_fea = pca.transform(ori_fea)
|
proj_fea = normalize(ori_fea, norm='l2')
|
||||||
|
proj_fea = pca.transform(proj_fea)
|
||||||
|
proj_fea = normalize(proj_fea, norm='l2')
|
||||||
|
|
||||||
ret.append(proj_fea)
|
ret.append(proj_fea)
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,7 @@ class PartSVD(DimProcessorBase):
|
||||||
else:
|
else:
|
||||||
proj_part_dim = self._hyper_params["proj_dim"] - already_proj_dim
|
proj_part_dim = self._hyper_params["proj_dim"] - already_proj_dim
|
||||||
assert proj_part_dim < ori_part_dim, "reduction dimension can not be distributed to each part!"
|
assert proj_part_dim < ori_part_dim, "reduction dimension can not be distributed to each part!"
|
||||||
|
already_proj_dim += proj_part_dim
|
||||||
|
|
||||||
svd = SKSVD(n_components=proj_part_dim)
|
svd = SKSVD(n_components=proj_part_dim)
|
||||||
train_fea = fea[:, st_idx: ed_idx]
|
train_fea = fea[:, st_idx: ed_idx]
|
||||||
|
@ -79,22 +80,23 @@ class PartSVD(DimProcessorBase):
|
||||||
}
|
}
|
||||||
|
|
||||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||||
if self._hyper_params["proj_dim"] != 0:
|
fea_names = np.sort(list(self.svds.keys()))
|
||||||
ret = np.zeros(shape=(fea.shape[0], self._hyper_params["proj_dim"]))
|
ret = list()
|
||||||
else:
|
|
||||||
ret = np.zeros(shape=(fea.shape[0], fea.shape[1] - len(self.svds)))
|
|
||||||
|
|
||||||
for fea_name in self.svds:
|
for fea_name in fea_names:
|
||||||
st_idx, ed_idx = self.svds[fea_name]["pos"][0], self.svds[fea_name]["pos"][1]
|
st_idx, ed_idx = self.svds[fea_name]["pos"][0], self.svds[fea_name]["pos"][1]
|
||||||
svd = self.svds[fea_name]["svd"]
|
svd = self.svds[fea_name]["svd"]
|
||||||
|
|
||||||
proj_fea = fea[:, st_idx: ed_idx]
|
proj_fea = fea[:, st_idx: ed_idx]
|
||||||
|
proj_fea = normalize(proj_fea, norm='l2')
|
||||||
proj_fea = svd.transform(proj_fea)
|
proj_fea = svd.transform(proj_fea)
|
||||||
if self._hyper_params["whiten"]:
|
if self._hyper_params["whiten"]:
|
||||||
proj_fea = proj_fea / (self.svds[fea_name]["std"] + 1e-6)
|
proj_fea = proj_fea / (self.svds[fea_name]["std"] + 1e-6)
|
||||||
|
proj_fea = normalize(proj_fea, norm='l2')
|
||||||
|
|
||||||
ret[:, st_idx: ed_idx] = proj_fea
|
ret.append(proj_fea)
|
||||||
|
|
||||||
|
ret = np.concatenate(ret, axis=1)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -49,24 +49,27 @@ def main():
|
||||||
|
|
||||||
# load search space
|
# load search space
|
||||||
datasets = load_datasets()
|
datasets = load_datasets()
|
||||||
data_processes = importlib.import_module("{}.data_process_dict".format(args.search_modules)).data_processes
|
pre_processes = importlib.import_module("{}.pre_process_dict".format(args.search_modules)).pre_processes
|
||||||
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
|
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
|
||||||
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
|
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
|
||||||
|
|
||||||
# search in an exhaustive way
|
# search in an exhaustive way
|
||||||
for data_name, data_args in datasets.items():
|
for data_name, data_args in datasets.items():
|
||||||
for data_proc_name, data_proc_args in data_processes.items():
|
for pre_proc_name, pre_proc_args in pre_processes.items():
|
||||||
|
|
||||||
if 'market' in data_name:
|
if 'market' in data_name:
|
||||||
model_name = 'market_res50'
|
model_name = 'market_res50'
|
||||||
elif 'duke' in data_name:
|
elif 'duke' in data_name:
|
||||||
model_name = 'duke_res50'
|
model_name = 'duke_res50'
|
||||||
|
|
||||||
feature_full_name = data_name + "_" + data_proc_name + "_" + model_name
|
feature_full_name = data_name + "_" + pre_proc_name + "_" + model_name
|
||||||
print(feature_full_name)
|
print(feature_full_name)
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(args.save_path, feature_full_name)):
|
||||||
|
print("[Search Extract]: config exists...")
|
||||||
|
continue
|
||||||
|
|
||||||
# load retrieval pipeline settings
|
# load retrieval pipeline settings
|
||||||
cfg.datasets.merge_from_other_cfg(data_proc_args)
|
cfg.datasets.merge_from_other_cfg(pre_proc_args)
|
||||||
cfg.model.merge_from_other_cfg(models[model_name])
|
cfg.model.merge_from_other_cfg(models[model_name])
|
||||||
cfg.extract.merge_from_other_cfg(extracts[model_name])
|
cfg.extract.merge_from_other_cfg(extracts[model_name])
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import importlib
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from .utils.misc import check_exist, get_dir
|
from utils.misc import check_result_exist, get_dir, get_default_result_dict
|
||||||
|
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
from pyretri.index import build_index_helper, feature_loader
|
from pyretri.index import build_index_helper, feature_loader
|
||||||
|
@ -53,8 +53,8 @@ def main():
|
||||||
|
|
||||||
# load search space
|
# load search space
|
||||||
datasets = load_datasets()
|
datasets = load_datasets()
|
||||||
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
|
indexes = importlib.import_module("{}.index_dict".format(args.search_modules)).indexes
|
||||||
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
|
evaluates = importlib.import_module("{}.index_dict".format(args.search_modules)).evaluates
|
||||||
|
|
||||||
if os.path.exists(args.save_path):
|
if os.path.exists(args.save_path):
|
||||||
with open(args.save_path, "r") as f:
|
with open(args.save_path, "r") as f:
|
||||||
|
@ -64,8 +64,9 @@ def main():
|
||||||
|
|
||||||
for dir in os.listdir(args.fea_dir):
|
for dir in os.listdir(args.fea_dir):
|
||||||
for data_name, data_args in datasets.items():
|
for data_name, data_args in datasets.items():
|
||||||
for query_name, query_args in queries.items():
|
for index_name, index_args in indexes.items():
|
||||||
if data_name in dir:
|
if data_name in dir:
|
||||||
|
print(dir)
|
||||||
|
|
||||||
# get dirs
|
# get dirs
|
||||||
gallery_fea_dir, query_fea_dir, train_fea_dir = get_dir(args.fea_dir, dir, data_args)
|
gallery_fea_dir, query_fea_dir, train_fea_dir = get_dir(args.fea_dir, dir, data_args)
|
||||||
|
@ -73,21 +74,19 @@ def main():
|
||||||
# get evaluate setting
|
# get evaluate setting
|
||||||
evaluate_args = evaluates["reid_overall"]
|
evaluate_args = evaluates["reid_overall"]
|
||||||
|
|
||||||
for dim_proc in query_args.dim_processors.names:
|
for dim_proc in index_args.dim_processors.names:
|
||||||
if dim_proc in ["PartPCA", "PartSVD", "PCA", "SVD"]:
|
if dim_proc in ["PartPCA", "PartSVD", "PCA", "SVD"]:
|
||||||
query_args.dim_processors[dim_proc].train_fea_dir = train_fea_dir
|
index_args.dim_processors[dim_proc].train_fea_dir = train_fea_dir
|
||||||
|
|
||||||
for fea_name in fea_names:
|
for fea_name in fea_names:
|
||||||
|
result_dict = get_default_result_dict(dir, data_name, index_name, fea_name)
|
||||||
result_dict = get_default_result_dict(dir, data_name, query_name, fea_name)
|
if check_result_exist(result_dict, results):
|
||||||
if check_exist(result_dict, results):
|
|
||||||
print("[Search Query]: config exists...")
|
print("[Search Query]: config exists...")
|
||||||
continue
|
continue
|
||||||
print(data_name + '_' + fea_name + '_' + query_name)
|
|
||||||
|
|
||||||
# load retrieval pipeline settings
|
# load retrieval pipeline settings
|
||||||
query_args.feature_names = [fea_name]
|
index_args.feature_names = [fea_name]
|
||||||
cfg.index.merge_from_other_cfg(query_args)
|
cfg.index.merge_from_other_cfg(index_args)
|
||||||
cfg.evaluate.merge_from_other_cfg(evaluate_args)
|
cfg.evaluate.merge_from_other_cfg(evaluate_args)
|
||||||
|
|
||||||
# load features
|
# load features
|
||||||
|
@ -95,12 +94,12 @@ def main():
|
||||||
gallery_fea, gallery_info, _ = feature_loader.load(gallery_fea_dir, [fea_name])
|
gallery_fea, gallery_info, _ = feature_loader.load(gallery_fea_dir, [fea_name])
|
||||||
|
|
||||||
# build helper and index features
|
# build helper and index features
|
||||||
query_helper = build_index_helper(cfg.index)
|
index_helper = build_index_helper(cfg.index)
|
||||||
query_result_info, _, _ = query_helper.do_index(query_fea, query_info, gallery_fea)
|
index_result_info, _, _ = index_helper.do_index(query_fea, query_info, gallery_fea)
|
||||||
|
|
||||||
# build helper and evaluate results
|
# build helper and evaluate results
|
||||||
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
||||||
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info, gallery_info)
|
mAP, recall_at_k = evaluate_helper.do_eval(index_result_info, gallery_info)
|
||||||
|
|
||||||
# record results
|
# record results
|
||||||
to_save_recall = dict()
|
to_save_recall = dict()
|
|
@ -19,6 +19,7 @@ models.add(
|
||||||
extracts.add(
|
extracts.add(
|
||||||
"market_res50",
|
"market_res50",
|
||||||
{
|
{
|
||||||
|
"assemble": 1,
|
||||||
"extractor": {
|
"extractor": {
|
||||||
"name": "ReIDSeries",
|
"name": "ReIDSeries",
|
||||||
"ReIDSeries": {
|
"ReIDSeries": {
|
||||||
|
@ -39,13 +40,14 @@ models.add(
|
||||||
{
|
{
|
||||||
"name": "ft_net",
|
"name": "ft_net",
|
||||||
"ft_net": {
|
"ft_net": {
|
||||||
"load_checkpoint": "/home/songrenjie/projects/reID_baseline/model/ft_ResNet50/net_59.pth"
|
"load_checkpoint": "/home/songrenjie/projects/reID_baseline/model/ft_ResNet50/res50_duke.pth"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
extracts.add(
|
extracts.add(
|
||||||
"duke_res50",
|
"duke_res50",
|
||||||
{
|
{
|
||||||
|
"assemble": 1,
|
||||||
"extractor": {
|
"extractor": {
|
||||||
"name": "ReIDSeries",
|
"name": "ReIDSeries",
|
||||||
"ReIDSeries": {
|
"ReIDSeries": {
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
from utils.search_modules import SearchModules
|
from utils.search_modules import SearchModules
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
|
|
||||||
queries = SearchModules()
|
indexes = SearchModules()
|
||||||
evaluates = SearchModules()
|
evaluates = SearchModules()
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"no_fea_process",
|
"no_fea_process",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -32,7 +32,7 @@ queries.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"l2_normalize",
|
"l2_normalize",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -58,7 +58,7 @@ queries.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"pca_wo_whiten",
|
"pca_wo_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -67,11 +67,12 @@ queries.add(
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"dim_processors": {
|
"dim_processors": {
|
||||||
"names": ["PartPCA"],
|
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||||
"PartPCA": {
|
"PCA": {
|
||||||
"whiten": False,
|
"whiten": False,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 512
|
"proj_dim": 512,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -89,7 +90,7 @@ queries.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"pca_whiten",
|
"pca_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -98,11 +99,12 @@ queries.add(
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"dim_processors": {
|
"dim_processors": {
|
||||||
"names": ["PartPCA"],
|
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||||
"PartPCA": {
|
"PCA": {
|
||||||
"whiten": True,
|
"whiten": True,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 512
|
"proj_dim": 512,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -120,7 +122,7 @@ queries.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"svd_wo_whiten",
|
"svd_wo_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -129,11 +131,12 @@ queries.add(
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"dim_processors": {
|
"dim_processors": {
|
||||||
"names": ["PartSVD"],
|
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||||
"PartSVD": {
|
"SVD": {
|
||||||
"whiten": False,
|
"whiten": False,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 511
|
"proj_dim": 511,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -151,7 +154,7 @@ queries.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"svd_whiten",
|
"svd_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -160,11 +163,12 @@ queries.add(
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"dim_processors": {
|
"dim_processors": {
|
||||||
"names": ["PartSVD"],
|
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||||
"PartSVD": {
|
"SVD": {
|
||||||
"whiten": True,
|
"whiten": True,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 511
|
"proj_dim": 511,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -211,5 +215,5 @@ evaluates.add(
|
||||||
|
|
||||||
cfg = get_defaults_cfg()
|
cfg = get_defaults_cfg()
|
||||||
|
|
||||||
queries.check_valid(cfg["index"])
|
indexes.check_valid(cfg["index"])
|
||||||
evaluates.check_valid(cfg["evaluate"])
|
evaluates.check_valid(cfg["evaluate"])
|
|
@ -3,9 +3,9 @@
|
||||||
from utils.search_modules import SearchModules
|
from utils.search_modules import SearchModules
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
|
|
||||||
data_processes = SearchModules()
|
pre_processes = SearchModules()
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"Direct256128",
|
"Direct256128",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -31,4 +31,4 @@ data_processes.add(
|
||||||
|
|
||||||
cfg = get_defaults_cfg()
|
cfg = get_defaults_cfg()
|
||||||
|
|
||||||
data_processes.check_valid(cfg["datasets"])
|
pre_processes.check_valid(cfg["datasets"])
|
|
@ -1,149 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import json
|
|
||||||
import importlib
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from pyretri.config import get_defaults_cfg
|
|
||||||
from pyretri.query import build_query_helper
|
|
||||||
from pyretri.evaluate import build_evaluate_helper
|
|
||||||
from pyretri.query.utils import feature_loader
|
|
||||||
|
|
||||||
fea_names = ["output"]
|
|
||||||
|
|
||||||
task_mapping = {
|
|
||||||
"market_gallery": {
|
|
||||||
"gallery": "market_gallery",
|
|
||||||
"query": "market_query",
|
|
||||||
"train_fea_dir": "market_gallery"
|
|
||||||
},
|
|
||||||
"duke_gallery": {
|
|
||||||
"gallery": "duke_gallery",
|
|
||||||
"query": "duke_query",
|
|
||||||
"train_fea_dir": "duke_gallery"
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def check_exist(now_res, exist_results):
|
|
||||||
for e_r in exist_results:
|
|
||||||
totoal_equal = True
|
|
||||||
for key in now_res:
|
|
||||||
if now_res[key] != e_r[key]:
|
|
||||||
totoal_equal = False
|
|
||||||
break
|
|
||||||
if totoal_equal:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
|
||||||
parser.add_argument('--fea_dir', '-f', default=None, type=str, help="path of feature directory", required=True)
|
|
||||||
parser.add_argument(
|
|
||||||
"--search_modules",
|
|
||||||
"-m",
|
|
||||||
default="",
|
|
||||||
help="name of search module's directory",
|
|
||||||
type=str,
|
|
||||||
required=True
|
|
||||||
)
|
|
||||||
parser.add_argument("--save_path", "-s", default=None, type=str, required=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_result_dict(dir, task_name, query_name, fea_name):
|
|
||||||
|
|
||||||
result_dict = {
|
|
||||||
"task_name": task_name.split("_")[0],
|
|
||||||
"dataprocess": dir.split("_")[0],
|
|
||||||
"model_name": "_".join(dir.split("_")[-2:]),
|
|
||||||
"feature_map_name": fea_name.split("_")[0],
|
|
||||||
"fea_process_name": query_name
|
|
||||||
}
|
|
||||||
if fea_name == "output":
|
|
||||||
result_dict["aggregator_name"] = "none"
|
|
||||||
else:
|
|
||||||
result_dict["aggregator_name"] = fea_name.split("_")[1]
|
|
||||||
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
cfg = get_defaults_cfg()
|
|
||||||
|
|
||||||
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
|
|
||||||
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
|
|
||||||
|
|
||||||
if os.path.exists(args.save_path):
|
|
||||||
with open(args.save_path, "r") as f:
|
|
||||||
results = json.load(f)
|
|
||||||
else:
|
|
||||||
results = list()
|
|
||||||
|
|
||||||
q_cnt = 0
|
|
||||||
for dir in os.listdir(args.fea_dir):
|
|
||||||
q_cnt += 1
|
|
||||||
print("Processing {} / {} queries...".format(q_cnt, len(queries)))
|
|
||||||
for task_name in task_mapping:
|
|
||||||
for query_name, query_args in queries.items():
|
|
||||||
if task_name in dir:
|
|
||||||
for fea_name in fea_names:
|
|
||||||
|
|
||||||
gallery_fea_dir = os.path.join(args.fea_dir, dir)
|
|
||||||
query_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["query"])
|
|
||||||
train_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["train_fea_dir"])
|
|
||||||
|
|
||||||
assert os.path.exists(gallery_fea_dir), gallery_fea_dir
|
|
||||||
assert os.path.exists(query_fea_dir), query_fea_dir
|
|
||||||
assert os.path.exists(train_fea_dir), train_fea_dir
|
|
||||||
|
|
||||||
for post_proc in ["PartPCA", "PartSVD"]:
|
|
||||||
if post_proc in query_args.post_processors.names:
|
|
||||||
query_args.post_processors[post_proc].train_fea_dir = train_fea_dir
|
|
||||||
query_args.gallery_fea_dir, query_args.query_fea_dir = gallery_fea_dir, query_fea_dir
|
|
||||||
query_args.feature_names = [fea_name]
|
|
||||||
eval_args = evaluates["reid_overall"]
|
|
||||||
|
|
||||||
result_dict = get_default_result_dict(dir, task_name, query_name, fea_name)
|
|
||||||
|
|
||||||
if check_exist(result_dict, results):
|
|
||||||
print("[Search Query]: config exists...")
|
|
||||||
continue
|
|
||||||
|
|
||||||
cfg.query.merge_from_other_cfg(query_args)
|
|
||||||
cfg.evaluate.merge_from_other_cfg(eval_args)
|
|
||||||
|
|
||||||
query_helper = build_query_helper(cfg.query)
|
|
||||||
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
|
||||||
|
|
||||||
query_fea, query_info_dicts, _ = feature_loader.load(cfg.query.query_fea_dir,
|
|
||||||
cfg.query.feature_names)
|
|
||||||
gallery_fea, gallery_info_dicts, _ = feature_loader.load(cfg.query.gallery_fea_dir,
|
|
||||||
cfg.query.feature_names)
|
|
||||||
|
|
||||||
query_result_info_dicts, _, _ = query_helper.do_query(query_fea, query_info_dicts, gallery_fea)
|
|
||||||
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info_dicts, gallery_info_dicts)
|
|
||||||
|
|
||||||
to_save_dict = dict()
|
|
||||||
for k in recall_at_k:
|
|
||||||
to_save_dict[str(k)] = recall_at_k[k]
|
|
||||||
|
|
||||||
result_dict["mAP"] = float(mAP)
|
|
||||||
result_dict["recall_at_k"] = to_save_dict
|
|
||||||
print(result_dict)
|
|
||||||
assert False
|
|
||||||
|
|
||||||
results.append(result_dict)
|
|
||||||
with open(args.save_path, "w") as f:
|
|
||||||
json.dump(results, f)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
|
@ -30,17 +30,11 @@ def load_datasets():
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
def check_exist(save_path, full_name):
|
|
||||||
if os.path.exists(os.path.join(save_path, full_name)):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||||
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
|
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
|
||||||
parser.add_argument('--save_path', '-sp', default=None, type=str, help="save path for feature")
|
parser.add_argument('--save_path', '-sp', default=None, type=str, help="save path for feature")
|
||||||
parser.add_argument("--search_modules", "-sm", default="", type=str, help="name of search module's directory")
|
parser.add_argument("--search_modules", "-sm", default=None, type=str, help="name of search module's directory")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
@ -58,24 +52,24 @@ def main():
|
||||||
|
|
||||||
# load search space
|
# load search space
|
||||||
datasets = load_datasets()
|
datasets = load_datasets()
|
||||||
data_processes = importlib.import_module("{}.data_process_dict".format(args.search_modules)).data_processes
|
pre_processes = importlib.import_module("{}.pre_process_dict".format(args.search_modules)).pre_processes
|
||||||
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
|
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
|
||||||
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
|
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
|
||||||
|
|
||||||
# search in an exhaustive way
|
# search in an exhaustive way
|
||||||
for data_name, data_args in datasets.items():
|
for data_name, data_args in datasets.items():
|
||||||
for data_proc_name, data_proc_args in data_processes.items():
|
for pre_proc_name, pre_proc_args in pre_processes.items():
|
||||||
for model_name, model_args in models.items():
|
for model_name, model_args in models.items():
|
||||||
|
|
||||||
feature_full_name = data_name + "_" + data_proc_name + "_" + model_name
|
feature_full_name = data_name + "_" + pre_proc_name + "_" + model_name
|
||||||
print(feature_full_name)
|
print(feature_full_name)
|
||||||
|
|
||||||
if check_exist(args.save_path, feature_full_name):
|
if os.path.exists(os.path.join(args.save_path, feature_full_name)):
|
||||||
print("[Search Extract]: config exists...")
|
print("[Search Extract]: config exists...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# load retrieval pipeline settings
|
# load retrieval pipeline settings
|
||||||
cfg.datasets.merge_from_other_cfg(data_proc_args)
|
cfg.datasets.merge_from_other_cfg(pre_proc_args)
|
||||||
cfg.model.merge_from_other_cfg(model_args)
|
cfg.model.merge_from_other_cfg(model_args)
|
||||||
cfg.extract.merge_from_other_cfg(extracts[model_name])
|
cfg.extract.merge_from_other_cfg(extracts[model_name])
|
||||||
|
|
||||||
|
|
|
@ -5,13 +5,26 @@ import importlib
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from .utils.misc import check_exist, get_dir
|
from utils.misc import check_result_exist, get_dir, get_default_result_dict
|
||||||
|
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
from pyretri.index import build_index_helper, feature_loader
|
from pyretri.index import build_index_helper, feature_loader
|
||||||
from pyretri.evaluate import build_evaluate_helper
|
from pyretri.evaluate import build_evaluate_helper
|
||||||
|
|
||||||
|
|
||||||
|
# # gap, gmp, gem, spoc, crow
|
||||||
|
# vgg_fea = ["pool4_GAP", "pool4_GMP", "pool4_GeM", "pool4_SPoC", "pool4_Crow",
|
||||||
|
# "pool5_GAP", "pool5_GMP", "pool5_GeM", "pool5_SPoC", "pool5_Crow",
|
||||||
|
# "fc"]
|
||||||
|
# res_fea = ["pool3_GAP", "pool3_GMP", "pool3_GeM", "pool3_SPoC", "pool4_Crow",
|
||||||
|
# "pool4_GAP", "pool4_GMP", "pool4_GeM", "pool4_SPoC", "pool4_Crow",
|
||||||
|
# "pool5_GAP", "pool5_GMP", "pool5_GeM", "pool5_SPoC", "pool5_Crow"]
|
||||||
|
|
||||||
|
# # scda, rmca
|
||||||
|
# vgg_fea = ["pool5_SCDA", "pool5_RMAC"]
|
||||||
|
# res_fea = ["pool5_SCDA", "pool5_RMAC"]
|
||||||
|
|
||||||
|
# pwa
|
||||||
vgg_fea = ["pool5_PWA"]
|
vgg_fea = ["pool5_PWA"]
|
||||||
res_fea = ["pool5_PWA"]
|
res_fea = ["pool5_PWA"]
|
||||||
|
|
||||||
|
@ -80,8 +93,8 @@ def main():
|
||||||
|
|
||||||
# load search space
|
# load search space
|
||||||
datasets = load_datasets()
|
datasets = load_datasets()
|
||||||
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
|
indexes = importlib.import_module("{}.index_dict".format(args.search_modules)).indexes
|
||||||
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
|
evaluates = importlib.import_module("{}.index_dict".format(args.search_modules)).evaluates
|
||||||
|
|
||||||
if os.path.exists(args.save_path):
|
if os.path.exists(args.save_path):
|
||||||
with open(args.save_path, "r") as f:
|
with open(args.save_path, "r") as f:
|
||||||
|
@ -91,8 +104,10 @@ def main():
|
||||||
|
|
||||||
for dir in os.listdir(args.fea_dir):
|
for dir in os.listdir(args.fea_dir):
|
||||||
for data_name, data_args in datasets.items():
|
for data_name, data_args in datasets.items():
|
||||||
for query_name, query_args in queries.items():
|
for index_name, index_args in indexes.items():
|
||||||
if data_name in dir:
|
if data_name in dir:
|
||||||
|
print(dir)
|
||||||
|
|
||||||
# get dirs
|
# get dirs
|
||||||
gallery_fea_dir, query_fea_dir, train_fea_dir = get_dir(args.fea_dir, dir, data_args)
|
gallery_fea_dir, query_fea_dir, train_fea_dir = get_dir(args.fea_dir, dir, data_args)
|
||||||
|
|
||||||
|
@ -102,19 +117,20 @@ def main():
|
||||||
# get feature names
|
# get feature names
|
||||||
fea_names = get_fea_names(gallery_fea_dir)
|
fea_names = get_fea_names(gallery_fea_dir)
|
||||||
|
|
||||||
for post_proc in query_args.post_processors.names:
|
# set train feature path for dimension reduction processes
|
||||||
if post_proc in ["PartPCA", "PartSVD", "PCA", "SVD"]:
|
for dim_proc in index_args.dim_processors.names:
|
||||||
query_args.post_processors[post_proc].train_fea_dir = train_fea_dir
|
if dim_proc in ["PartPCA", "PartSVD", "PCA", "SVD"]:
|
||||||
|
index_args.dim_processors[dim_proc].train_fea_dir = train_fea_dir
|
||||||
|
|
||||||
for fea_name in fea_names:
|
for fea_name in fea_names:
|
||||||
result_dict = get_default_result_dict(dir, data_name, query_name, fea_name)
|
result_dict = get_default_result_dict(dir, data_name, index_name, fea_name)
|
||||||
if check_exist(result_dict, results):
|
if check_result_exist(result_dict, results):
|
||||||
print("[Search Query]: config exists...")
|
print("[Search Query]: config exists...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# load retrieval pipeline settings
|
# load retrieval pipeline settings
|
||||||
query_args.feature_names = [fea_name]
|
index_args.feature_names = [fea_name]
|
||||||
cfg.index.merge_from_other_cfg(query_args)
|
cfg.index.merge_from_other_cfg(index_args)
|
||||||
cfg.evaluate.merge_from_other_cfg(evaluate_args)
|
cfg.evaluate.merge_from_other_cfg(evaluate_args)
|
||||||
|
|
||||||
# load features
|
# load features
|
||||||
|
@ -122,12 +138,12 @@ def main():
|
||||||
gallery_fea, gallery_info, _ = feature_loader.load(gallery_fea_dir, [fea_name])
|
gallery_fea, gallery_info, _ = feature_loader.load(gallery_fea_dir, [fea_name])
|
||||||
|
|
||||||
# build helper and index features
|
# build helper and index features
|
||||||
query_helper = build_query_helper(cfg.query)
|
index_helper = build_index_helper(cfg.index)
|
||||||
query_result_info, _, _ = query_helper.do_query(query_fea, query_info, gallery_fea)
|
index_result_info, _, _ = index_helper.do_index(query_fea, query_info, gallery_fea)
|
||||||
|
|
||||||
# build helper and evaluate results
|
# build helper and evaluate results
|
||||||
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
||||||
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info, gallery_info)
|
mAP, recall_at_k = evaluate_helper.do_eval(index_result_info, gallery_info)
|
||||||
|
|
||||||
# record results
|
# record results
|
||||||
to_save_recall = dict()
|
to_save_recall = dict()
|
|
@ -3,10 +3,10 @@
|
||||||
from utils.search_modules import SearchModules
|
from utils.search_modules import SearchModules
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
|
|
||||||
queries = SearchModules()
|
indexes = SearchModules()
|
||||||
evaluates = SearchModules()
|
evaluates = SearchModules()
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"pca_wo_whiten",
|
"pca_wo_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -15,11 +15,12 @@ queries.add(
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"dim_processors": {
|
"dim_processors": {
|
||||||
"names": ["PartPCA"],
|
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||||
"PartPCA": {
|
"PCA": {
|
||||||
"whiten": False,
|
"whiten": False,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 512
|
"proj_dim": 512,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -37,7 +38,7 @@ queries.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"pca_whiten",
|
"pca_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -46,11 +47,12 @@ queries.add(
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"dim_processors": {
|
"dim_processors": {
|
||||||
"names": ["PartPCA"],
|
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||||
"PartPCA": {
|
"PCA": {
|
||||||
"whiten": True,
|
"whiten": True,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 512
|
"proj_dim": 512,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -68,7 +70,7 @@ queries.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"svd_wo_whiten",
|
"svd_wo_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -77,11 +79,12 @@ queries.add(
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"dim_processors": {
|
"dim_processors": {
|
||||||
"names": ["PartSVD"],
|
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||||
"PartSVD": {
|
"SVD": {
|
||||||
"whiten": False,
|
"whiten": False,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 511
|
"proj_dim": 511,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -99,7 +102,7 @@ queries.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"svd_whiten",
|
"svd_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -108,11 +111,12 @@ queries.add(
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"dim_processors": {
|
"dim_processors": {
|
||||||
"names": ["PartSVD"],
|
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||||
"PartSVD": {
|
"SVD": {
|
||||||
"whiten": True,
|
"whiten": True,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 511
|
"proj_dim": 511,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -150,5 +154,5 @@ evaluates.add(
|
||||||
|
|
||||||
cfg = get_defaults_cfg()
|
cfg = get_defaults_cfg()
|
||||||
|
|
||||||
queries.check_valid(cfg["index"])
|
indexes.check_valid(cfg["index"])
|
||||||
evaluates.check_valid(cfg["evaluate"])
|
evaluates.check_valid(cfg["evaluate"])
|
|
@ -3,9 +3,9 @@
|
||||||
from utils.search_modules import SearchModules
|
from utils.search_modules import SearchModules
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
|
|
||||||
data_processes = SearchModules()
|
pre_processes = SearchModules()
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"Shorter256Center224",
|
"Shorter256Center224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -16,8 +16,8 @@ data_processes.add(
|
||||||
"name": "CollateFn"
|
"name": "CollateFn"
|
||||||
},
|
},
|
||||||
"transformers": {
|
"transformers": {
|
||||||
"names": ["ResizeShorter", "CenterCrop", "ToTensor", "Normalize"],
|
"names": ["ShorterResize", "CenterCrop", "ToTensor", "Normalize"],
|
||||||
"ResizeShorter": {
|
"ShorterResize": {
|
||||||
"size": 256
|
"size": 256
|
||||||
},
|
},
|
||||||
"CenterCrop": {
|
"CenterCrop": {
|
||||||
|
@ -31,7 +31,7 @@ data_processes.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"Direct224",
|
"Direct224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -54,7 +54,7 @@ data_processes.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"PadResize224",
|
"PadResize224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -80,4 +80,4 @@ data_processes.add(
|
||||||
|
|
||||||
cfg = get_defaults_cfg()
|
cfg = get_defaults_cfg()
|
||||||
|
|
||||||
data_processes.check_valid(cfg["datasets"])
|
pre_processes.check_valid(cfg["datasets"])
|
|
@ -21,9 +21,9 @@ def load_datasets():
|
||||||
"cub_query": os.path.join(data_json_dir, "cub_query.json"),
|
"cub_query": os.path.join(data_json_dir, "cub_query.json"),
|
||||||
"indoor_gallery": os.path.join(data_json_dir, "indoor_gallery.json"),
|
"indoor_gallery": os.path.join(data_json_dir, "indoor_gallery.json"),
|
||||||
"indoor_query": os.path.join(data_json_dir, "indoor_query.json"),
|
"indoor_query": os.path.join(data_json_dir, "indoor_query.json"),
|
||||||
"caltech101_gallery": os.path.join(data_json_dir, "caltech101_gallery.json"),
|
"caltech_gallery": os.path.join(data_json_dir, "caltech_gallery.json"),
|
||||||
"caltech101_query": os.path.join(data_json_dir, "caltech101_query.json"),
|
"caltech_query": os.path.join(data_json_dir, "caltech_query.json"),
|
||||||
"paris": os.path.join(data_json_dir, "paris.json"),
|
"paris_all": os.path.join(data_json_dir, "paris.json"),
|
||||||
}
|
}
|
||||||
for data_path in datasets.values():
|
for data_path in datasets.values():
|
||||||
assert os.path.exists(data_path), "non-exist dataset path {}".format(data_path)
|
assert os.path.exists(data_path), "non-exist dataset path {}".format(data_path)
|
||||||
|
@ -48,25 +48,30 @@ def main():
|
||||||
# init retrieval pipeline settings
|
# init retrieval pipeline settings
|
||||||
cfg = get_defaults_cfg()
|
cfg = get_defaults_cfg()
|
||||||
|
|
||||||
data_processes = importlib.import_module("{}.data_process_dict".format(args.search_modules)).data_processes
|
# load search space
|
||||||
|
datasets = load_datasets()
|
||||||
|
pre_processes = importlib.import_module("{}.pre_process_dict".format(args.search_modules)).pre_processes
|
||||||
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
|
models = importlib.import_module("{}.extract_dict".format(args.search_modules)).models
|
||||||
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
|
extracts = importlib.import_module("{}.extract_dict".format(args.search_modules)).extracts
|
||||||
|
|
||||||
datasets = load_datasets()
|
|
||||||
|
|
||||||
for data_name, data_args in datasets.items():
|
for data_name, data_args in datasets.items():
|
||||||
for data_proc_name, data_proc_args in data_processes.items():
|
for pre_proc_name, pre_proc_args in pre_processes.items():
|
||||||
for model_name, model_args in models.items():
|
for model_name, model_args in models.items():
|
||||||
|
|
||||||
feature_full_name = data_process_name + "_" + dataset_name + "_" + model_name
|
feature_full_name = data_name + "_" + pre_proc_name + "_" + model_name
|
||||||
print(feature_full_name)
|
print(feature_full_name)
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(args.save_path, feature_full_name)):
|
||||||
|
print("[Search Extract]: config exists...")
|
||||||
|
continue
|
||||||
|
|
||||||
# load retrieval pipeline settings
|
# load retrieval pipeline settings
|
||||||
cfg.datasets.merge_from_other_cfg(data_proc_args)
|
cfg.datasets.merge_from_other_cfg(pre_proc_args)
|
||||||
cfg.model.merge_from_other_cfg(model_args)
|
cfg.model.merge_from_other_cfg(model_args)
|
||||||
cfg.extract.merge_from_other_cfg(extracts[model_name])
|
cfg.extract.merge_from_other_cfg(extracts[model_name])
|
||||||
|
|
||||||
pwa_train_fea_dir = os.path.join("/data/my_features/gap_gmp_gem_crow_spoc", feature_full_name)
|
# set train feature path for pwa
|
||||||
|
pwa_train_fea_dir = os.path.join("/data/features/test_gap_gmp_gem_crow_spoc", feature_full_name)
|
||||||
if "query" in pwa_train_fea_dir:
|
if "query" in pwa_train_fea_dir:
|
||||||
pwa_train_fea_dir.replace("query", "gallery")
|
pwa_train_fea_dir.replace("query", "gallery")
|
||||||
elif "paris" in pwa_train_fea_dir:
|
elif "paris" in pwa_train_fea_dir:
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
from utils.search_modules import SearchModules
|
from utils.search_modules import SearchModules
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
|
|
||||||
queries = SearchModules()
|
indexes = SearchModules()
|
||||||
evaluates = SearchModules()
|
evaluates = SearchModules()
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"pca_wo_whiten",
|
"pca_wo_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -14,16 +14,17 @@ queries.add(
|
||||||
|
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"post_processor": {
|
"dim_processors": {
|
||||||
"name": "PartPCA",
|
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||||
"PartPCA": {
|
"PCA": {
|
||||||
"whiten": False,
|
"whiten": False,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 512
|
"proj_dim": 512,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
"database_enhance": {
|
"feature_enhancer": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -31,13 +32,13 @@ queries.add(
|
||||||
"name": "KNN"
|
"name": "KNN"
|
||||||
},
|
},
|
||||||
|
|
||||||
"re_rank": {
|
"re_ranker": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"pca_whiten",
|
"pca_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -45,16 +46,17 @@ queries.add(
|
||||||
|
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"post_processor": {
|
"dim_processors": {
|
||||||
"name": "PartPCA",
|
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||||
"PartPCA": {
|
"PCA": {
|
||||||
"whiten": True,
|
"whiten": True,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 512
|
"proj_dim": 512,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
"database_enhance": {
|
"feature_enhancer": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -62,13 +64,13 @@ queries.add(
|
||||||
"name": "KNN"
|
"name": "KNN"
|
||||||
},
|
},
|
||||||
|
|
||||||
"re_rank": {
|
"re_ranker": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"svd_wo_whiten",
|
"svd_wo_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -76,16 +78,17 @@ queries.add(
|
||||||
|
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"post_processor": {
|
"dim_processors": {
|
||||||
"name": "PartSVD",
|
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||||
"PartSVD": {
|
"SVD": {
|
||||||
"whiten": False,
|
"whiten": False,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 511
|
"proj_dim": 511,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
"database_enhance": {
|
"feature_enhancer": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -93,13 +96,13 @@ queries.add(
|
||||||
"name": "KNN"
|
"name": "KNN"
|
||||||
},
|
},
|
||||||
|
|
||||||
"re_rank": {
|
"re_ranker": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"svd_whiten",
|
"svd_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -107,16 +110,17 @@ queries.add(
|
||||||
|
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"post_processor": {
|
"dim_processors": {
|
||||||
"name": "PartSVD",
|
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||||
"PartSVD": {
|
"SVD": {
|
||||||
"whiten": True,
|
"whiten": True,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 511
|
"proj_dim": 511,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
"database_enhance": {
|
"feature_enhancer": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -124,7 +128,7 @@ queries.add(
|
||||||
"name": "KNN"
|
"name": "KNN"
|
||||||
},
|
},
|
||||||
|
|
||||||
"re_rank": {
|
"re_ranker": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -150,5 +154,5 @@ evaluates.add(
|
||||||
|
|
||||||
cfg = get_defaults_cfg()
|
cfg = get_defaults_cfg()
|
||||||
|
|
||||||
queries.check_valid(cfg["query"])
|
indexes.check_valid(cfg["index"])
|
||||||
evaluates.check_valid(cfg["evaluate"])
|
evaluates.check_valid(cfg["evaluate"])
|
|
@ -3,9 +3,9 @@
|
||||||
from utils.search_modules import SearchModules
|
from utils.search_modules import SearchModules
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
|
|
||||||
data_processes = SearchModules()
|
pre_processes = SearchModules()
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"Shorter256Center224",
|
"Shorter256Center224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -16,8 +16,8 @@ data_processes.add(
|
||||||
"name": "CollateFn"
|
"name": "CollateFn"
|
||||||
},
|
},
|
||||||
"transformers": {
|
"transformers": {
|
||||||
"names": ["ResizeShorter", "CenterCrop", "ToTensor", "Normalize"],
|
"names": ["ShorterResize", "CenterCrop", "ToTensor", "Normalize"],
|
||||||
"ResizeShorter": {
|
"ShorterResize": {
|
||||||
"size": 256
|
"size": 256
|
||||||
},
|
},
|
||||||
"CenterCrop": {
|
"CenterCrop": {
|
||||||
|
@ -31,7 +31,7 @@ data_processes.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"Direct224",
|
"Direct224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -54,7 +54,7 @@ data_processes.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"PadResize224",
|
"PadResize224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -80,4 +80,4 @@ data_processes.add(
|
||||||
|
|
||||||
cfg = get_defaults_cfg()
|
cfg = get_defaults_cfg()
|
||||||
|
|
||||||
data_processes.check_valid(cfg["datasets"])
|
pre_processes.check_valid(cfg["datasets"])
|
|
@ -1,167 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import json
|
|
||||||
import importlib
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from pyretri.config import get_defaults_cfg
|
|
||||||
from pyretri.query import build_query_helper
|
|
||||||
from pyretri.evaluate import build_evaluate_helper
|
|
||||||
from pyretri.index import feature_loader
|
|
||||||
|
|
||||||
|
|
||||||
vgg_fea = ["pool5_PWA"]
|
|
||||||
res_fea = ["pool5_PWA"]
|
|
||||||
|
|
||||||
task_mapping = {
|
|
||||||
"oxford_gallery": {
|
|
||||||
"gallery": "oxford_gallery",
|
|
||||||
"query": "oxford_query",
|
|
||||||
"train_fea_dir": "paris"
|
|
||||||
},
|
|
||||||
"cub_gallery": {
|
|
||||||
"gallery": "cub_gallery",
|
|
||||||
"query": "cub_query",
|
|
||||||
"train_fea_dir": "cub_gallery"
|
|
||||||
},
|
|
||||||
"indoor_gallery": {
|
|
||||||
"gallery": "indoor_gallery",
|
|
||||||
"query": "indoor_query",
|
|
||||||
"train_fea_dir": "indoor_gallery"
|
|
||||||
},
|
|
||||||
"caltech101_gallery": {
|
|
||||||
"gallery": "caltech101_gallery",
|
|
||||||
"query": "caltech101_query",
|
|
||||||
"train_fea_dir": "caltech101_gallery"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def check_exist(now_res, exist_results):
|
|
||||||
for e_r in exist_results:
|
|
||||||
totoal_equal = True
|
|
||||||
for key in now_res:
|
|
||||||
if now_res[key] != e_r[key]:
|
|
||||||
totoal_equal = False
|
|
||||||
break
|
|
||||||
if totoal_equal:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_result_dict(dir, task_name, query_name, fea_name):
|
|
||||||
result_dict = {
|
|
||||||
"task_name": task_name.split("_")[0],
|
|
||||||
"dataprocess": dir.split("_")[0],
|
|
||||||
"model_name": "_".join(dir.split("_")[-2:]),
|
|
||||||
"feature_map_name": fea_name.split("_")[0],
|
|
||||||
"fea_process_name": query_name
|
|
||||||
}
|
|
||||||
|
|
||||||
if fea_name == "fc":
|
|
||||||
result_dict["aggregator_name"] = "none"
|
|
||||||
else:
|
|
||||||
result_dict["aggregator_name"] = fea_name.split("_")[1]
|
|
||||||
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
|
||||||
parser.add_argument('--fea_dir', '-fd', default=None, type=str, help="path of feature dirs", required=True)
|
|
||||||
parser.add_argument("--search_modules", "-sm", default=None, type=str, help="name of search module's directory")
|
|
||||||
parser.add_argument("--save_path", "-sp", default=None, type=str, help="path for saving results")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
# init args
|
|
||||||
args = parse_args()
|
|
||||||
assert args.fea_dir is not None, 'the feature directory must be provided!'
|
|
||||||
assert args.search_modules is not None, 'the search modules must be provided!'
|
|
||||||
assert args.save_path is not None, 'the save path must be provided!'
|
|
||||||
|
|
||||||
# init retrieval pipeline settings
|
|
||||||
cfg = get_defaults_cfg()
|
|
||||||
|
|
||||||
# load search space
|
|
||||||
queries = importlib.import_module("{}.query_dict".format(args.search_modules)).queries
|
|
||||||
evaluates = importlib.import_module("{}.query_dict".format(args.search_modules)).evaluates
|
|
||||||
|
|
||||||
if os.path.exists(args.save_path):
|
|
||||||
with open(args.save_path, "r") as f:
|
|
||||||
results = json.load(f)
|
|
||||||
else:
|
|
||||||
results = list()
|
|
||||||
|
|
||||||
q_cnt = 0
|
|
||||||
for dir in os.listdir(args.fea_dir):
|
|
||||||
q_cnt += 1
|
|
||||||
print("Processing {} / {} queries...".format(q_cnt, len(queries)))
|
|
||||||
for query_name, query_args in queries.items():
|
|
||||||
for task_name in task_mapping:
|
|
||||||
if task_name in dir:
|
|
||||||
|
|
||||||
if "vgg" in gallery_fea_dir:
|
|
||||||
fea_names = vgg_fea
|
|
||||||
else:
|
|
||||||
fea_names = res_fea
|
|
||||||
|
|
||||||
for fea_name in fea_names:
|
|
||||||
gallery_fea_dir = os.path.join(args.fea_dir, dir)
|
|
||||||
query_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["query"])
|
|
||||||
train_fea_dir = gallery_fea_dir.replace(task_name, task_mapping[task_name]["train_fea_dir"])
|
|
||||||
|
|
||||||
for post_proc in ["PartPCA", "PartSVD"]:
|
|
||||||
if post_proc in query_args.post_processors.names:
|
|
||||||
query_args.post_processors[post_proc].train_fea_dir = train_fea_dir
|
|
||||||
|
|
||||||
query.gallery_fea_dir, query.query_fea_dir = gallery_fea_dir, query_fea_dir
|
|
||||||
|
|
||||||
query.feature_names = [fea_name]
|
|
||||||
if task_name == "oxford_base":
|
|
||||||
evaluate = evaluates["oxford_overall"]
|
|
||||||
else:
|
|
||||||
evaluate = evaluates["overall"]
|
|
||||||
|
|
||||||
result_dict = get_default_result_dict(dir, task_name, query_name, fea_name)
|
|
||||||
|
|
||||||
if check_exist(result_dict, results):
|
|
||||||
print("[Search Query]: config exists...")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# load retrieval pipeline settings
|
|
||||||
cfg.query.merge_from_other_cfg(query)
|
|
||||||
cfg.evaluate.merge_from_other_cfg(evaluate)
|
|
||||||
|
|
||||||
# load features
|
|
||||||
query_fea, query_info, _ = feature_loader.load(cfg.query.query_fea_dir, cfg.query.feature_names)
|
|
||||||
gallery_fea, gallery_info, _ = feature_loader.load(cfg.query.gallery_fea_dir,
|
|
||||||
cfg.query.feature_names)
|
|
||||||
|
|
||||||
# build helper and index features
|
|
||||||
query_helper = build_query_helper(cfg.query)
|
|
||||||
query_result_info, _, _ = query_helper.do_query(query_fea, query_info, gallery_fea)
|
|
||||||
|
|
||||||
# build helper and evaluate results
|
|
||||||
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
|
||||||
mAP, recall_at_k = evaluate_helper.do_eval(query_result_info, gallery_info)
|
|
||||||
|
|
||||||
# save results
|
|
||||||
to_save_dict = dict()
|
|
||||||
for k in recall_at_k:
|
|
||||||
to_save_dict[str(k)] = recall_at_k[k]
|
|
||||||
result_dict["mAP"] = float(mAP)
|
|
||||||
result_dict["recall_at_k"] = to_save_dict
|
|
||||||
|
|
||||||
results.append(result_dict)
|
|
||||||
with open(args.save_path, "w") as f:
|
|
||||||
json.dump(results, f)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
|
@ -3,10 +3,10 @@
|
||||||
from utils.search_modules import SearchModules
|
from utils.search_modules import SearchModules
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
|
|
||||||
queries = SearchModules()
|
indexes = SearchModules()
|
||||||
evaluates = SearchModules()
|
evaluates = SearchModules()
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"pca_wo_whiten",
|
"pca_wo_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -14,16 +14,17 @@ queries.add(
|
||||||
|
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"post_processor": {
|
"dim_processors": {
|
||||||
"name": "PartPCA",
|
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||||
"PartPCA": {
|
"PCA": {
|
||||||
"whiten": False,
|
"whiten": False,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 512
|
"proj_dim": 512,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
"database_enhance": {
|
"feature_enhancer": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -31,13 +32,13 @@ queries.add(
|
||||||
"name": "KNN"
|
"name": "KNN"
|
||||||
},
|
},
|
||||||
|
|
||||||
"re_rank": {
|
"re_ranker": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"pca_whiten",
|
"pca_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -45,16 +46,17 @@ queries.add(
|
||||||
|
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"post_processor": {
|
"dim_processors": {
|
||||||
"name": "PartPCA",
|
"names": ["L2Normalize", "PCA", "L2Normalize"],
|
||||||
"PartPCA": {
|
"PCA": {
|
||||||
"whiten": True,
|
"whiten": True,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 512
|
"proj_dim": 512,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
"database_enhance": {
|
"feature_enhancer": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -62,13 +64,13 @@ queries.add(
|
||||||
"name": "KNN"
|
"name": "KNN"
|
||||||
},
|
},
|
||||||
|
|
||||||
"re_rank": {
|
"re_ranker": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"svd_wo_whiten",
|
"svd_wo_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -76,16 +78,17 @@ queries.add(
|
||||||
|
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"post_processor": {
|
"dim_processors": {
|
||||||
"name": "PartSVD",
|
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||||
"PartSVD": {
|
"SVD": {
|
||||||
"whiten": False,
|
"whiten": False,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 511
|
"proj_dim": 511,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
"database_enhance": {
|
"feature_enhancer": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -93,13 +96,13 @@ queries.add(
|
||||||
"name": "KNN"
|
"name": "KNN"
|
||||||
},
|
},
|
||||||
|
|
||||||
"re_rank": {
|
"re_ranker": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
queries.add(
|
indexes.add(
|
||||||
"svd_whiten",
|
"svd_whiten",
|
||||||
{
|
{
|
||||||
"gallery_fea_dir": "",
|
"gallery_fea_dir": "",
|
||||||
|
@ -107,16 +110,17 @@ queries.add(
|
||||||
|
|
||||||
"feature_names": [],
|
"feature_names": [],
|
||||||
|
|
||||||
"post_processor": {
|
"dim_processors": {
|
||||||
"name": "PartSVD",
|
"names": ["L2Normalize", "SVD", "L2Normalize"],
|
||||||
"PartSVD": {
|
"SVD": {
|
||||||
"whiten": True,
|
"whiten": True,
|
||||||
"train_fea_dir": "",
|
"train_fea_dir": "",
|
||||||
"proj_dim": 511
|
"proj_dim": 511,
|
||||||
|
"l2": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
"database_enhance": {
|
"feature_enhancer": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -124,7 +128,7 @@ queries.add(
|
||||||
"name": "KNN"
|
"name": "KNN"
|
||||||
},
|
},
|
||||||
|
|
||||||
"re_rank": {
|
"re_ranker": {
|
||||||
"name": "Identity"
|
"name": "Identity"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -150,5 +154,5 @@ evaluates.add(
|
||||||
|
|
||||||
cfg = get_defaults_cfg()
|
cfg = get_defaults_cfg()
|
||||||
|
|
||||||
queries.check_valid(cfg["query"])
|
indexes.check_valid(cfg["index"])
|
||||||
evaluates.check_valid(cfg["evaluate"])
|
evaluates.check_valid(cfg["evaluate"])
|
|
@ -3,9 +3,9 @@
|
||||||
from utils.search_modules import SearchModules
|
from utils.search_modules import SearchModules
|
||||||
from pyretri.config import get_defaults_cfg
|
from pyretri.config import get_defaults_cfg
|
||||||
|
|
||||||
data_processes = SearchModules()
|
pre_processes = SearchModules()
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"Shorter256Center224",
|
"Shorter256Center224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -16,8 +16,8 @@ data_processes.add(
|
||||||
"name": "CollateFn"
|
"name": "CollateFn"
|
||||||
},
|
},
|
||||||
"transformers": {
|
"transformers": {
|
||||||
"names": ["ResizeShorter", "CenterCrop", "ToTensor", "Normalize"],
|
"names": ["ShorterResize", "CenterCrop", "ToTensor", "Normalize"],
|
||||||
"ResizeShorter": {
|
"ShorterResize": {
|
||||||
"size": 256
|
"size": 256
|
||||||
},
|
},
|
||||||
"CenterCrop": {
|
"CenterCrop": {
|
||||||
|
@ -31,7 +31,7 @@ data_processes.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"Direct224",
|
"Direct224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -54,7 +54,7 @@ data_processes.add(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
data_processes.add(
|
pre_processes.add(
|
||||||
"PadResize224",
|
"PadResize224",
|
||||||
{
|
{
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
@ -80,4 +80,4 @@ data_processes.add(
|
||||||
|
|
||||||
cfg = get_defaults_cfg()
|
cfg = get_defaults_cfg()
|
||||||
|
|
||||||
data_processes.check_valid(cfg["datasets"])
|
pre_processes.check_valid(cfg["datasets"])
|
|
@ -4,9 +4,10 @@ import os
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import csv
|
|
||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
|
from utils.misc import save_to_csv, filter_by_keywords
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||||
|
@ -22,62 +23,30 @@ def show_results(results):
|
||||||
print(results[i])
|
print(results[i])
|
||||||
|
|
||||||
|
|
||||||
def save_to_csv(results, csv_path):
|
|
||||||
start = []
|
|
||||||
col_num = 12
|
|
||||||
if not os.path.exists(csv_path):
|
|
||||||
start = ["data_process", "model", "feature", "fea_process", "market_mAP", "market_mAP_re", "market_R1",
|
|
||||||
"market_R1_re", "duke_mAP", "duke_mAP_re", "duke_R1", "duke_R1_re"]
|
|
||||||
with open(csv_path, 'a+') as f:
|
|
||||||
csv_write = csv.writer(f)
|
|
||||||
if len(start) > 0:
|
|
||||||
csv_write.writerow(start)
|
|
||||||
for i in range(len(results)):
|
|
||||||
data_row = [0 for x in range(col_num)]
|
|
||||||
data_row[0] = results[i]["dataprocess"]
|
|
||||||
data_row[1] = results[i]["model_name"]
|
|
||||||
data_row[2] = results[i]["feature_map_name"]
|
|
||||||
data_row[3] = results[i]["fea_process_name"]
|
|
||||||
if results[i]["task_name"] == 'market':
|
|
||||||
data_row[4] = results[i]["mAP"]
|
|
||||||
data_row[6] = results[i]["recall_at_k"]['1']
|
|
||||||
elif results[i]["task_name"] == 'duke':
|
|
||||||
data_row[8] = results[i]["mAP"]
|
|
||||||
data_row[10] = results[i]["recall_at_k"]['1']
|
|
||||||
csv_write.writerow(data_row)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# init args
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
assert os.path.exists(args.results_json_path), 'the config file must be existed!'
|
assert os.path.exists(args.results_json_path), 'the config file must be existed!'
|
||||||
|
|
||||||
with open(args.results_json_path, "r") as f:
|
with open(args.results_json_path, "r") as f:
|
||||||
results = json.load(f)
|
results = json.load(f)
|
||||||
|
|
||||||
key_words = {
|
# save the search results in a csv format file.
|
||||||
'task_name': ['market'],
|
|
||||||
'dataprocess': list(),
|
|
||||||
'model_name': list(),
|
|
||||||
'feature_map_name': list(),
|
|
||||||
'aggregator_name': list(),
|
|
||||||
'fea_process_name': ['no_fea_process', 'l2_normalize', 'pca_whiten', 'pca_wo_whiten'],
|
|
||||||
}
|
|
||||||
|
|
||||||
csv_path = '/home/songrenjie/projects/RetrievalToolBox/test.csv'
|
csv_path = '/home/songrenjie/projects/RetrievalToolBox/test.csv'
|
||||||
save_to_csv(results, csv_path)
|
save_to_csv(results, csv_path)
|
||||||
|
|
||||||
for key in key_words:
|
# define the keywords to be selected
|
||||||
no_match = []
|
keywords = {
|
||||||
if len(key_words[key]) == 0:
|
'data_name': ['market'],
|
||||||
continue
|
'pre_process_name': list(),
|
||||||
else:
|
'model_name': list(),
|
||||||
for i in range(len(results)):
|
'feature_map_name': list(),
|
||||||
if not results[i][key] in key_words[key]:
|
'aggregator_name': list(),
|
||||||
no_match.append(i)
|
'post_process_name': ['no_fea_process', 'l2_normalize', 'pca_whiten', 'pca_wo_whiten'],
|
||||||
for num in no_match[::-1]:
|
}
|
||||||
results.pop(num)
|
|
||||||
|
|
||||||
|
# show search results according to the given keywords
|
||||||
|
results = filter_by_keywords(results, keywords)
|
||||||
show_results(results)
|
show_results(results)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
import csv
|
||||||
|
|
||||||
|
def check_result_exist(now_res: Dict, exist_results: List) -> bool:
|
||||||
def check_exist(now_res: Dict, exist_results: List) -> bool:
|
|
||||||
"""
|
"""
|
||||||
Check if the config exists.
|
Check if the config exists.
|
||||||
|
|
||||||
|
@ -45,14 +46,14 @@ def get_dir(root_path: str, dir: str, dataset: Dict) -> (str, str, str):
|
||||||
return gallery_fea_dir, query_fea_dir, train_fea_dir
|
return gallery_fea_dir, query_fea_dir, train_fea_dir
|
||||||
|
|
||||||
|
|
||||||
def get_default_result_dict(dir, data_name, query_name, fea_name) -> Dict:
|
def get_default_result_dict(dir: str, data_name: str, index_name: str, fea_name: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
Get the default result dict based on the experimental factors.
|
Get the default result dict based on the experimental factors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dir (str): the path of one single extracted feature directory.
|
dir (str): the path of one single extracted feature directory.
|
||||||
data_name (str): the name of the dataset.
|
data_name (str): the name of the dataset.
|
||||||
query_name (str): the name of query process.
|
index_name (str): the name of query process.
|
||||||
fea_name (str): the name of the features to be loaded.
|
fea_name (str): the name of the features to be loaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -60,16 +61,70 @@ def get_default_result_dict(dir, data_name, query_name, fea_name) -> Dict:
|
||||||
"""
|
"""
|
||||||
result_dict = {
|
result_dict = {
|
||||||
"data_name": data_name.split("_")[0],
|
"data_name": data_name.split("_")[0],
|
||||||
"dataprocess": dir.split("_")[0],
|
"pre_process_name": dir.split("_")[2],
|
||||||
"model_name": "_".join(dir.split("_")[-2:]),
|
"model_name": "_".join(dir.split("_")[-2:]),
|
||||||
"feature_map_name": fea_name.split("_")[0],
|
"feature_map_name": fea_name.split("_")[0],
|
||||||
"fea_process_name": query_name
|
"post_process_name": index_name
|
||||||
}
|
}
|
||||||
|
|
||||||
if fea_name == "fc":
|
if len(fea_name.split("_")) == 1:
|
||||||
result_dict["aggregator_name"] = "none"
|
result_dict["aggregator_name"] = "none"
|
||||||
else:
|
else:
|
||||||
result_dict["aggregator_name"] = fea_name.split("_")[1]
|
result_dict["aggregator_name"] = fea_name.split("_")[1]
|
||||||
|
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_csv(results: List[Dict], csv_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Save the search results in a csv format file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (List): a list of retrieval results.
|
||||||
|
csv_path (str): the path for saving the csv file.
|
||||||
|
"""
|
||||||
|
start = ["data", "pre_process", "model", "feature_map", "aggregator", "post_process"]
|
||||||
|
for i in range(len(start)):
|
||||||
|
results = sorted(results, key=lambda result: result[start[len(start) - i - 1] + "_name"])
|
||||||
|
start.append('mAP')
|
||||||
|
start.append('Recall@1')
|
||||||
|
|
||||||
|
with open(csv_path, 'w') as f:
|
||||||
|
csv_write = csv.writer(f)
|
||||||
|
if len(start) > 0:
|
||||||
|
csv_write.writerow(start)
|
||||||
|
for i in range(len(results)):
|
||||||
|
data_row = [0 for x in range(len(start))]
|
||||||
|
data_row[0] = results[i]["data_name"]
|
||||||
|
data_row[1] = results[i]["pre_process_name"]
|
||||||
|
data_row[2] = results[i]["model_name"]
|
||||||
|
data_row[3] = results[i]["feature_map_name"]
|
||||||
|
data_row[4] = results[i]["aggregator_name"]
|
||||||
|
data_row[5] = results[i]["post_process_name"]
|
||||||
|
data_row[6] = results[i]["mAP"]
|
||||||
|
data_row[7] = results[i]["recall_at_k"]['1']
|
||||||
|
csv_write.writerow(data_row)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_by_keywords(results: List[Dict], keywords: Dict) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Filter the search results according to the given keywords
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (List): a list of retrieval results.
|
||||||
|
keywords (Dict): a dict containing keywords to be selected.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
for key in keywords:
|
||||||
|
no_match = []
|
||||||
|
if len(keywords[key]) == 0:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for i in range(len(results)):
|
||||||
|
if not results[i][key] in keywords[key]:
|
||||||
|
no_match.append(i)
|
||||||
|
for num in no_match[::-1]:
|
||||||
|
results.pop(num)
|
||||||
|
return results
|
||||||
|
|
Loading…
Reference in New Issue