pull/1/head
hubenyi 2020-04-02 14:00:49 +08:00
parent d4145ece41
commit 0b23903d0b
171 changed files with 49676 additions and 0 deletions

73
README.md 100644
View File

@ -0,0 +1,73 @@
# PyRetri
## Introduction
PyRetri is an open source deep learning-based image retrieval toolbox based on PyTorch.
### Major features
- **Modular Design**
We decompose the deep learning-based image retrieval into several stages and one can easily construct a image retrieval pipeline by combining different modules.
- **Flexible Loading**
The toolbox is able to adapt to load several types of model parameters, including parameters with the same keys and shape, parameters with different keys, and parameters with the same keys but different shapes.
- **Support of Multiple Methods**
The toolbox directly supports popluar methods designed fot deep learning-based image retrieval.
- **Combinations Search Tool**
We provide the pipeline combinations search scripts to help users to find possible combinations of approaches with various hyper-parameters.
### Supported Methods
- **Pre-processing**
- DirectReszie, PadResize, ShorterResize
- CenterCrop, TenCrop
- TwoFlip
- ToTensor, ToCaffeTensor
- Normalize
- **Feature Representation**
- GAP
- GMP
- [R-MAC](https://arxiv.org/pdf/1511.05879.pdf)
- [SPoC](https://arxiv.org/pdf/1510.07493.pdf)
- [CroW](https://arxiv.org/pdf/1512.04065.pdf)
- [SCDA](http://www.weixiushen.com/publication/tip17SCDA.pdf)
- [GeM](https://pdfs.semanticscholar.org/a2ca/e0ed91d8a3298b3209fc7ea0a4248b914386.pdf)
- [PWA](https://arxiv.org/abs/1705.01247)
- [PCB](http://openaccess.thecvf.com/content_ECCV_2018/papers/Yifan_Sun_Beyond_Part_Models_ECCV_2018_paper.pdf)
- **Post-processing**
- [SVD](https://link.springer.com/chapter/10.1007%2F978-3-662-39778-7_10)
- [PCA](http://pzs.dstu.dp.ua/DataMining/pca/bibl/Principal%20components%20analysis.pdf)
- [DBA](https://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf)
- [QE](https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf)
- [K-reciprocal](https://arxiv.org/pdf/1701.08398.pdf)
## License
This project is released under the [Apache 2.0 license](https://github.com/hby96/pyretri/blob/master/LICENSE).
## Installation
Please refer to INSTALL.md for installation and dataset preparation.
## Get Started
Please see GETTING_STARTED.md for the basic usage of PyRetri.
## Model Zoo
Results and models are available in MODEL_ZOO.md.
## Citation
If you use this toolbox in your research, please cite this project.
```shell
```

View File

@ -0,0 +1,306 @@
# Getting started
This page provides basic tutorials about the usage of PyRetri. For installation instructions and dataset preparation, please see [INSTALL.md] (INSTALL.md).
## Make Data Json
After the gallery set and query set are separated, we package the information of each subset in pickle format for further process. We use different types to package different structured folders: `general`, `oxford` and `reid`.
The general object recognition dataset collects images with the same label in one directory and the folder structure should be like this:
```shell
# type: general
general_recognition
├── class A
│ ├── XXX.jpg
│ └── ···
├── class B
│ ├── XXX.jpg
│ └── ···
└── ···
```
Oxford5k is a typical dataset in image retrieval field and the folder structure is as follows:
```shell
# type: oxford
oxford
├── gt
│ ├── XXX.txt
│ └── ···
└── images
├── XXX.jpg
└── ···
```
The person re-identification dataset have already split the query set and gallery set, its folder structure should be like this:
```shell
# type: reid
person_re_identification
├── bounding_box_test
│ ├── XXX.jpg
│ └── ···
├── query
│ ├── XXX.jpg
│ └── ···
└── ···
```
Choosing the mode carefully, you can generate data jsons by:
```shell
python3 main/make_data_json.py [-d ${dataset}] [-sp ${save_path}] [-t ${type}] [-gt ${ground_truth}]
```
Auguments:
- `data`: Path of the dataset for generating data json file.
- `save_path`: Path for saving the output file.
- `type`: Type of the dataset collecting images. For dataset collecting images with the same label in one directory, we use `general`. For oxford/paris dataset, we use `oxford`. For re-id dataset, we use `reid`.
- `ground_truth`: Path of the gt information, which is necessary for generating data json file of oxford/paris dataset.
Examples:
```shell
# for dataset collecting images with the same label in one directory
python3 main/make_data_json.py -d /data/caltech101/gallery/ -sp data_jsons/caltech_gallery.json -t general
python3 main/make_data_json.py -d /data/caltech101/query/ -sp data_jsons/caltech_query.json -t feneral
# for oxford/paris dataset
python3 main/make_data_json.py -d /data/cbir/oxford/gallery/ -sp data_jsons/oxford_gallery.json -t oxford -gt /data/cbir/oxford/gt/
python3 main/make_data_json.py -d /data/cbir/oxford/query/ -sp data_jsons/oxford_query.json -t oxford -gt /data/cbir/oxford/gt/
# for re-id dataset
python3 main/make_data_json.py -d /data/market1501/bounding_box_test/ -sp data_jsons/market_gallery.json -t reid
python3 main/make_data_json.py -d /data/market1501/query/ -sp data_jsons/market_query.json -t reid
```
Note: Oxford/Paris dataset contains the ground truth of each query image in a txt file, so remember to give the path of gt file when generating data json file of Oxford/Paris.
## Extract
All outputs (features and labels) will be saved to the save directory in pickle format.
Extract feature for each data json file by:
```shell
python3 main/extract_feature.py [-dj ${data_json}] [-sp ${save_path}] [-cfg ${config_file}] [-si ${save_interval}]
```
Arguments:
- `data_json`: Path of the data json file to be extrated.
- `save_path`: Path for saving the output features in pickle format.
- `config_file`: Path of the configuration file in yaml format.
- `save_interval`: Optional. It is the number of features saved in one part file, which is set to 5000 by default.
```shell
# extract features of gallert set and query set
python3 main/extract_feature.py -dj data_jsons/caltech_gallery.json -sp /data/features/caltech/gallery/ -cfg configs/caltech.yaml
python3 main/extract_feature.py -dj data_jsons/caltech_query.json -sp /data/features/caltech/query/ -cfg configs/caltech.yaml
```
## Index
The path of query set features and gallery set features is specified in the config file.
Index the query set features by:
```shell
python3 main/index.py [-cfg ${config_file}]
```
Arguments:
- `config_file`: Path of the configuration file in yaml format.
Examples:
```shell
python3 main/index.py -cfg configs/caltech.yaml
```
## Single Index
For visulization results and wrong case analysis, we provide the script for single query image and you can visualize or save the retrieval results easily.
Use this command to single index:
```shell
python3 main/single_index.py [-cfg ${config_file}]
```
Arguments:
- `config_file`: Path of the configuration file in yaml format.
Examples:
```shell
python3 main/single_index.py -cfg configs/caltech.yaml
```
Please see single_index.py for more details.
## Add Your Own Module
We basically categorize retrieval process into 4 components.
- model: the pre-trained model for feature extraction.
- extract: assign which layer to output, including splitter functions and aggregation methods.
- index: index features, including dimension process, feature enhance, distance metric and re-rank.
- evaluate: evaluate retrieval results, outputting recall and mAP results.
Here we show how to add your own model to extract features.
1. Create your model file `retrieval_tool_box/models/backbone/backbone_impl/reid_baseline.py`.
```shell
import torch.nn as nn
from ..backbone_base import BackboneBase
from ...registry import BACKBONES
@BACKBONES.register
class ft_net(BackboneBase):
def __init__(self):
pass
def forward(self, x):
pass
```
or
```shell
import torch.nn as nn
from ..backbone_base import BackboneBase
from ...registry import BACKBONES
class FT_NET(BackboneBase):
def __init__(self):
pass
def forward(self, x):
pass
@BACKBONES.register
def ft_net():
model = FT_NET()
return model
```
2. Import the module in `retrieval_tool_box/models/backbone/__init__.py`.
```shell
from .backbone_impl.reid_baseline import ft_net
__all__ = [
'ft_net',
]
```
3. Use it in your config file.
```shell
model:
name: "ft_net"
ft_net:
load_checkpoint: "/data/my_model_zoo/res50_market1501.pth"
```
## Pipeline Combinations Search
Since tricks used in each stage have a signicant impact on retrieval performance, we present the pipeline combinations search scripts to help users to find possible combinations of approaches with various hyper-parameters.
### Get into the combinations search scripts
```shell
cd search/
```
### Define Search Space
We decompose the search space into three sub search spaces: data_process, extract and index, each of which corresponds to a specified file. Search space is defined by adding methods with hyper-parameters to a specified dict. You can add a search operator as follows:
```shell
data_processes.add(
"PadResize224",
{
"batch_size": 32,
"folder": {
"name": "Folder"
},
"collate_fn": {
"name": "CollateFn"
},
"transformers": {
"names": ["PadResize", "ToTensor", "Normalize"],
"PadResize": {
"size": 224,
"padding_v": [124, 116, 104]
},
"Normalize": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225]
}
}
}
)
```
By doing this, a data_process operator named "PadResize224" is added to the data_process sub search space and will be searched in the following process.
### Search
Similar to the image retrieval pipeline, combinations search includes two stages: search for feature extraction and search for indexing.
#### search for feature extraction
Search for the feature extraction combinations by:
```shell
python3 search_extract.py [-sp ${save_path}] [-sp ${search_modules}]
```
Arguments:
- `save_path`: path for saving the output features in pickle format.
- `search_modules`: name of the folder containing search space files.
Examples:
```shell
python3 search_extract.py -sp /data/features/gap_gmp_gem_crow_spoc/ -sm search_modules
```
#### search for indexing
Search for the indexing combinations by:
```shell
python3 search_query.py [-fd ${fea_dir}] [-sm ${search_modules}] [-sp ${save_path}]
```
Arguments:
- `fea_dir`: path of the output features extracted by the feature extraction combinations search.
- `search_modules`: name of the folder containing search space files.
- `save_path`: path for saving the retrieval results of each combination.
Examples:
```shell
python3 search_query.py -fd /data/features/gap_gmp_gem_crow_spoc/ -sm search_modules -sp /data/features/gap_gmp_gem_crow_spoc_result.json
```

82
docs/INSTALL.md 100644
View File

@ -0,0 +1,82 @@
# Installation
## Requirements
- Linux (Windows is not officially supported)
- Python 3.5 or higher
- PyTorch 1.2.0 or higher
- torchvison 0.4.0 or higher
- CUDA 9.0 or higher
- numpy
- sklearn
- yacs
- tqdm
Our experiments are conducted on the following environment:
- PyTorch 1.2.0
- torchvision 0.4.0
- CUDA 9.0
- numpy 1.17.2
- sklearn 0.21.3
- tqdm 4.36.1
## Install PyRetri
1. Install PyTorch and torchvision following the official instructions.
2. Clone the PyRetri repository.
```she
git clone https://github.com/???
cd ???
```
3. Install PyRetri.
```shell
python setup.py install
```
## Prepare Datasets
### Datasets
In our experiments, we use four general image retrieval dataset and two person re-identification dataset.
- [Oxford5k](https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/): collecting crawling images from Flickr using the names of 11 different landmarks in Oxford, which stands for landmark recognition task.
- [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html): containing photos of 200 bird species, which represents fine-grained visual categorization task.
- [Indoor](http://web.mit.edu/torralba/www/indoor.html): containing indoor scene images with 67 categories, representing scene recognition task.
- [Caltech101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/): consisting pictures of objects belonging to 101 categories, standing for general object recognition task.
- [Market-1501](http://www.liangzheng.com.cn/Project/project_reid.html): containing images taken on the Tsinghua campus under 6 camera viewpoints, representing person re-identification task.
- [DukeMTMC-reID](https://drive.google.com/file/d/1jjE85dRCMOgRtvJ5RQV9-Afs-2_5dY3O/view): containing images captured by 8 cameras, which is more challenging.
To reproduce our experimental results, you need to first download these datasets, then follow these steps to re-organize the dataset.
### Split Dataset
For image retrieval task, the dataset should be divided into two subset: query set and gallery set. If your dataset has been divided already, you can skip this step.
In order to help you to reproduce our results conventionally, we provide several txt files, each of which is the division protocol used in our experiments. These txt files can be found in [split_file] and you can use the following command to split the dataset mentioned above:
```shell
python3 main/split_dataset.py [-d ${dataset}] [-sf ${split_file}]
```
Arguments:
- `dataset`: Path of the dataset to be splitted.
- `split_file`: Path of the division protocol txt file, with each line corresponding to one image:<image_path> <is_gallery_image>. <image_path> corresponds to the relative path of the image, and a value of 1 or 0 for <is_gallery_image> denotes that the file is in the gallery or query set, respectively.
Examples:
```shell
python3 main/split_dataset.py -d /data/caltech101/ -sf split_file/caltech_split.txt
```
Then query folder and gallery folder will be created under the dataset folder.
Note:
1. For Re-ID dataset, the images are well divided in advance, so we do not need to split it.
2. Since we use symlink images instead of copying images to split the dataset, the overwrite operation is prohibited. In other words, if you want to split the dataset again, please remember to delete the last generated folders.

58
docs/MODEL_ZOO.md 100644
View File

@ -0,0 +1,58 @@
# Model Zoo
## General image retrieval
### pre-trained models
| Training Set | Backbone | for Short | Download |
| :------------------: | :-------: | :-------: | :----------------------------------------------------------: |
| ImageNet | VGG-16 | I-VGG16 | [model](https://download.pytorch.org/models/vgg16-397923af.pth) |
| Places365 | VGG-16 | P-VGG16 | [model](https://drive.google.com/open?id=1U_VWbn_0L9mSDCBGiAIFbXxMvBeOiTG9) |
| ImageNet + Places365 | VGG-16 | H-VGG16 | [model](https://drive.google.com/open?id=11zE5kGNeeAXMhlHNv31Ye4kDcECrlJ1t) |
| ImageNet | ResNet-50 | I-Res50 | [model](https://download.pytorch.org/models/resnet50-19c8e357.pth) |
| Places365 | ResNet-50 | P-Res50 | [model](https://drive.google.com/open?id=1lp_nNw7hh1MQO_kBW86GG8y3_CyugdS2) |
| ImageNet + Places365 | ResNet-50 | H-Res50 | [model](https://drive.google.com/open?id=1_USt_gOxgV4NJ9Zjw_U8Fq-1HEC_H_ki) |
### performance
| Dataset | Data Augmentation | Backbone | Pooling | Dimension Process | mAP |
| :--------: | :------------------------: | :------: | :-----: | :------------------: | :--: |
| Oxford5k | ShorterResize + CenterCrop | H-VGG16 | GAP | l2 +SVD(whiten) + l2 | 62.9 |
| CUB-200 | ShorterResize + CenterCrop | I-Res50 | SCDA | l2 + PCA + l2 | 27.8 |
| Indoor | DirectResize | P-Res50 | CroW | l2 + PCA + l2 | 51.8 |
| Caltech101 | PadResize | I-Res50 | GeM | l2 + PCA + l2 | 77.9 |
Choosing the implementations mentioned above as baselines and adding some tricks, we have:
| Dataset | Implementations | mAP |
| :--------: | :--------------------------------: | :--: |
| Oxford5k | baseline + K-reciprocal | 72.9 |
| CUB-200 | baseline + K-reciprocal | 38.9 |
| Indoor | baseline + DBA + QE | 63.7 |
| Caltech101 | baseline + DBA + QE + K-reciprocal | 86.1 |
## Person re-identification
For person re-identification, we use the model provided by [Person_reID_baseline](https://github.com/layumi/Person_reID_baseline_pytorch) and reproduce its resutls. In addition, we train a model on DukeMTMC-reID through the open source code for further experiments.
###pre-trained models
| Training Set | Backbone | for Short | Download |
| :-----------: | :-------: | :-------: | :------: |
| Market-1501 | ResNet-50 | M-Res50 | [model](https://drive.google.com/open?id=1-6LT_NCgp_0ps3EO-uqERrtlGnbynWD5) |
| DukeMTMC-reID | ResNet-50 | D-Res50 | [model](https://drive.google.com/open?id=1X2Tiv-SQH3FxwClvBUalWkLqflgZHb9m) |
### performance
| Dataset | Data Augmentation | Backbone | Pooling | Dimension Process | mAP | Recall@1 |
| :-----------: | :--------------------: | :------: | :-----: | :---------------: | ---- | :------: |
| Market-1501 | DirectResize + TwoFlip | M-Res50 | GAP | l2 | 71.6 | 88.8 |
| DukeMTMC-reID | DirectResize + TwoFlip | D-Res50 | GAP | l2 | 62.5 | 80.4 |
Choosing the implementations mentioned above as baselines and adding some tricks, we have:
| Dataset | Implementations | mAP | Recall@1 |
| :-----------: | :-------------------------------------: | :--: | :------: |
| Market-1501 | Baseline + l2 + PCA + l2 + K-reciprocal | 84.8 | 90.4 |
| DukeMTMC-reID | Baseline + l2 + PCA + l2 + K-reciprocal | 78.3 | 84.2 |

BIN
main/.DS_Store vendored 100644

Binary file not shown.

View File

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
import argparse
import os
import torch
from retrieval_tool_box.config import get_defaults_cfg, setup_cfg
from retrieval_tool_box.datasets import build_folder, build_loader
from retrieval_tool_box.models import build_model
from retrieval_tool_box.extract import build_extract_helper
from torchvision import models
def parse_args():
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
parser.add_argument('--data_json', '-dj', default=None, type=str, help='json file for dataset to be extracted')
parser.add_argument('--save_path', '-sp', default=None, type=str, help='save path for features')
parser.add_argument('--config_file', '-cfg', default=None, metavar='FILE', type=str, help='path to config file')
parser.add_argument('--save_interval', '-si', default=5000, type=int, help='number of features saved in one part file')
args = parser.parse_args()
return args
def main():
# init args
args = parse_args()
assert args.data_json is not None, 'the dataset json must be provided!'
assert args.save_path is not None, 'the save path must be provided!'
assert args.config_file is not None, 'a config file must be provided!'
# init and load retrieval pipeline settings
cfg = get_defaults_cfg()
cfg = setup_cfg(cfg, args.config_file, args.opts)
# build dataset and dataloader
dataset = build_folder(args.data_json, cfg.datasets)
dataloader = build_loader(dataset, cfg.datasets)
# build model
model = build_model(cfg.model)
# build helper and extract features
extract_helper = build_extract_helper(model, cfg.extract)
extract_helper.do_extract(dataloader, args.save_path, args.save_interval)
if __name__ == '__main__':
main()

48
main/index.py 100644
View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
import argparse
import os
import pickle
from retrieval_tool_box.config import get_defaults_cfg, setup_cfg
from retrieval_tool_box.index import build_index_helper, feature_loader
from retrieval_tool_box.evaluate import build_evaluate_helper
def parse_args():
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
parser.add_argument('--config_file', '-cfg', default=None, metavar='FILE', type=str, help='path to config file')
args = parser.parse_args()
return args
def main():
# init args
args = parse_args()
assert args.config_file is not None, 'a config file must be provided!'
# init and load retrieval pipeline settings
cfg = get_defaults_cfg()
cfg = setup_cfg(cfg, args.config_file, args.opts)
# load features
query_fea, query_info, _ = feature_loader.load(cfg.index.query_fea_dir, cfg.index.feature_names)
gallery_fea, gallery_info, _ = feature_loader.load(cfg.index.gallery_fea_dir, cfg.index.feature_names)
# build helper and index features
index_helper = build_index_helper(cfg.index)
index_result_info, query_fea, gallery_fea = index_helper.do_index(query_fea, query_info, gallery_fea)
# build helper and evaluate results
evaluate_helper = build_evaluate_helper(cfg.evaluate)
mAP, recall_at_k = evaluate_helper.do_eval(index_result_info, gallery_info)
# show results
evaluate_helper.show_results(mAP, recall_at_k)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
import argparse
from retrieval_tool_box.extract import make_data_json
def parse_args():
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
parser.add_argument('--dataset', '-d', default=None, type=str, help="path for the dataset that make the json file")
parser.add_argument('--save_path', '-sp', default=None, type=str, help="save path for the json file")
parser.add_argument('--type', '-t', default=None, type=str, help="mode of the dataset")
parser.add_argument('--ground_truth', '-gt', default=None, type=str, help="ground truth of the dataset")
args = parser.parse_args()
return args
def main():
# init args
args = parse_args()
assert args.dataset is not None, 'the data must be provided!'
assert args.save_path is not None, 'the save path must be provided!'
assert args.type is not None, 'the type must be provided!'
# make data json
make_data_json(args.dataset, args.save_path, args.type, args.ground_truth)
print('make data json have done!')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
import argparse
import os
from PIL import Image
import numpy as np
from retrieval_tool_box.config import get_defaults_cfg, setup_cfg
from retrieval_tool_box.datasets import build_transformers
from retrieval_tool_box.models import build_model
from retrieval_tool_box.extract import build_extract_helper
from retrieval_tool_box.index import build_index_helper, feature_loader
def parse_args():
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
parser.add_argument('--config_file', '-cfg', default=None, metavar='FILE', type=str, help='path to config file')
args = parser.parse_args()
return args
def main():
# init args
args = parse_args()
assert args.config_file is not "", 'a config file must be provided!'
assert os.path.exists(args.config_file), 'the config file must be existed!'
# init and load retrieval pipeline settings
cfg = get_defaults_cfg()
cfg = setup_cfg(cfg, args.config_file, args.opts)
# set path for single image
path = '/data/caltech101/query/airplanes/image_0004.jpg'
# build transformers
transformers = build_transformers(cfg.datasets.transformers)
# build model
model = build_model(cfg.model)
# read image and convert it to tensor
img = Image.open(path).convert("RGB")
img_tensor = transformers(img)
# build helper and extract feature for single image
extract_helper = build_extract_helper(model, cfg.extract)
img_fea_info = extract_helper.do_single_extract(img_tensor)
stacked_feature = list()
for name in cfg.index.feature_names:
assert name in img_fea_info[0], "invalid feature name: {} not in {}!".format(name, img_fea_info[0].keys())
stacked_feature.append(img_fea_info[0][name])
img_fea = np.concatenate(stacked_feature, axis=1)
# load gallery features
gallery_fea, gallery_info, _ = feature_loader.load(cfg.index.gallery_fea_dir, cfg.index.feature_names)
# build helper and single index feature
index_helper = build_index_helper(cfg.index)
index_result_info, query_fea, gallery_fea = index_helper.do_index(img_fea, img_fea_info, gallery_fea)
# index_helper.show_topk_retrieved_images(index_result_info[0], 4, gallery_info)
index_helper.save_topk_retrieved_images('../retrieved_images', index_result_info[0], 5, gallery_info)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
import argparse
import os
from retrieval_tool_box.extract.utils import split_dataset
def parse_args():
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
parser.add_argument('--dataset', '-d', default=None, type=str, help="path for the dataset.")
parser.add_argument('--split_file', '-sf', default=None, type=str, help="name for the dataset.")
args = parser.parse_args()
return args
def main():
# init args
args = parse_args()
assert args.dataset is not None, 'the dataset must be provided!'
assert args.split_file is not None, 'the save path must be provided!'
# split dataset
split_dataset(args.dataset, args.split_file)
print('split dataset have done!')
if __name__ == '__main__':
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

7
requirements.txt 100644
View File

@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
numpy
torch>=1.1
torchvision>=0.4
sklearn
yacs
tqdm

BIN
retrieval_tool_box/.DS_Store vendored 100644

Binary file not shown.

View File

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
__version__ = "0.1.0"

View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from .config import setup_cfg, get_defaults_cfg
__all__ = [
'get_defaults_cfg',
'setup_cfg',
]

View File

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from ..datasets import get_datasets_cfg
from ..models import get_model_cfg
from ..extract import get_extract_cfg
from ..index import get_index_cfg
from ..evaluate import get_evaluate_cfg
def get_defaults_cfg() -> CfgNode:
"""
Construct the default configuration tree.
Returns:
cfg (CfgNode): the default configuration tree.
"""
cfg = CfgNode()
cfg["datasets"] = get_datasets_cfg()
cfg["model"] = get_model_cfg()
cfg["extract"] = get_extract_cfg()
cfg["index"] = get_index_cfg()
cfg["evaluate"] = get_evaluate_cfg()
return cfg
def setup_cfg(cfg: CfgNode, cfg_file: str, cfg_opts: list or None = None) -> CfgNode:
"""
Load a yaml config file and merge it this CfgNode.
Args:
cfg (CfgNode): the configuration tree with default structure.
cfg_file (str): the path for yaml config file which is matched with the CfgNode.
cfg_opts (list, optional): config (keys, values) in a list (e.g., from command line) into this CfgNode.
Returns:
cfg (CfgNode): the configuration tree with settings in the config file.
"""
cfg.merge_from_file(cfg_file)
cfg.merge_from_list(cfg_opts)
cfg.freeze()
return cfg

Binary file not shown.

View File

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .builder import build_collate, build_folder, build_transformers, build_loader
from .config import get_datasets_cfg
__all__ = [
'get_datasets_cfg',
'build_collate', 'build_folder', 'build_transformers', 'build_loader',
]

View File

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .registry import COLLATEFNS, FOLDERS, TRANSFORMERS
from .collate_fn import CollateFnBase
from .folder import FolderBase
from .transformer import TransformerBase
from ..utils import simple_build
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
def build_collate(cfg: CfgNode) -> CollateFnBase:
"""
Instantiate a collate class with the given configuration tree.
Args:
cfg (CfgNode): the configuration tree.
Returns:
collate (CollateFnBase): a collate class.
"""
name = cfg["name"]
collate = simple_build(name, cfg, COLLATEFNS)
return collate
def build_transformers(cfg: CfgNode) -> Compose:
"""
Instantiate a compose class containing several transforms with the given configuration tree.
Args:
cfg (CfgNode): the configuration tree.
Returns:
transformers (Compose): a compose class.
"""
names = cfg["names"]
transformers = list()
for name in names:
transformers.append(simple_build(name, cfg, TRANSFORMERS))
transformers = Compose(transformers)
return transformers
def build_folder(data_json_path: str, cfg: CfgNode) -> FolderBase:
"""
Instantiate a folder class with the given configuration tree.
Args:
data_json_path (str): the path of the data json file.
cfg (CfgNode): the configuration tree.
Returns:
folder (FolderBase): a folder class.
"""
trans = build_transformers(cfg.transformers)
folder = simple_build(cfg.folder["name"], cfg.folder, FOLDERS, data_json_path=data_json_path, transformer=trans)
return folder
def build_loader(folder: FolderBase, cfg: CfgNode) -> DataLoader:
"""
Instantiate a data loader class with the given configuration tree.
Args:
folder (FolderBase): the folder function.
cfg (CfgNode): the configuration tree.
Returns:
data_loader (DataLoader): a data loader class.
"""
co_fn = build_collate(cfg.collate_fn)
data_loader = DataLoader(folder, cfg["batch_size"], collate_fn=co_fn, num_workers=8, pin_memory=True)
return data_loader

View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from .collate_fn_impl.collate_fn import CollateFn
from .collate_fn_base import CollateFnBase
__all__ = [
'CollateFnBase',
'CollateFn',
]

View File

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
from abc import abstractmethod
import torch
from ...utils import ModuleBase
from typing import Dict, List
class CollateFnBase(ModuleBase):
"""
The base class of collate function.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps: default hyper parameters in a dict (keys, values).
"""
super(CollateFnBase, self).__init__(hps)
@abstractmethod
def __call__(self, batch: List[Dict]) -> Dict[str, torch.tensor]:
pass

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
import glob
from os.path import basename, dirname, isfile
modules = glob.glob(dirname(__file__) + "/*.py")
__all__ = [
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
]

View File

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
import torch
from ..collate_fn_base import CollateFnBase
from ...registry import COLLATEFNS
from torch.utils.data.dataloader import default_collate
from typing import Dict, List
@COLLATEFNS.register
class CollateFn(CollateFnBase):
"""
A wrapper for torch.utils.data.dataloader.default_collate.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(CollateFn, self).__init__(hps)
def __call__(self, batch: List[Dict]) -> Dict[str, torch.tensor]:
assert isinstance(batch, list)
assert isinstance(batch[0], dict)
return default_collate(batch)

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .registry import COLLATEFNS, FOLDERS, TRANSFORMERS
from ..utils import get_config_from_registry
def get_collate_cfg() -> CfgNode:
cfg = get_config_from_registry(COLLATEFNS)
cfg["name"] = "unknown"
return cfg
def get_folder_cfg() -> CfgNode:
cfg = get_config_from_registry(FOLDERS)
cfg["name"] = "unknown"
return cfg
def get_tranformers_cfg() -> CfgNode:
cfg = get_config_from_registry(TRANSFORMERS)
cfg["names"] = ["unknown"]
return cfg
def get_datasets_cfg() -> CfgNode:
cfg = CfgNode()
cfg["collate_fn"] = get_collate_cfg()
cfg["folder"] = get_folder_cfg()
cfg["transformers"] = get_tranformers_cfg()
cfg["batch_size"] = 1
return cfg

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from .folder_impl.folder import Folder
from .folder_base import FolderBase
__all__ = [
'FolderBase',
'Folder',
]

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
import numpy as np
from PIL import Image
import pickle
import os
from abc import abstractmethod
from ...utils import ModuleBase
from typing import Dict, List
class FolderBase(ModuleBase):
"""
The base class of folder function.
"""
default_hyper_params = dict()
def __init__(self, data_json_path: str, transformer: callable or None = None, hps: Dict or None = None):
"""
Args:
data_json_path (str): the path for data json file.
transformer (callable): a list of data augmentation operations.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(FolderBase, self).__init__(hps)
with open(data_json_path, "rb") as f:
self.data_info = pickle.load(f)
self.data_json_path = data_json_path
self.transformer = transformer
def __len__(self) -> int:
pass
@abstractmethod
def __getitem__(self, idx: int) -> Dict:
pass
def find_classes(self, info_dicts: Dict) -> (List, Dict):
pass
def read_img(self, path: str) -> Image:
"""
Load image.
Args:
path (str): the path of the image.
Returns:
image (Image): shape (H, W, C).
"""
try:
img = Image.open(path)
img = img.convert("RGB")
return img
except Exception as e:
print('[DataSet]: WARNING image can not be loaded: {}'.format(str(e)))
return None

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
import glob
from os.path import basename, dirname, isfile
modules = glob.glob(dirname(__file__) + "/*.py")
__all__ = [
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
]

View File

@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
import pickle
from ..folder_base import FolderBase
from ...registry import FOLDERS
from typing import Dict, List
@FOLDERS.register
class Folder(FolderBase):
"""
A folder function for loading images.
Hyper-Params:
use_bbox: bool, whether use bbox to crop image. When set to true,
make sure that bbox attribute is provided in your data json and bbox format is [x1, y1, x2, y2].
"""
default_hyper_params = {
"use_bbox": False,
}
def __init__(self, data_json_path: str, transformer: callable or None = None, hps: Dict or None = None):
"""
Args:
data_json_path (str): the path for data json file.
transformer (callable): a list of data augmentation operations.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(Folder, self).__init__(data_json_path, transformer, hps)
self.classes, self.class_to_idx = self.find_classes(self.data_info["info_dicts"])
def find_classes(self, info_dicts: Dict) -> (List, Dict):
"""
Get the class names and the mapping relations.
Args:
info_dicts (dict): the dataset information contained the data json file.
Returns:
tuple (list, dict): a list of class names and a dict for projecting class name into int label.
"""
classes = list()
for i in range(len(info_dicts)):
if info_dicts[i]["label"] not in classes:
classes.append(info_dicts[i]["label"])
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __len__(self) -> int:
"""
Get the number of total training samples.
Returns:
length (int): the number of total training samples.
"""
return len(self.data_info["info_dicts"])
def __getitem__(self, idx: int) -> Dict:
"""
Load the image and convert it to tensor for training.
Args:
idx (int): the serial number of the image.
Returns:
item (dict): the dict containing the image after augmentations, serial number and label.
"""
info = self.data_info["info_dicts"][idx]
img = self.read_img(info["path"])
if self._hyper_params["use_bbox"]:
assert info["bbox"] is not None, 'image {} does not have a bbox'.format(info["path"])
x1, y1, x2, y2 = info["bbox"]
box = map(int, (x1, y1, x2, y2))
img = img.crop(box)
img = self.transformer(img)
return {"img": img, "idx": idx, "label": self.class_to_idx[info["label"]]}

View File

@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
from ..utils.registry import Registry
COLLATEFNS = Registry()
FOLDERS = Registry()
TRANSFORMERS = Registry()

Binary file not shown.

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .transformers_impl.transformers import (DirectResize, PadResize, ShorterResize, CenterCrop,
ToTensor, ToCaffeTensor, Normalize, TenCrop, TwoFlip)
from .transformers_base import TransformerBase
__all__ = [
'TransformerBase',
'DirectResize', 'PadResize', 'ShorterResize', 'CenterCrop', 'ToTensor', 'ToCaffeTensor',
'Normalize', 'TenCrop', 'TwoFlip',
]

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
from abc import abstractmethod
from PIL import Image
import torch
from ...utils import ModuleBase
from ...utils import Registry
from typing import Dict
class TransformerBase(ModuleBase):
"""
The base class of data augmentation operations.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(TransformerBase, self).__init__(hps)
@abstractmethod
def __call__(self, img: Image) -> Image or torch.tensor:
pass

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
import glob
from os.path import basename, dirname, isfile
modules = glob.glob(dirname(__file__) + "/*.py")
__all__ = [
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
]

View File

@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
import torch
import numpy as np
from PIL import Image
from ..transformers_base import TransformerBase
from ...registry import TRANSFORMERS
from torchvision.transforms import Resize as TResize
from torchvision.transforms import TenCrop as TTenCrop
from torchvision.transforms import CenterCrop as TCenterCrop
from torchvision.transforms import ToTensor as TToTensor
from torchvision.transforms.functional import hflip
from typing import Dict
@TRANSFORMERS.register
class DirectResize(TransformerBase):
"""
Directly resize image to target size, regardless of h: w ratio.
Hyper-Params
size (sequence): desired output size.
interpolation (int): desired interpolation.
"""
default_hyper_params = {
"size": (224, 224),
"interpolation": Image.BILINEAR,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(DirectResize, self).__init__(hps)
self.t_transformer = TResize(self._hyper_params["size"], self._hyper_params["interpolation"])
def __call__(self, img: Image) -> Image:
return self.t_transformer(img)
@TRANSFORMERS.register
class PadResize(TransformerBase):
"""
Resize image's longer edge to target size, and then pad the shorter edge to target size.
Hyper-Params
size (int): desired output size of the longer edge.
padding_v (sequence): padding pixel value.
interpolation (int): desired interpolation.
"""
default_hyper_params = {
"size": 224,
"padding_v": [124, 116, 104],
"interpolation": Image.BILINEAR,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps: default hyper parameters in a dict (keys, values).
"""
super(PadResize, self).__init__(hps)
def __call__(self, img: Image) -> Image:
target_size = self._hyper_params["size"]
padding_v = tuple(self._hyper_params["padding_v"])
interpolation = self._hyper_params["interpolation"]
w, h = img.size
if w > h:
img = img.resize((int(target_size), int(h * target_size * 1.0 / w)), interpolation)
else:
img = img.resize((int(w * target_size * 1.0 / h), int(target_size)), interpolation)
ret_img = Image.new("RGB", (target_size, target_size), padding_v)
w, h = img.size
st_w = int((ret_img.size[0] - w) / 2.0)
st_h = int((ret_img.size[1] - h) / 2.0)
ret_img.paste(img, (st_w, st_h))
return ret_img
@TRANSFORMERS.register
class ShorterResize(TransformerBase):
"""
Resize image's shorter edge to target size, while keep h: w ratio.
Hyper-Params
size (int): desired output size.
interpolation (int): desired interpolation.
"""
default_hyper_params = {
"size": 224,
"interpolation": Image.BILINEAR,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(ShorterResize, self).__init__(hps)
self.t_transformer = TResize(self._hyper_params["size"], self._hyper_params["interpolation"])
def __call__(self, img: Image) -> Image:
return self.t_transformer(img)
@TRANSFORMERS.register
class CenterCrop(TransformerBase):
"""
A wrapper from CenterCrop in pytorch, see torchvision.transformers.CenterCrop for explanation.
Hyper-Params
size(sequence or int): desired output size.
"""
default_hyper_params = {
"size": 224,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(CenterCrop, self).__init__(hps)
self.t_transformer = TCenterCrop(self._hyper_params["size"])
def __call__(self, img: Image) -> Image:
return self.t_transformer(img)
@TRANSFORMERS.register
class ToTensor(TransformerBase):
"""
A wrapper from ToTensor in pytorch, see torchvision.transformers.ToTensor for explanation.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(ToTensor, self).__init__(hps)
self.t_transformer = TToTensor()
def __call__(self, imgs: Image or tuple) -> torch.Tensor:
if not isinstance(imgs, tuple):
imgs = [imgs]
ret_tensor = list()
for img in imgs:
ret_tensor.append(self.t_transformer(img))
ret_tensor = torch.stack(ret_tensor, dim=0)
return ret_tensor
@TRANSFORMERS.register
class ToCaffeTensor(TransformerBase):
"""
Create tensors for models trained in caffe.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(ToCaffeTensor, self).__init__(hps)
def __call__(self, imgs: Image or tuple) -> torch.tensor:
if not isinstance(imgs, tuple):
imgs = [imgs]
ret_tensor = list()
for img in imgs:
img = np.array(img, np.int32, copy=False)
r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
img = np.stack([b, g, r], axis=2)
img = torch.from_numpy(img)
img = img.transpose(0, 1).transpose(0, 2).contiguous()
img = img.float()
ret_tensor.append(img)
ret_tensor = torch.stack(ret_tensor, dim=0)
return ret_tensor
@TRANSFORMERS.register
class Normalize(TransformerBase):
"""
Normalize a tensor image with mean and standard deviation.
Hyper-Params
mean (sequence): sequence of means for each channel.
std (sequence): sequence of standard deviations for each channel.
"""
default_hyper_params = {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(Normalize, self).__init__(hps)
for v in ["mean", "std"]:
self.__dict__[v] = np.array(self._hyper_params[v])[None, :, None, None]
self.__dict__[v] = torch.from_numpy(self.__dict__[v]).float()
def __call__(self, tensor: torch.tensor) -> torch.tensor:
assert tensor.ndimension() == 4
tensor.sub_(self.mean).div_(self.std)
return tensor
@TRANSFORMERS.register
class TenCrop(TransformerBase):
"""
A wrapper from TenCrop in pytorchsee torchvision.transformers.TenCrop for explanation.
Hyper-Params
size (sequence or int): desired output size.
"""
default_hyper_params = {
"size": 224,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(TenCrop, self).__init__(hps)
self.t_transformer = TTenCrop(self._hyper_params["size"])
def __call__(self, img: Image) -> Image:
return self.t_transformer(img)
@TRANSFORMERS.register
class TwoFlip(TransformerBase):
"""
Return the image itself and its horizontal flipped one.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(TwoFlip, self).__init__(hps)
def __call__(self, img: Image) -> (Image, Image):
return img, hflip(img)

View File

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .config import get_evaluate_cfg
from .builder import build_evaluate_helper
__all__ = [
'get_evaluate_cfg',
'build_evaluate_helper',
]

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .registry import EVALUATORS
from .evaluator import EvaluatorBase
from .helper import EvaluateHelper
from ..utils import simple_build
def build_evaluator(cfg: CfgNode) -> EvaluatorBase:
"""
Instantiate a evaluator class.
Args:
cfg (CfgNode): the configuration tree.
Returns:
evaluator (EvaluatorBase): a evaluator class.
"""
name = cfg["name"]
evaluator = simple_build(name, cfg, EVALUATORS)
return evaluator
def build_evaluate_helper(cfg: CfgNode) -> EvaluateHelper:
"""
Instantiate a evaluate helper class.
Args:
cfg (CfgNode): the configuration tree.
Returns:
helper (EvaluateHelper): a evaluate helper class.
"""
evaluator = build_evaluator(cfg.evaluator)
helper = EvaluateHelper(evaluator)
return helper

View File

@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .registry import EVALUATORS
from ..utils import get_config_from_registry
def get_evaluator_cfg() -> CfgNode:
cfg = get_config_from_registry(EVALUATORS)
cfg["name"] = "unknown"
return cfg
def get_evaluate_cfg() -> CfgNode:
cfg = CfgNode()
cfg["evaluator"] = get_evaluator_cfg()
return cfg

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from .evaluators_impl.overall import OverAll
from .evaluators_impl.oxford_overall import OxfordOverAll
from .evaluators_impl.reid_overall import ReIDOverAll
from .evaluators_base import EvaluatorBase
__all__ = [
'EvaluatorBase',
'OverAll', 'OxfordOverAll',
'ReIDOverAll',
]

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
from abc import abstractmethod
from ...utils import ModuleBase
from typing import Dict
class EvaluatorBase(ModuleBase):
"""
The base class of evaluators which compute mAP and recall.
"""
default_hyper_params = {}
def __init__(self, hps: Dict or None = None):
super(EvaluatorBase, self).__init__(hps)
@abstractmethod
def __call__(self, query_result: Dict, gallery_info: Dict) -> (float, Dict):
pass

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
import glob
from os.path import basename, dirname, isfile
modules = glob.glob(dirname(__file__) + "/*.py")
__all__ = [
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
]

View File

@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
import numpy as np
from ..evaluators_base import EvaluatorBase
from ...registry import EVALUATORS
from sklearn.metrics import average_precision_score
from typing import Dict, List
@EVALUATORS.register
class OverAll(EvaluatorBase):
"""
A evaluator for mAP and recall computation.
Hyper-Params
recall_k (sequence): positions of recalls to be calculated.
"""
default_hyper_params = {
"recall_k": [1, 2, 4, 8],
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(OverAll, self).__init__(hps)
self._hyper_params["recall_k"] = np.sort(self._hyper_params["recall_k"])
def compute_recall_at_k(self, gt: List[bool], result_dict: Dict) -> None:
"""
Calculate the recall at each position.
Args:
gt (sequence): a list of bool indicating if the result is equal to the label.
result_dict (dict): a dict of indexing results.
"""
ks = self._hyper_params["recall_k"]
gt = gt[:ks[-1]]
first_tp = np.where(gt)[0]
if len(first_tp) == 0:
return
for k in ks:
if k >= first_tp[0] + 1:
result_dict[k] = result_dict[k] + 1
def __call__(self, query_result: List, gallery_info: List) -> (float, Dict):
"""
Calculate the mAP and recall for the indexing results.
Args:
query_result (list): a list of indexing results.
gallery_info (list): a list of gallery set information.
Returns:
tuple (float, dict): mean average precision and recall for each position.
"""
aps = list()
# For mAP calculation
pseudo_score = np.arange(0, len(gallery_info))[::-1]
recall_at_k = dict()
for k in self._hyper_params["recall_k"]:
recall_at_k[k] = 0
gallery_label = np.array([gallery_info[idx]["label_idx"] for idx in range(len(gallery_info))])
for i in range(len(query_result)):
ranked_idx = query_result[i]["ranked_neighbors_idx"]
gt = (gallery_label[query_result[i]["ranked_neighbors_idx"]] == query_result[i]["label_idx"])
aps.append(average_precision_score(gt, pseudo_score[:len(gt)]))
# deal with 'gallery as query' test
if gallery_info[ranked_idx[0]]["path"] == query_result[i]["path"]:
gt.pop(0)
self.compute_recall_at_k(gt, recall_at_k)
mAP = np.mean(aps) * 100
for k in recall_at_k:
recall_at_k[k] = recall_at_k[k] * 100 / len(query_result)
return mAP, recall_at_k

View File

@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
import os
import numpy as np
from ..evaluators_base import EvaluatorBase
from ...registry import EVALUATORS
from typing import Dict, List
@EVALUATORS.register
class OxfordOverAll(EvaluatorBase):
"""
A evaluator for Oxford mAP and recall computation.
Hyper-Params
gt_dir (str): the path of the oxford ground truth.
recall_k (sequence): positions of recalls to be calculated.
"""
default_hyper_params = {
"gt_dir": "/data/cbir/oxford/gt",
"recall_k": [1, 2, 4, 8],
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(OxfordOverAll, self).__init__(hps)
assert os.path.exists(self._hyper_params["gt_dir"]), 'the ground truth files must be existed!'
@staticmethod
def _load_tag_set(file: str) -> set:
"""
Read information from the txt file.
Args:
file (str): the path of the txt file.
Returns:
ret (set): the information.
"""
ret = set()
with open(file, "r") as f:
for line in f.readlines():
ret.add(line.strip())
return ret
def compute_ap(self, query_tag: str, ranked_tags: List[str], recall_at_k: Dict) -> float:
"""
Calculate the ap for one query.
Args:
query_tag (str): name of the query image.
ranked_tags (list): a list of label of the indexing results.
recall_at_k (dict): positions of recalls to be calculated.
Returns:
ap (float): ap for one query.
"""
gt_prefix = os.path.join(self._hyper_params["gt_dir"], query_tag)
good_set = self._load_tag_set(gt_prefix + "_good.txt")
ok_set = self._load_tag_set(gt_prefix + "_ok.txt")
junk_set = self._load_tag_set(gt_prefix + "_junk.txt")
pos_set = set.union(good_set, ok_set)
old_recall = 0.0
old_precision = 1.0
ap = 0.0
intersect_size = 0.0
i = 0
first_tp = -1
for tag in ranked_tags:
if tag in junk_set:
continue
if tag in pos_set:
intersect_size += 1
# Remember that in oxford query mode, the first element in rank_list is the query itself.
if first_tp == -1:
first_tp = i
recall = intersect_size * 1.0 / len(pos_set)
precision = intersect_size / (i + 1.0)
ap += (recall - old_recall) * ((old_precision + precision) / 2.0)
old_recall = recall
old_precision = precision
i += 1
if first_tp != -1:
ks = self._hyper_params["recall_k"]
for k in ks:
if k >= first_tp + 1:
recall_at_k[k] = recall_at_k[k] + 1
return ap
def __call__(self, query_result: List, gallery_info: List) -> (float, Dict):
"""
Calculate the mAP and recall for the indexing results.
Args:
query_result (list): a list of indexing results.
gallery_info (list): a list of gallery set information.
Returns:
tuple (float, dict): mean average precision and recall for each position.
"""
aps = list()
recall_at_k = dict()
for k in self._hyper_params["recall_k"]:
recall_at_k[k] = 0
for i in range(len(query_result)):
ranked_idx = query_result[i]["ranked_neighbors_idx"]
ranked_tags = list()
for idx in ranked_idx:
ranked_tags.append(gallery_info[idx]["label"])
aps.append(self.compute_ap(query_result[i]["query_name"], ranked_tags, recall_at_k))
mAP = np.mean(aps) * 100
for k in recall_at_k:
recall_at_k[k] = recall_at_k[k] * 100 / len(query_result)
return mAP, recall_at_k

View File

@ -0,0 +1,144 @@
# -*- coding: utf-8 -*-
import numpy as np
import torch
from ..evaluators_base import EvaluatorBase
from ...registry import EVALUATORS
from sklearn.metrics import average_precision_score
from typing import Dict, List
@EVALUATORS.register
class ReIDOverAll(EvaluatorBase):
"""
A evaluator for Re-ID task mAP and recall computation.
Hyper-Params
recall_k (sequence): positions of recalls to be calculated.
"""
default_hyper_params = {
"recall_k": [1, 2, 4, 8],
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(ReIDOverAll, self).__init__(hps)
self._hyper_params["recall_k"] = np.sort(self._hyper_params["recall_k"])
def compute_ap_cmc(self, index: np.ndarray, good_index: np.ndarray, junk_index: np.ndarray) -> (float, torch.tensor):
"""
Calculate the ap and cmc for one query.
Args:
index (np.ndarray): the sorted retrieval index for one query.
good_index (np.ndarray): the index for good matching.
junk_index (np.ndarray): the index for junk matching.
Returns:
tupele (float, torch.tensor): (ap, cmc), ap and cmc for one query.
"""
ap = 0
cmc = torch.IntTensor(len(index)).zero_()
if good_index.size == 0:
cmc[0] = -1
return ap, cmc
# remove junk_index
mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
# find good_index index
ngood = len(good_index)
mask = np.in1d(index, good_index)
rows_good = np.argwhere(mask == True)
rows_good = rows_good.flatten()
cmc[rows_good[0]:] = 1
for i in range(ngood):
d_recall = 1.0 / ngood
precision = (i + 1) * 1.0 / (rows_good[i] + 1)
if rows_good[i] != 0:
old_precision = i * 1.0 / rows_good[i]
else:
old_precision = 1.0
ap = ap + d_recall * (old_precision + precision) / 2
return ap, cmc
def evaluate_once(self, index: np.ndarray, ql: int, qc: int, gl: np.ndarray, gc: np.ndarray) -> (float, torch.tensor):
"""
Generate the indexes and calculate the ap and cmc for one query.
Args:
index (np.ndarray): the sorted retrieval index for one query.
ql (int): the person id of the query.
qc (int): the camera id of the query.
gl (np.ndarray): the person ids of the gallery set.
gc (np.ndarray): the camera ids of the gallery set.
Returns:
tuple (float, torch.tensor): ap and cmc for one query.
"""
query_index = (ql == gl)
query_index = np.argwhere(query_index)
camera_index = (qc == gc)
camera_index = np.argwhere(camera_index)
# good index
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
# junk index
junk_index1 = np.argwhere(gl == -1)
junk_index2 = np.intersect1d(query_index, camera_index)
junk_index = np.append(junk_index2, junk_index1)
AP_tmp, CMC_tmp = self.compute_ap_cmc(index, good_index, junk_index)
return AP_tmp, CMC_tmp
def __call__(self, query_result: List, gallery_info: List) -> (float, Dict):
"""
Calculate the mAP and recall for the indexing results.
Args:
query_result (list): a list of indexing results.
gallery_info (list): a list of gallery set information.
Returns:
tuple (float, dict): mean average precision and recall for each position.
"""
AP = 0.0
CMC = torch.IntTensor(range(len(gallery_info))).zero_()
gallery_label = np.array([int(gallery_info[idx]["label"]) for idx in range(len(gallery_info))])
gallery_cam = np.array([int(gallery_info[idx]["cam"]) for idx in range(len(gallery_info))])
recall_at_k = dict()
for k in self._hyper_params["recall_k"]:
recall_at_k[k] = 0
for i in range(len(query_result)):
AP_tmp, CMC_tmp = self.evaluate_once(np.array(query_result[i]["ranked_neighbors_idx"]),
int(query_result[i]["label"]), int(query_result[i]["cam"]),
gallery_label, gallery_cam)
if CMC_tmp[0] == -1:
continue
CMC = CMC + CMC_tmp
AP += AP_tmp
CMC = CMC.float()
CMC = CMC / len(query_result) # average CMC
for k in recall_at_k:
recall_at_k[k] = (CMC[k-1] * 100).item()
mAP = AP / len(query_result) * 100
return mAP, recall_at_k

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .helper import EvaluateHelper
__all__ = [
'EvaluateHelper',
]

View File

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
import torch
from ..evaluator import EvaluatorBase
from typing import Dict, List
class EvaluateHelper:
"""
A helper class to evaluate query results.
"""
def __init__(self, evaluator: EvaluatorBase):
"""
Args:
evaluator: a evaluator class.
"""
self.evaluator = evaluator
self.recall_k = evaluator.default_hyper_params["recall_k"]
def show_results(self, mAP: float, recall_at_k: Dict) -> None:
"""
Show the evaluate results.
Args:
mAP (float): mean average precision.
recall_at_k (Dict): recall at the k position.
"""
repr_str = "mAP: {:.1f}\n".format(mAP)
for k in self.recall_k:
repr_str += "R@{}: {:.1f}\t".format(k, recall_at_k[k])
print('--------------- Retrieval Evaluation ------------')
print(repr_str)
def do_eval(self, query_result_info: List, gallery_info: List) -> (float, Dict):
"""
Get the evaluate results.
Args:
query_result_info (list): a list of indexing results.
gallery_info (list): a list of gallery set information.
Returns:
tuple (float, Dict): mean average precision and recall for each position.
"""
mAP, recall_at_k = self.evaluator(query_result_info, gallery_info)
return mAP, recall_at_k

View File

@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-
from ..utils import Registry
EVALUATORS = Registry()

Binary file not shown.

View File

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .config import get_extractor_cfg, get_aggregators_cfg, get_extract_cfg
from .builder import build_aggregators, build_extractor, build_extract_helper
from .utils import split_dataset, make_data_json
__all__ = [
'get_extract_cfg',
'build_aggregators', 'build_extractor', 'build_extract_helper',
'split_dataset', 'make_data_json',
]

Binary file not shown.

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .aggregators_impl.crow import Crow
from .aggregators_impl.gap import GAP
from .aggregators_impl.gem import GeM
from .aggregators_impl.gmp import GMP
from .aggregators_impl.pwa import PWA
from .aggregators_impl.r_mac import RMAC
from .aggregators_impl.scda import SCDA
from .aggregators_impl.spoc import SPoC
from .aggregators_base import AggregatorBase
__all__ = [
'AggregatorBase',
'Crow', 'GAP', 'GeM', 'GMP', 'PWA', 'RMAC', 'SCDA', 'SPoC',
'build_aggregators',
]

View File

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
from abc import abstractmethod
import torch
from ...utils import ModuleBase
from typing import Dict
class AggregatorBase(ModuleBase):
r"""
The base class for feature aggregators.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(AggregatorBase, self).__init__(hps)
@abstractmethod
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
pass

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
import glob
from os.path import basename, dirname, isfile
modules = glob.glob(dirname(__file__) + "/*.py")
__all__ = [
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
]

View File

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
import torch
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from typing import Dict
@AGGREGATORS.register
class Crow(AggregatorBase):
"""
Cross-dimensional Weighting for Aggregated Deep Convolutional Features.
c.f. https://arxiv.org/pdf/1512.04065.pdf
Hyper-Params
spatial_a (float): hyper-parameter for calculating spatial weight.
spatial_b (float): hyper-parameter for calculating spatial weight.
"""
default_hyper_params = {
"spatial_a": 2.0,
"spatial_b": 2.0,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
self.first_show = True
super(Crow, self).__init__(hps)
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
spatial_a = self._hyper_params["spatial_a"]
spatial_b = self._hyper_params["spatial_b"]
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
spatial_weight = fea.sum(dim=1, keepdims=True)
z = (spatial_weight ** spatial_a).sum(dim=(2, 3), keepdims=True)
z = z ** (1.0 / spatial_a)
spatial_weight = (spatial_weight / z) ** (1.0 / spatial_b)
c, w, h = fea.shape[1:]
nonzeros = (fea!=0).float().sum(dim=(2, 3)) / 1.0 / (w * h) + 1e-6
channel_weight = torch.log(nonzeros.sum(dim=1, keepdims=True) / nonzeros)
fea = fea * spatial_weight
fea = fea.sum(dim=(2, 3))
fea = fea * channel_weight
ret[key + "_{}".format(self.__class__.__name__)] = fea
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[Crow Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
import torch
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from typing import Dict
@AGGREGATORS.register
class GAP(AggregatorBase):
"""
Global average pooling.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
self.first_show = True
super(GAP, self).__init__(hps)
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
fea = fea.mean(dim=3).mean(dim=2)
ret[key + "_{}".format(self.__class__.__name__)] = fea
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[GAP Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret

View File

@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
import torch
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from typing import Dict
@AGGREGATORS.register
class GeM(AggregatorBase):
"""
Generalized-mean pooling.
c.f. https://pdfs.semanticscholar.org/a2ca/e0ed91d8a3298b3209fc7ea0a4248b914386.pdf
Hyper-Params
p (float): hyper-parameter for calculating generalized mean. If p = 1, GeM is equal to global average pooling, and
if p = +infinity, GeM is equal to global max pooling.
"""
default_hyper_params = {
"p": 3.0,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
self.first_show = True
super(GeM, self).__init__(hps)
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
p = self._hyper_params["p"]
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
fea = fea ** p
h, w = fea.shape[2:]
fea = fea.sum(dim=(2, 3)) * 1.0 / w / h
fea = fea ** (1.0 / p)
ret[key + "_{}".format(self.__class__.__name__)] = fea
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[GeM Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
import torch
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from typing import Dict
@AGGREGATORS.register
class GMP(AggregatorBase):
"""
Global maximum pooling
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
self.first_show = True
super(GMP, self).__init__(hps)
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
fea = (fea.max(dim=3)[0]).max(dim=2)[0]
ret[key + "_{}".format(self.__class__.__name__)] = fea
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[GMP Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret

View File

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
import torch
import numpy as np
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from ....index.utils import feature_loader
from typing import Dict
@AGGREGATORS.register
class PWA(AggregatorBase):
"""
Part-based Weighting Aggregation.
c.f. https://arxiv.org/abs/1705.01247
Hyper-Params
train_fea_dir (str): path of feature dir for selecting channels.
n_proposal (int): number of proposals to be selected.
alpha (float): alpha for calculate spatial weight.
beta (float): beta for calculate spatial weight.
"""
default_hyper_params = {
"train_fea_dir": "",
"n_proposal": 25,
"alpha": 2.0,
"beta": 2.0,
"train_fea_names": ["pool5_GAP"],
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(PWA, self).__init__(hps)
self.first_show = True
assert self._hyper_params["train_fea_dir"] != ""
self.selected_proposals_idx = None
self.train()
def train(self) -> None:
n_proposal = self._hyper_params["n_proposal"]
stacked_fea, _, pos_info = feature_loader.load(
self._hyper_params["train_fea_dir"],
self._hyper_params["train_fea_names"]
)
self.selected_proposals_idx = dict()
for fea_name in pos_info:
st_idx, ed_idx = pos_info[fea_name]
fea = stacked_fea[:, st_idx: ed_idx]
assert fea.ndim == 2, "invalid train feature"
channel_variance = np.std(fea, axis=0)
selected_idx = channel_variance.argsort()[-n_proposal:]
fea_name = "_".join(fea_name.split("_")[:-1])
self.selected_proposals_idx[fea_name] = selected_idx.tolist()
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
alpha, beta = self._hyper_params["alpha"], self._hyper_params["beta"]
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
assert (key in self.selected_proposals_idx), '{} is not in the {}'.format(key, self.selected_proposals_idx.keys())
proposals_idx = np.array(self.selected_proposals_idx[key])
proposals = fea[:, proposals_idx, :, :]
power_norm = (proposals ** alpha).sum(dim=(2, 3), keepdims=True) ** (1.0 / alpha)
normed_proposals = (proposals / (power_norm + 1e-5)) ** (1.0 / beta)
fea = (fea[:, None, :, :, :] * normed_proposals[:, :, None, :, :]).sum(dim=(3, 4))
fea = fea.view(fea.shape[0], -1)
ret[key + "_{}".format(self.__class__.__name__)] = fea
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[PWA Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret

View File

@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
import torch
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from typing import Dict, List
@AGGREGATORS.register
class RMAC(AggregatorBase):
"""
Regional Maximum activation of convolutions (R-MAC).
c.f. https://arxiv.org/pdf/1511.05879.pdf
Hyper-Params
level_n (int): number of levels for selecting regions.
"""
default_hyper_params = {
"level_n": 3,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(RMAC, self).__init__(hps)
self.first_show = True
self.cached_regions = dict()
def _get_regions(self, h: int, w: int) -> List:
"""
Divide the image into several regions.
Args:
h (int): height for dividing regions.
w (int): width for dividing regions.
Returns:
regions (List): a list of region positions.
"""
if (h, w) in self.cached_regions:
return self.cached_regions[(h, w)]
m = 1
n_h, n_w = 1, 1
regions = list()
if h != w:
min_edge = min(h, w)
left_space = max(h, w) - min(h, w)
iou_target = 0.4
iou_best = 1.0
while True:
iou_tmp = (min_edge ** 2 - min_edge * (left_space // m)) / (min_edge ** 2)
# small m maybe result in non-overlap
if iou_tmp <= 0:
m += 1
continue
if abs(iou_tmp - iou_target) <= iou_best:
iou_best = abs(iou_tmp - iou_target)
m += 1
else:
break
if h < w:
n_w = m
else:
n_h = m
for i in range(self._hyper_params["level_n"]):
region_width = int(2 * 1.0 / (i + 2) * min(h, w))
step_size_h = (h - region_width) // n_h
step_size_w = (w - region_width) // n_w
for x in range(n_h):
for y in range(n_w):
st_x = step_size_h * x
ed_x = st_x + region_width - 1
assert ed_x < h
st_y = step_size_w * y
ed_y = st_y + region_width - 1
assert ed_y < w
regions.append((st_x, st_y, ed_x, ed_y))
n_h += 1
n_w += 1
self.cached_regions[(h, w)] = regions
return regions
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
h, w = fea.shape[2:]
final_fea = None
regions = self._get_regions(h, w)
for _, r in enumerate(regions):
st_x, st_y, ed_x, ed_y = r
region_fea = (fea[:, :, st_x: ed_x, st_y: ed_y].max(dim=3)[0]).max(dim=2)[0]
region_fea = region_fea / torch.norm(region_fea, dim=1, keepdim=True)
if final_fea is None:
final_fea = region_fea
else:
final_fea = final_fea + region_fea
ret[key + "_{}".format(self.__class__.__name__)] = final_fea
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[RMAC Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret

View File

@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-
import queue
import torch
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from typing import Dict
@AGGREGATORS.register
class SCDA(AggregatorBase):
"""
Selective Convolutional Descriptor Aggregation for Fine-Grained Image Retrieval.
c.f. http://www.weixiushen.com/publication/tip17SCDA.pdf
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(SCDA, self).__init__(hps)
self.first_show = True
def bfs(self, x: int, y: int, mask: torch.tensor, cc_map: torch.tensor, cc_id: int) -> int:
dirs = [[1, 0], [-1, 0], [0, 1], [0, -1]]
q = queue.LifoQueue()
q.put((x, y))
ret = 1
cc_map[x][y] = cc_id
while not q.empty():
x, y = q.get()
for (dx, dy) in dirs:
new_x = x + dx
new_y = y + dy
if 0 <= new_x < mask.shape[0] and 0 <= new_y < mask.shape[1]:
if mask[new_x][new_y] == 1 and cc_map[new_x][new_y] == 0:
q.put((new_x, new_y))
ret += 1
cc_map[new_x][new_y] = cc_id
return ret
def find_max_cc(self, mask: torch.tensor) -> torch.tensor:
"""
Find the largest connected component of the mask
Args:
mask (torch.tensor): the original mask.
Returns:
mask (torch.tensor): the mask only containing the maximum connected component.
"""
assert mask.ndim == 4
assert mask.shape[1] == 1
mask = mask[:, 0, :, :]
for i in range(mask.shape[0]):
m = mask[i]
cc_map = torch.zeros(m.shape)
cc_num = list()
for x in range(m.shape[0]):
for y in range(m.shape[1]):
if m[x][y] == 1 and cc_map[x][y] == 0:
cc_id = len(cc_num) + 1
cc_num.append(self.bfs(x, y, m, cc_map, cc_id))
max_cc_id = cc_num.index(max(cc_num)) + 1
m[cc_map != max_cc_id] = 0
mask = mask[:, None, :, :]
return mask
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
mask = fea.sum(dim=1, keepdims=True)
thres = mask.mean(dim=(2, 3), keepdims=True)
mask[mask <= thres] = 0
mask[mask > thres] = 1
mask = self.find_max_cc(mask)
fea = fea * mask
gap = fea.mean(dim=(2, 3))
gmp, _ = fea.max(dim=3)
gmp, _ = gmp.max(dim=2)
ret[key + "_{}".format(self.__class__.__name__)] = torch.cat([gap, gmp], dim=1)
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[SCDA Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret

View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
import torch
import numpy as np
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from typing import Dict
@AGGREGATORS.register
class SPoC(AggregatorBase):
"""
SPoC with center prior.
c.f. https://arxiv.org/pdf/1510.07493.pdf
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(SPoC, self).__init__(hps)
self.first_show = True
self.spatial_weight_cache = dict()
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
h, w = fea.shape[2:]
if (h, w) in self.spatial_weight_cache:
spatial_weight = self.spatial_weight_cache[(h, w)]
else:
sigma = min(h, w) / 2.0 / 3.0
x = torch.Tensor(range(w))
y = torch.Tensor(range(h))[:, None]
spatial_weight = torch.exp(-((x - (w - 1) / 2.0) ** 2 + (y - (h - 1) / 2.0) ** 2) / 2.0 / (sigma ** 2))
if torch.cuda.is_available():
spatial_weight = spatial_weight.cuda()
spatial_weight = spatial_weight[None, None, :, :]
self.spatial_weight_cache[(h, w)] = spatial_weight
fea = (fea * spatial_weight).sum(dim=(2, 3))
ret[key + "_{}".format(self.__class__.__name__)] = fea
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[SPoC Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret

View File

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .registry import AGGREGATORS, SPLITTERS, EXTRACTORS
from .extractor import ExtractorBase
from .splitter import SplitterBase
from .aggregator import AggregatorBase
from .helper import ExtractHelper
from ..utils import simple_build
import torch.nn as nn
from typing import List
def build_aggregators(cfg: CfgNode) -> List[AggregatorBase]:
"""
Instantiate a list of aggregator classes.
Args:
cfg (CfgNode): the configuration tree.
Returns:
aggregators (list): a list of instances of aggregator class.
"""
names = cfg["names"]
aggregators = list()
for name in names:
aggregators.append(simple_build(name, cfg, AGGREGATORS))
return aggregators
def build_extractor(model: nn.Module, cfg: CfgNode) -> ExtractorBase:
"""
Instantiate a extractor class.
Args:
model (nn.Module): the model for extracting features.
cfg (CfgNode): the configuration tree.
Returns:
extractor (ExtractorBase): an instance of extractor class.
"""
name = cfg["name"]
extractor = simple_build(name, cfg, EXTRACTORS, model=model)
return extractor
def build_splitter(cfg: CfgNode) -> SplitterBase:
"""
Instantiate a splitter class.
Args:
cfg (CfgNode): the configuration tree.
Returns:
splitter (SplitterBase): an instance of splitter class.
"""
name = cfg["name"]
splitter = simple_build(name, cfg, SPLITTERS)
return splitter
def build_extract_helper(model: nn.Module, cfg: CfgNode) -> ExtractHelper:
"""
Instantiate a extract helper class.
Args:
model (nn.Module): the model for extracting features.
cfg (CfgNode): the configuration tree.
Returns:
helper (ExtractHelper): an instance of extract helper class.
"""
assemble = cfg.assemble
extractor = build_extractor(model, cfg.extractor)
splitter = build_splitter(cfg.splitter)
aggregators = build_aggregators(cfg.aggregators)
helper = ExtractHelper(assemble, extractor, splitter, aggregators)
return helper

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .registry import EXTRACTORS, SPLITTERS, AGGREGATORS
from ..utils import get_config_from_registry
def get_aggregators_cfg() -> CfgNode:
cfg = get_config_from_registry(AGGREGATORS)
cfg["names"] = list()
return cfg
def get_splitter_cfg() -> CfgNode:
cfg = get_config_from_registry(SPLITTERS)
cfg["name"] = "unknown"
return cfg
def get_extractor_cfg() -> CfgNode:
cfg = get_config_from_registry(EXTRACTORS)
cfg["name"] = "unknown"
return cfg
def get_extract_cfg() -> CfgNode:
cfg = CfgNode()
cfg["assemble"] = 0
cfg["extractor"] = get_extractor_cfg()
cfg["splitter"] = get_splitter_cfg()
cfg["aggregators"] = get_aggregators_cfg()
return cfg

View File

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
import torch.nn as nn
from .extractors_impl.vgg_series import VggSeries
from .extractors_impl.res_series import ResSeries
from .extractors_impl.reid_series import ReIDSeries
from .extractors_base import ExtractorBase
__all__ = [
'ExtractorBase',
'VggSeries', 'ResSeries',
'ReIDSeries',
]

View File

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
from functools import partial
import torch
import torch.nn as nn
import numpy as np
from ...utils import ModuleBase
from typing import Dict
class ExtractorBase(ModuleBase):
"""
The base class feature map extractors.
Hyper-Parameters
extract_features (list): indicates which feature maps to output. See available_feas for available feature maps.
If it is ["all"], then all available features will be output.
"""
available_feas = list()
default_hyper_params = {
"extract_features": list(),
}
def __init__(self, model: nn.Module, feature_modules: Dict[str, nn.Module], hps: Dict or None = None):
"""
Args:
model (nn.Module): the model for extracting features.
feature_modules (dict): the output layer of the model.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(ExtractorBase, self).__init__(hps)
assert len(self._hyper_params["extract_features"]) > 0
self.model = model.eval()
if torch.cuda.is_available():
self.model.cuda()
if torch.cuda.device_count() > 1:
self.model = nn.DataParallel(self.model)
self.feature_modules = feature_modules
self.feature_buffer = dict()
if self._hyper_params["extract_features"][0] == "all":
self._hyper_params["extract_features"] = self.available_feas
for fea in self._hyper_params["extract_features"]:
self.feature_buffer[fea] = dict()
self._register_hook()
def _register_hook(self) -> None:
"""
Register hooks to output inner feature map.
"""
def hook(feature_buffer, fea_name, module, input, output):
feature_buffer[fea_name][str(output.device)] = output.data
for fea in self._hyper_params["extract_features"]:
assert fea in self.feature_modules, 'unknown feature {}!'.format(fea)
self.feature_modules[fea].register_forward_hook(partial(hook, self.feature_buffer, fea))
def __call__(self, x: torch.tensor) -> Dict:
with torch.no_grad():
self.model(x)
ret = dict()
for fea in self._hyper_params["extract_features"]:
ret[fea] = list()
devices = list(self.feature_buffer[fea].keys())
devices = np.sort(devices)
for d in devices:
ret[fea].append(self.feature_buffer[fea][d])
ret[fea] = torch.cat(ret[fea], dim=0)
return ret

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
import glob
from os.path import basename, dirname, isfile
modules = glob.glob(dirname(__file__) + "/*.py")
__all__ = [
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
]

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
import torch.nn as nn
from ..extractors_base import ExtractorBase
from ...registry import EXTRACTORS
from typing import Dict
@EXTRACTORS.register
class ReIDSeries(ExtractorBase):
"""
The extractors for reid baseline models.
Hyper-Parameters
extract_features (list): indicates which feature maps to output. See available_feas for available feature maps.
If it is ["all"], then all available features will be output.
"""
default_hyper_params = {
"extract_features": list(),
}
available_feas = ["output"]
def __init__(self, model: nn.Module, hps: Dict or None = None):
"""
Args:
model (nn.Module): the model for extracting features.
hps (dict): default hyper parameters in a dict (keys, values).
"""
children = list(model.children())
feature_modules = {
"output": children[1].add_block,
}
super(ReIDSeries, self).__init__(model, feature_modules, hps)

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
from ..extractors_base import ExtractorBase
from ...registry import EXTRACTORS
from typing import Dict
@EXTRACTORS.register
class ResSeries(ExtractorBase):
"""
The extractors for ResNet.
Hyper-Parameters
extract_features (list): indicates which feature maps to output. See available_feas for available feature maps.
If it is ["all"], then all available features will be output.
"""
default_hyper_params = {
"extract_features": list(),
}
available_feas = ["pool5", "pool4", "pool3"]
def __init__(self, model, hps: Dict or None = None):
"""
Args:
model (nn.Module): the model for extracting features.
hps (dict): default hyper parameters in a dict (keys, values).
"""
children = list(model.children())
feature_modules = {
"pool5": children[-3][-1].relu,
"pool4": children[-4][-1].relu,
"pool3": children[-5][-1].relu
}
super(ResSeries, self).__init__(model, feature_modules, hps)

View File

@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
from ..extractors_base import ExtractorBase
from ...registry import EXTRACTORS
from typing import Dict
@EXTRACTORS.register
class VggSeries(ExtractorBase):
"""
The extractors for VGG Net.
Hyper-Parameters
extract_features (list): indicates which feature maps to output. See available_feas for available feature maps.
If it is ["all"], then all available features will be output.
"""
default_hyper_params = {
"extract_features": list(),
}
available_feas = ["fc", "pool5", "pool4"]
def __init__(self, model, hps: Dict or None = None):
"""
Args:
model (nn.Module): the model for extracting features.
hps (dict): default hyper parameters in a dict (keys, values).
"""
feature_modules = {
"fc": model.classifier[4],
"pool5": model.features[-1],
"pool4": model.features[23]
}
super(VggSeries, self).__init__(model, feature_modules, hps)

View File

@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
from .helper import ExtractHelper
__all__ = [
'ExtractHelper',
]

View File

@ -0,0 +1,141 @@
# -*- coding: utf-8 -*-
import os
import pickle
from copy import deepcopy
from tqdm import tqdm
import torch
from ..extractor import ExtractorBase
from ..aggregator import AggregatorBase
from ..splitter import SplitterBase
from ...utils import ensure_dir
from torch.utils.data import DataLoader
from typing import Dict, List
class ExtractHelper:
"""
A helper class to extract feature maps from model, and then aggregate them.
"""
def __init__(self, assemble: int, extractor: ExtractorBase, splitter: SplitterBase, aggregators: List[AggregatorBase]):
"""
Args:
assemble (int): way to assemble features if transformers produce multiple images (e.g. TwoFlip, TenCrop).
extractor (ExtractorBase): a extractor class for extracting features.
splitter (SplitterBase): a splitter class for splitting features.
aggregators (list): a list of extractor classes for aggregating features.
"""
self.assemble = assemble
self.extractor = extractor
self.splitter = splitter
self.aggregators = aggregators
def _save_part_fea(self, datainfo: Dict, save_fea: List, save_path: str) -> None:
"""
Save features in a json file.
Args:
datainfo (dict): the dataset information contained the data json file.
save_fea (list): a list of features to be saved.
save_path (str): the save path for the extracted features.
"""
save_json = dict()
for key in datainfo:
if key != "info_dicts":
save_json[key] = datainfo[key]
save_json["info_dicts"] = save_fea
with open(save_path, "wb") as f:
pickle.dump(save_json, f)
def extract_one_batch(self, batch: Dict) -> Dict:
"""
Extract features for a batch of images.
Args:
batch (dict): a dict containing several image tensors.
Returns:
all_fea_dict (dict): a dict containing extracted features.
"""
img = batch["img"]
if torch.cuda.is_available():
img = img.cuda()
# img is in the shape (N, IMG_AUG, C, H, W)
batch_size, aug_size = img.shape[0], img.shape[1]
img = img.view(-1, img.shape[2], img.shape[3], img.shape[4])
features = self.extractor(img)
features = self.splitter(features)
all_fea_dict = dict()
for aggregator in self.aggregators:
fea_dict = aggregator(features)
all_fea_dict.update(fea_dict)
# PyTorch will duplicate inputs if batch_size < n_gpu
for key in all_fea_dict.keys():
if self.assemble == 0:
features = all_fea_dict[key][:img.shape[0], :]
features = features.view(batch_size, aug_size, -1)
features = features.view(batch_size, -1)
all_fea_dict[key] = features
elif self.assemble == 1:
features = all_fea_dict[key].view(batch_size, aug_size, -1)
features = features.sum(dim=1)
all_fea_dict[key] = features
return all_fea_dict
def do_extract(self, dataloader: DataLoader, save_path: str, save_interval: int = 5000) -> None:
"""
Extract features for a whole dataset and save features in json files.
Args:
dataloader (DataLoader): a DataLoader class for loading images for training.
save_path (str): the save path for the extracted features.
save_interval (int, optional): number of features saved in one part file.
"""
datainfo = dataloader.dataset.data_info
pbar = tqdm(range(len(dataloader)))
save_fea = list()
part_cnt = 0
ensure_dir(save_path)
for _, batch in zip(pbar, dataloader):
feature_dict = self.extract_one_batch(batch)
for i in range(len(batch["img"])):
idx = batch["idx"][i]
save_fea.append(deepcopy(datainfo["info_dicts"][idx]))
single_fea_dict = dict()
for key in feature_dict:
single_fea_dict[key] = feature_dict[key][i].tolist()
save_fea[-1]["feature"] = single_fea_dict
save_fea[-1]["idx"] = int(idx)
if len(save_fea) >= save_interval:
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
part_cnt += 1
del save_fea
save_fea = list()
if len(save_fea) >= 1:
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
def do_single_extract(self, img: torch.Tensor) -> [Dict]:
"""
Extract features for a single image.
Args:
img (torch.Tensor): a single image tensor.
Returns:
[fea_dict] (sequence): the extract features of the image.
"""
batch = dict()
batch["img"] = img.view(1, img.shape[0], img.shape[1], img.shape[2], img.shape[3])
fea_dict = self.extract_one_batch(batch)
return [fea_dict]

View File

@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
from ..utils.registry import Registry
EXTRACTORS = Registry()
SPLITTERS = Registry()
AGGREGATORS = Registry()

Binary file not shown.

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .splitter_impl.identity import Identity
from .splitter_impl.pcb import PCB
from .splitter_base import SplitterBase
__all__ = [
'SplitterBase',
'Identity', 'PCB',
]

View File

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
from abc import abstractmethod
import torch
from ...utils import ModuleBase
from typing import Dict
class SplitterBase(ModuleBase):
"""
The base class for splitter function.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(SplitterBase, self).__init__(hps)
@abstractmethod
def __call__(self, features: torch.tensor) -> Dict:
pass

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
import glob
from os.path import basename, dirname, isfile
modules = glob.glob(dirname(__file__) + "/*.py")
__all__ = [
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
]

View File

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
import torch
import numpy as np
from ..splitter_base import SplitterBase
from ...registry import SPLITTERS
from typing import Dict
@SPLITTERS.register
class Identity(SplitterBase):
"""
Directly return feature maps without any operations.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(Identity, self).__init__(hps)
def __call__(self, features: torch.tensor) -> torch.tensor:
return features

View File

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
import torch
import numpy as np
from ..splitter_base import SplitterBase
from ...registry import SPLITTERS
from typing import Dict
@SPLITTERS.register
class PCB(SplitterBase):
"""
PCB function to split feature maps.
c.f. http://openaccess.thecvf.com/content_ECCV_2018/papers/Yifan_Sun_Beyond_Part_Models_ECCV_2018_paper.pdf
Hyper-Params:
stripe_num (int): the number of stripes divided.
"""
default_hyper_params = {
'stripe_num': 2,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(PCB, self).__init__(hps)
def __call__(self, features: torch.tensor) -> Dict:
ret = dict()
for key in features:
fea = features[key]
assert fea.ndimension() == 4
assert self.default_hyper_params["stripe_num"] <= fea.shape[2], \
'stripe num must be less than or equal to the height of fea'
stride = fea.shape[2] // self.default_hyper_params["stripe_num"]
for i in range(int(self.default_hyper_params["stripe_num"])):
ret[key + "_part_{}".format(i)] = fea[:, :, stride * i: stride * (i + 1), :]
ret[key + "_global"] = fea
return ret

View File

@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
from .make_data_json import make_data_json
from .split_dataset import split_dataset
__all__ = [
'split_dataset', 'make_data_json',
]

View File

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
import pickle
import os
def make_ds_for_general(dataset_path: str, save_path: str) -> None:
"""
Generate data json file for dataset collecting images with the same label one directory. e.g. CUB-200-2011.
Args:
dataset_path (str): the path of the dataset.
save_ds_path (str): the path for saving the data json files.
"""
info_dicts = list()
img_dirs = os.listdir(dataset_path)
label_list = list()
label_to_idx = dict()
for dir in img_dirs:
for root, _, files in os.walk(os.path.join(dataset_path, dir)):
for file in files:
info_dict = dict()
info_dict['path'] = os.path.join(root, file)
if dir not in label_list:
label_to_idx[dir] = len(label_list)
label_list.append(dir)
info_dict['label'] = dir
info_dict['label_idx'] = label_to_idx[dir]
info_dicts += [info_dict]
with open(save_path, 'wb') as f:
pickle.dump({'nr_class': len(img_dirs), 'path_type': 'absolute_path', 'info_dicts': info_dicts}, f)
def make_ds_for_oxford(dataset_path, save_path: str or None=None, gt_path: str or None=None) -> None:
"""
Generate data json file for oxford dataset.
Args:
dataset_path (str): the path of the dataset.
save_ds_path (str): the path for saving the data json files.
gt_path (str, optional): the path of the ground truth, necessary for Oxford.
"""
label_list = list()
info_dicts = list()
query_info = dict()
if 'query' in dataset_path:
for root, _, files in os.walk(gt_path):
for file in files:
if 'query' in file:
with open(os.path.join(root, file), 'r') as f:
line = f.readlines()[0].strip('\n').split(' ')
query_name = file[:-10]
label = line[0][5:]
bbox = [float(line[1]), float(line[2]), float(line[3]), float(line[4])]
query_info[label] = {'query_name': query_name, 'bbox': bbox,}
for root, _, files in os.walk(dataset_path):
for file in files:
info_dict = dict()
info_dict['path'] = os.path.join(root, file)
label = file.split('.')[0]
if label not in label_list:
label_list.append(label)
info_dict['label'] = label
if 'query' in dataset_path:
info_dict['bbox'] = query_info[label]['bbox']
info_dict['query_name'] = query_info[label]['query_name']
info_dicts += [info_dict]
with open(save_path, 'wb') as f:
pickle.dump({'nr_class': len(label_list), 'path_type': 'absolute_path', 'info_dicts': info_dicts}, f)
def make_ds_for_reid(dataset_path: str, save_path: str) -> None:
"""
Generating data json file for Re-ID dataset.
Args:
dataset_path (str): the path of the dataset.
save_ds_path (str): the path for saving the data json files.
"""
label_list = list()
info_dicts = list()
for root, _, files in os.walk(dataset_path):
for file in files:
info_dict = dict()
info_dict['path'] = os.path.join(root, file)
label = file.split('_')[0]
cam = file.split('_')[1][1]
if label not in label_list:
label_list.append(label)
info_dict['label'] = label
info_dict['cam'] = cam
info_dicts += [info_dict]
with open(save_path, 'wb') as f:
pickle.dump({'nr_class': len(label_list), 'path_type': 'absolute_path', 'info_dicts': info_dicts}, f)
def make_data_json(dataset_path: str, save_path: str, type: str, gt_path: str or None=None) -> None:
"""
Generate data json file for dataset.
Args:
dataset_path (str): the path of the dataset.
save_ds_path (str): the path for saving the data json files.
type (str): the structure type of the dataset.
gt_path (str, optional): the path of the ground truth, necessary for Oxford.
"""
assert type in ['general', 'oxford', 'reid']
if type == 'general':
make_ds_for_general(dataset_path, save_path)
elif type == 'oxford':
make_ds_for_oxford(dataset_path, save_path, gt_path)
elif typem == 'reid':
make_ds_for_reid(dataset_path, save_path)

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
import os
from shutil import copyfile
def split_dataset(dataset_path: str, split_file: str) -> None:
"""
Split the dataset according to the given splitting rules.
Args:
dataset_path (str): the path of the dataset.
split_file (str): the path of the file containing the splitting rules.
"""
with open(split_file, 'r') as f:
lines = f.readlines()
for line in lines:
path = line.strip('\n').split(' ')[0]
is_gallery = line.strip('\n').split(' ')[1]
if is_gallery == '0':
src = os.path.join(dataset_path, path)
dst = src.replace(path.split('/')[0], 'query')
dst_index = len(dst.split('/')[-1])
dst_dir = dst[:len(dst) - dst_index]
if not os.path.isdir(dst_dir):
os.makedirs(dst_dir)
os.symlink(src, dst)
elif is_gallery == '1':
src = os.path.join(dataset_path, path)
dst = src.replace(path.split('/')[0], 'gallery')
dst_index = len(dst.split('/')[-1])
dst_dir = dst[:len(dst) - dst_index]
if not os.path.isdir(dst_dir):
os.makedirs(dst_dir)
os.symlink(src, dst)

Binary file not shown.

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
from .config import get_index_cfg
from .builder import build_index_helper
from .utils import feature_loader
__all__ = [
'get_index_cfg',
'build_index_helper',
'feature_loader',
]

View File

@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .registry import ENHANCERS, METRICS, DIMPROCESSORS, RERANKERS
from .feature_enhancer import EnhanceBase
from .helper import IndexHelper
from .metric import MetricBase
from .dim_processor import DimProcessorBase
from .re_ranker import ReRankerBase
from ..utils import simple_build
from typing import List
def build_enhance(cfg: CfgNode) -> EnhanceBase:
"""
Instantiate a feature enhancer class.
Args:
cfg (CfgNode): the configuration tree.
Returns:
enhance (EnhanceBase): an instance of feature enhancer class.
"""
name = cfg["name"]
enhance = simple_build(name, cfg, ENHANCERS)
return enhance
def build_metric(cfg: CfgNode) -> MetricBase:
"""
Instantiate a metric class.
Args:
cfg (CfgNode): the configuration tree.
Returns:
metric (MetricBase): an instance of metric class.
"""
name = cfg["name"]
metric = simple_build(name, cfg, METRICS)
return metric
def build_processors(feature_names: List[str], cfg: CfgNode) -> DimProcessorBase:
"""
Instantiate a list of dimension processor classes.
Args:
cfg (CfgNode): the configuration tree.
Returns:
processors (list): a list of instances of dimension process class.
"""
names = cfg["names"]
processors = list()
for name in names:
processors.append(simple_build(name, cfg, DIMPROCESSORS, feature_names=feature_names))
return processors
def build_ranker(cfg: CfgNode) -> ReRankerBase:
"""
Instantiate a re-ranker class.
Args:
cfg (CfgNode): the configuration tree.
Returns:
re_rank (list): an instance of re-ranker class.
"""
name = cfg["name"]
re_rank = simple_build(name, cfg, RERANKERS)
return re_rank
def build_index_helper(cfg: CfgNode) -> IndexHelper:
"""
Instantiate a index helper class.
Args:
cfg (CfgNode): the configuration tree.
Returns:
helper (IndexHelper): an instance of index helper class.
"""
dim_processors = build_processors(cfg["feature_names"], cfg.dim_processors)
metric = build_metric(cfg.metric)
feature_enhancer = build_enhance(cfg.feature_enhancer)
re_ranker = build_ranker(cfg.re_ranker)
helper = IndexHelper(dim_processors, feature_enhancer, metric, re_ranker)
return helper

View File

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .registry import ENHANCERS, METRICS, DIMPROCESSORS, RERANKERS
from ..utils import get_config_from_registry
def get_enhancer_cfg() -> CfgNode:
cfg = get_config_from_registry(ENHANCERS)
cfg["name"] = "unknown"
return cfg
def get_metric_cfg() -> CfgNode:
cfg = get_config_from_registry(METRICS)
cfg["name"] = "unknown"
return cfg
def get_processors_cfg() -> CfgNode:
cfg = get_config_from_registry(DIMPROCESSORS)
cfg["names"] = ["unknown"]
return cfg
def get_ranker_cfg() -> CfgNode:
cfg = get_config_from_registry(RERANKERS)
cfg["name"] = "unknown"
return cfg
def get_index_cfg() -> CfgNode:
cfg = CfgNode()
cfg["query_fea_dir"] = "unknown"
cfg["gallery_fea_dir"] = "unknown"
cfg["feature_names"] = ["all"]
cfg["dim_processors"] = get_processors_cfg()
cfg["feature_enhancer"] = get_enhancer_cfg()
cfg["metric"] = get_metric_cfg()
cfg["re_ranker"] = get_ranker_cfg()
return cfg

Binary file not shown.

View File

@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .dim_processors_impl.identity import Identity
from .dim_processors_impl.l2_normalize import L2Normalize
from .dim_processors_impl.part_pca import PartPCA
from .dim_processors_impl.part_svd import PartSVD
from .dim_processors_impl.pca import PCA
from .dim_processors_impl.svd import SVD
from .dim_processors_impl.rmac_pca import RMACPCA
from .dim_processors_base import DimProcessorBase
__all__ = [
'DimProcessorBase',
'Identity', 'L2Normalize', 'PartPCA', 'PartSVD', 'PCA', 'SVD', 'RMACPCA',
]

View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
from abc import abstractmethod
import numpy as np
from ...utils import ModuleBase
from typing import Dict, List
class DimProcessorBase(ModuleBase):
"""
The base class of dimension processor.
"""
default_hyper_params = dict()
def __init__(self, feature_names: List[str], hps: Dict or None = None):
"""
Args:
feature_names (list): a list of features names to be loaded.
hps (dict): default hyper parameters in a dict (keys, values).
"""
ModuleBase.__init__(self, hps)
self.feature_names = feature_names
@abstractmethod
def __call__(self, fea: np.ndarray) -> np.ndarray:
pass

View File

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
import glob
from os.path import basename, dirname, isfile
modules = glob.glob(dirname(__file__) + "/*.py")
__all__ = [
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
]

View File

@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
import numpy as np
from ..dim_processors_base import DimProcessorBase
from ...registry import DIMPROCESSORS
from typing import Dict, List
@DIMPROCESSORS.register
class Identity(DimProcessorBase):
"""
Directly return feature without any dimension process operations.
"""
default_hyper_params = dict()
def __init__(self, feature_names: List[str], hps: Dict or None = None):
"""
Args:
feature_names (list): a list of features names to be loaded.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(Identity, self).__init__(feature_names, hps)
def __call__(self, fea: np.ndarray) -> np.ndarray:
return fea

View File

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
import numpy as np
from ..dim_processors_base import DimProcessorBase
from ...registry import DIMPROCESSORS
from sklearn.preprocessing import normalize
from typing import Dict, List
@DIMPROCESSORS.register
class L2Normalize(DimProcessorBase):
"""
L2 normalize the features.
"""
default_hyper_params = dict()
def __init__(self, feature_names: List[str], hps: Dict or None = None):
"""
Args:
feature_names (list): a list of features names to be loaded.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(L2Normalize, self).__init__(feature_names, hps)
def __call__(self, fea: np.ndarray) -> np.ndarray:
return normalize(fea, norm="l2")

View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
import numpy as np
from ..dim_processors_base import DimProcessorBase
from ...registry import DIMPROCESSORS
from ...utils import feature_loader
from sklearn.preprocessing import normalize
from sklearn.decomposition import PCA as SKPCA
from typing import Dict, List
@DIMPROCESSORS.register
class PartPCA(DimProcessorBase):
"""
Part PCA will divided whole feature into several parts. Then apply PCA transformation to each part.
It is usually used for features that extracted by several feature maps and concatenated together.
Hyper-Params:
proj_dim (int): the dimension after reduction. If it is 0, then no reduction will be done.
whiten (bool): whether do whiten for each part.
train_fea_dir (str): the path of features for training PCA.
l2 (bool): whether do l2-normalization for the training features.
"""
default_hyper_params = {
"proj_dim": 0,
"whiten": True,
"train_fea_dir": "unknown",
"l2": True,
}
def __init__(self, feature_names: List[str], hps: Dict or None = None):
"""
Args:
feature_names (list): a list of features names to be loaded.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(PartPCA, self).__init__(feature_names, hps)
self.pcas = dict()
self._train(self._hyper_params["train_fea_dir"])
def _train(self, fea_dir: str) -> None:
"""
Train the part PCA.
Args:
fea_dir (str): the path of features for training part PCA.
"""
fea, _, pos_info = feature_loader.load(fea_dir, self.feature_names)
fea_names = list(pos_info.keys())
ori_dim = fea.shape[1]
already_proj_dim = 0
for fea_name in fea_names:
st_idx, ed_idx = pos_info[fea_name][0], pos_info[fea_name][1]
ori_part_dim = ed_idx - st_idx
if self._hyper_params["proj_dim"] == 0:
proj_part_dim = ori_part_dim
else:
ratio = self._hyper_params["proj_dim"] * 1.0 / ori_dim
if fea_name != fea_names[-1]:
proj_part_dim = int(ori_part_dim * ratio)
else:
proj_part_dim = self._hyper_params["proj_dim"] - already_proj_dim
assert proj_part_dim <= ori_part_dim, "reduction dimension can not be distributed to each part!"
already_proj_dim += proj_part_dim
pca = SKPCA(n_components=proj_part_dim, whiten=self._hyper_params["whiten"])
train_fea = fea[:, st_idx: ed_idx]
if self._hyper_params["l2"]:
train_fea = normalize(train_fea, norm="l2")
pca.fit(train_fea)
self.pcas[fea_name] = {
"pos": (st_idx, ed_idx),
"pca": pca
}
def __call__(self, fea: np.ndarray) -> np.ndarray:
fea_names = np.sort(list(self.pcas.keys()))
ret = list()
for fea_name in fea_names:
st_idx, ed_idx = self.pcas[fea_name]["pos"][0], self.pcas[fea_name]["pos"][1]
pca = self.pcas[fea_name]["pca"]
ori_fea = fea[:, st_idx: ed_idx]
proj_fea = pca.transform(ori_fea)
ret.append(proj_fea)
ret = np.concatenate(ret, axis=1)
return ret

View File

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
import numpy as np
from ..dim_processors_base import DimProcessorBase
from ...registry import DIMPROCESSORS
from ...utils import feature_loader
from sklearn.preprocessing import normalize
from sklearn.decomposition import TruncatedSVD as SKSVD
from typing import Dict, List
@DIMPROCESSORS.register
class PartSVD(DimProcessorBase):
"""
Part SVD will divided whole feature into several parts. Then apply SVD transformation to each part.
It is usually used for features that extracted by several feature maps and concatenated together.
Hyper-Params:
proj_dim (int): the dimension after reduction. If it is 0, then no reduction will be done
(in SVD, we will minus origin dimension by 1).
whiten (bool): whether do whiten for each part.
train_fea_dir (str): the path of features for training SVD.
l2 (bool): whether do l2-normalization for the training features.
"""
default_hyper_params = {
"proj_dim": 0,
"whiten": True,
"train_fea_dir": "unknown",
"l2": True,
}
def __init__(self, feature_names: List[str], hps: Dict or None = None):
"""
Args:
feature_names (list): a list of features names to be loaded.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(PartSVD, self).__init__(feature_names, hps)
self.svds = dict()
self._train(self._hyper_params["train_fea_dir"])
def _train(self, fea_dir: str) -> None:
"""
Train the part SVD.
Args:
fea_dir (str): the path of features for training part SVD.
"""
fea, _, pos_info = feature_loader.load(fea_dir, self.feature_names)
fea_names = list(pos_info.keys())
ori_dim = fea.shape[1]
already_proj_dim = 0
for fea_name in fea_names:
st_idx, ed_idx = pos_info[fea_name][0], pos_info[fea_name][1]
ori_part_dim = ed_idx - st_idx
if self._hyper_params["proj_dim"] == 0:
proj_part_dim = ori_part_dim - 1
else:
ratio = self._hyper_params["proj_dim"] * 1.0 / ori_dim
if fea_name != fea_names[-1]:
proj_part_dim = int(ori_part_dim * ratio)
else:
proj_part_dim = self._hyper_params["proj_dim"] - already_proj_dim
assert proj_part_dim < ori_part_dim, "reduction dimension can not be distributed to each part!"
svd = SKSVD(n_components=proj_part_dim)
train_fea = fea[:, st_idx: ed_idx]
if self._hyper_params["l2"]:
train_fea = normalize(train_fea, norm="l2")
train_fea = svd.fit_transform(train_fea)
std = train_fea.std(axis=0, keepdims=True)
self.svds[fea_name] = {
"pos": (st_idx, ed_idx),
"svd": svd,
"std": std
}
def __call__(self, fea: np.ndarray) -> np.ndarray:
if self._hyper_params["proj_dim"] != 0:
ret = np.zeros(shape=(fea.shape[0], self._hyper_params["proj_dim"]))
else:
ret = np.zeros(shape=(fea.shape[0], fea.shape[1] - len(self.svds)))
for fea_name in self.svds:
st_idx, ed_idx = self.svds[fea_name]["pos"][0], self.svds[fea_name]["pos"][1]
svd = self.svds[fea_name]["svd"]
proj_fea = fea[:, st_idx: ed_idx]
proj_fea = svd.transform(proj_fea)
if self._hyper_params["whiten"]:
proj_fea = proj_fea / (self.svds[fea_name]["std"] + 1e-6)
ret[:, st_idx: ed_idx] = proj_fea
return ret

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
import numpy as np
from ..dim_processors_base import DimProcessorBase
from ...registry import DIMPROCESSORS
from ...utils import feature_loader
from sklearn.preprocessing import normalize
from sklearn.decomposition import PCA as SKPCA
from typing import Dict, List
@DIMPROCESSORS.register
class PCA(DimProcessorBase):
"""
Do the PCA transformation for dimension reduction.
Hyper-Params:
proj_dim (int): the dimension after reduction. If it is 0, then no reduction will be done.
whiten (bool): whether do whiten.
train_fea_dir (str): the path of features for training PCA.
l2 (bool): whether do l2-normalization for the training features.
"""
default_hyper_params = {
"proj_dim": 0,
"whiten": True,
"train_fea_dir": "unknown",
"l2": True,
}
def __init__(self, feature_names: List[str], hps: Dict or None = None):
"""
Args:
feature_names (list): a list of features names to be loaded.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(PCA, self).__init__(feature_names, hps)
self.pca = SKPCA(n_components=self._hyper_params["proj_dim"], whiten=self._hyper_params["whiten"])
self._train(self._hyper_params["train_fea_dir"])
def _train(self, fea_dir: str) -> None:
"""
Train the PCA.
Args:
fea_dir (str): the path of features for training PCA.
"""
train_fea, _, _ = feature_loader.load(fea_dir, self.feature_names)
if self._hyper_params["l2"]:
train_fea = normalize(train_fea, norm="l2")
self.pca.fit(train_fea)
def __call__(self, fea: np.ndarray) -> np.ndarray:
ori_fea = fea
proj_fea = self.pca.transform(ori_fea)
return proj_fea

View File

@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
import numpy as np
from ..dim_processors_base import DimProcessorBase
from ...registry import DIMPROCESSORS
from ...utils import feature_loader
from sklearn.preprocessing import normalize
from sklearn.decomposition import PCA
from typing import Dict, List
@DIMPROCESSORS.register
class RMACPCA(DimProcessorBase):
"""
Do the PCA transformation for R-MAC only.
When call this transformation, each part feature is processed by l2-normalize, PCA and l2-normalize.
Then the global feature is processed by l2-normalize.
Hyper-Params:
proj_dim: int. The dimension after reduction. If it is 0, then no reduction will be done.
whiten: bool, whether do whiten for each part.
train_fea: str, feature directory for training PCA.
l2 (bool): whether do l2-normalization for the training features.
"""
default_hyper_params = {
"proj_dim": 0,
"whiten": True,
"train_fea_dir": "unknown",
"l2": True,
}
def __init__(self, feature_names: List[str], hps: Dict or None = None):
"""
Args:
feature_names (list): a list of features names to be loaded.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(RMACPCA, self).__init__(feature_names, hps)
self.pca = dict()
self._train(self._hyper_params["train_fea_dir"])
def _train(self, fea_dir: str) -> None:
"""
Train the PCA for R-MAC.
Args:
fea_dir (str): the path of features for training PCA.
"""
fea, _, pos_info = feature_loader.load(fea_dir, self.feature_names)
fea_names = np.sort(list(pos_info.keys()))
region_feas = list()
for fea_name in fea_names:
st_idx, ed_idx = pos_info[fea_name][0], pos_info[fea_name][1]
region_fea = fea[:, st_idx: ed_idx]
region_feas.append(region_fea)
train_fea = np.concatenate(region_feas, axis=0)
if self._hyper_params["l2"]:
train_fea = normalize(train_fea, norm="l2")
pca = PCA(n_components=self._hyper_params["proj_dim"], whiten=self._hyper_params["whiten"])
pca.fit(train_fea)
self.pca = {
"pca": pca,
"pos_info": pos_info
}
def __call__(self, fea: np.ndarray) -> np.ndarray:
pca = self.pca["pca"]
pos_info = self.pca["pos_info"]
fea_names = np.sort(list(pos_info.keys()))
final_fea = None
for fea_name in fea_names:
st_idx, ed_idx = pos_info[fea_name][0], pos_info[fea_name][1]
region_fea = fea[:, st_idx: ed_idx]
region_fea = normalize(region_fea)
region_fea = pca.transform(region_fea)
region_fea = normalize(region_fea)
if final_fea is None:
final_fea = region_fea
else:
final_fea += region_fea
final_fea = normalize(final_fea)
return final_fea

View File

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
import numpy as np
from ..dim_processors_base import DimProcessorBase
from ...registry import DIMPROCESSORS
from ...utils import feature_loader
from sklearn.preprocessing import normalize
from sklearn.decomposition import TruncatedSVD as SKSVD
from typing import Dict, List
@DIMPROCESSORS.register
class SVD(DimProcessorBase):
"""
Do the SVD transformation for dimension reduction.
Hyper-Params:
proj_dim (int): the dimension after reduction. If it is 0, then no reduction will be done
(in SVD, we will minus origin dimension by 1).
whiten (bool): whether do whiten for each part.
train_fea_dir (str): the path of features for training SVD.
l2 (bool): whether do l2-normalization for the training features.
"""
default_hyper_params = {
"proj_dim": 0,
"whiten": True,
"train_fea_dir": "unknown",
"l2": True,
}
def __init__(self, feature_names: List[str], hps: Dict or None = None):
"""
Args:
feature_names (list): a list of features names to be loaded.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(SVD, self).__init__(feature_names, hps)
self.svd = SKSVD(n_components=self._hyper_params["proj_dim"])
self.std = 0.0
self._train(self._hyper_params["train_fea_dir"])
def _train(self, fea_dir: str) -> None:
"""
Train the SVD.
Args:
fea_dir: the path of features for training SVD.
"""
train_fea, _, _ = feature_loader.load(fea_dir, self.feature_names)
if self._hyper_params["l2"]:
train_fea = normalize(train_fea, norm="l2")
train_fea = self.svd.fit_transform(train_fea)
self.std = train_fea.std(axis=0, keepdims=True)
def __call__(self, fea: np.ndarray) -> np.ndarray:
ori_fea = fea
proj_fea = self.svd.transform(ori_fea)
if self._hyper_params["whiten"]:
proj_fea = proj_fea / (self.std + 1e-6)
return proj_fea

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More