mirror of https://github.com/PyRetri/PyRetri.git
upload
parent
d4145ece41
commit
0b23903d0b
|
@ -0,0 +1,73 @@
|
|||
# PyRetri
|
||||
|
||||
## Introduction
|
||||
|
||||
PyRetri is an open source deep learning-based image retrieval toolbox based on PyTorch.
|
||||
|
||||
### Major features
|
||||
|
||||
- **Modular Design**
|
||||
|
||||
We decompose the deep learning-based image retrieval into several stages and one can easily construct a image retrieval pipeline by combining different modules.
|
||||
|
||||
- **Flexible Loading**
|
||||
|
||||
The toolbox is able to adapt to load several types of model parameters, including parameters with the same keys and shape, parameters with different keys, and parameters with the same keys but different shapes.
|
||||
|
||||
- **Support of Multiple Methods**
|
||||
|
||||
The toolbox directly supports popluar methods designed fot deep learning-based image retrieval.
|
||||
|
||||
- **Combinations Search Tool**
|
||||
|
||||
We provide the pipeline combinations search scripts to help users to find possible combinations of approaches with various hyper-parameters.
|
||||
|
||||
### Supported Methods
|
||||
|
||||
- **Pre-processing**
|
||||
- DirectReszie, PadResize, ShorterResize
|
||||
- CenterCrop, TenCrop
|
||||
- TwoFlip
|
||||
- ToTensor, ToCaffeTensor
|
||||
- Normalize
|
||||
- **Feature Representation**
|
||||
- GAP
|
||||
- GMP
|
||||
- [R-MAC](https://arxiv.org/pdf/1511.05879.pdf)
|
||||
- [SPoC](https://arxiv.org/pdf/1510.07493.pdf)
|
||||
- [CroW](https://arxiv.org/pdf/1512.04065.pdf)
|
||||
- [SCDA](http://www.weixiushen.com/publication/tip17SCDA.pdf)
|
||||
- [GeM](https://pdfs.semanticscholar.org/a2ca/e0ed91d8a3298b3209fc7ea0a4248b914386.pdf)
|
||||
- [PWA](https://arxiv.org/abs/1705.01247)
|
||||
- [PCB](http://openaccess.thecvf.com/content_ECCV_2018/papers/Yifan_Sun_Beyond_Part_Models_ECCV_2018_paper.pdf)
|
||||
- **Post-processing**
|
||||
- [SVD](https://link.springer.com/chapter/10.1007%2F978-3-662-39778-7_10)
|
||||
- [PCA](http://pzs.dstu.dp.ua/DataMining/pca/bibl/Principal%20components%20analysis.pdf)
|
||||
- [DBA](https://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf)
|
||||
- [QE](https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf)
|
||||
- [K-reciprocal](https://arxiv.org/pdf/1701.08398.pdf)
|
||||
|
||||
## License
|
||||
|
||||
This project is released under the [Apache 2.0 license](https://github.com/hby96/pyretri/blob/master/LICENSE).
|
||||
|
||||
## Installation
|
||||
|
||||
Please refer to INSTALL.md for installation and dataset preparation.
|
||||
|
||||
## Get Started
|
||||
|
||||
Please see GETTING_STARTED.md for the basic usage of PyRetri.
|
||||
|
||||
## Model Zoo
|
||||
|
||||
Results and models are available in MODEL_ZOO.md.
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this toolbox in your research, please cite this project.
|
||||
|
||||
```shell
|
||||
|
||||
```
|
||||
|
|
@ -0,0 +1,306 @@
|
|||
# Getting started
|
||||
|
||||
This page provides basic tutorials about the usage of PyRetri. For installation instructions and dataset preparation, please see [INSTALL.md] (INSTALL.md).
|
||||
|
||||
## Make Data Json
|
||||
|
||||
After the gallery set and query set are separated, we package the information of each subset in pickle format for further process. We use different types to package different structured folders: `general`, `oxford` and `reid`.
|
||||
|
||||
The general object recognition dataset collects images with the same label in one directory and the folder structure should be like this:
|
||||
|
||||
```shell
|
||||
# type: general
|
||||
general_recognition
|
||||
├── class A
|
||||
│ ├── XXX.jpg
|
||||
│ └── ···
|
||||
├── class B
|
||||
│ ├── XXX.jpg
|
||||
│ └── ···
|
||||
└── ···
|
||||
```
|
||||
|
||||
Oxford5k is a typical dataset in image retrieval field and the folder structure is as follows:
|
||||
|
||||
```shell
|
||||
# type: oxford
|
||||
oxford
|
||||
├── gt
|
||||
│ ├── XXX.txt
|
||||
│ └── ···
|
||||
└── images
|
||||
├── XXX.jpg
|
||||
└── ···
|
||||
|
||||
```
|
||||
|
||||
The person re-identification dataset have already split the query set and gallery set, its folder structure should be like this:
|
||||
|
||||
```shell
|
||||
# type: reid
|
||||
person_re_identification
|
||||
├── bounding_box_test
|
||||
│ ├── XXX.jpg
|
||||
│ └── ···
|
||||
├── query
|
||||
│ ├── XXX.jpg
|
||||
│ └── ···
|
||||
└── ···
|
||||
```
|
||||
|
||||
Choosing the mode carefully, you can generate data jsons by:
|
||||
|
||||
```shell
|
||||
python3 main/make_data_json.py [-d ${dataset}] [-sp ${save_path}] [-t ${type}] [-gt ${ground_truth}]
|
||||
```
|
||||
|
||||
Auguments:
|
||||
|
||||
- `data`: Path of the dataset for generating data json file.
|
||||
- `save_path`: Path for saving the output file.
|
||||
- `type`: Type of the dataset collecting images. For dataset collecting images with the same label in one directory, we use `general`. For oxford/paris dataset, we use `oxford`. For re-id dataset, we use `reid`.
|
||||
- `ground_truth`: Path of the gt information, which is necessary for generating data json file of oxford/paris dataset.
|
||||
|
||||
Examples:
|
||||
|
||||
```shell
|
||||
# for dataset collecting images with the same label in one directory
|
||||
python3 main/make_data_json.py -d /data/caltech101/gallery/ -sp data_jsons/caltech_gallery.json -t general
|
||||
|
||||
python3 main/make_data_json.py -d /data/caltech101/query/ -sp data_jsons/caltech_query.json -t feneral
|
||||
|
||||
# for oxford/paris dataset
|
||||
python3 main/make_data_json.py -d /data/cbir/oxford/gallery/ -sp data_jsons/oxford_gallery.json -t oxford -gt /data/cbir/oxford/gt/
|
||||
|
||||
python3 main/make_data_json.py -d /data/cbir/oxford/query/ -sp data_jsons/oxford_query.json -t oxford -gt /data/cbir/oxford/gt/
|
||||
|
||||
# for re-id dataset
|
||||
python3 main/make_data_json.py -d /data/market1501/bounding_box_test/ -sp data_jsons/market_gallery.json -t reid
|
||||
|
||||
python3 main/make_data_json.py -d /data/market1501/query/ -sp data_jsons/market_query.json -t reid
|
||||
```
|
||||
|
||||
Note: Oxford/Paris dataset contains the ground truth of each query image in a txt file, so remember to give the path of gt file when generating data json file of Oxford/Paris.
|
||||
|
||||
## Extract
|
||||
|
||||
All outputs (features and labels) will be saved to the save directory in pickle format.
|
||||
|
||||
Extract feature for each data json file by:
|
||||
|
||||
```shell
|
||||
python3 main/extract_feature.py [-dj ${data_json}] [-sp ${save_path}] [-cfg ${config_file}] [-si ${save_interval}]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `data_json`: Path of the data json file to be extrated.
|
||||
- `save_path`: Path for saving the output features in pickle format.
|
||||
|
||||
- `config_file`: Path of the configuration file in yaml format.
|
||||
- `save_interval`: Optional. It is the number of features saved in one part file, which is set to 5000 by default.
|
||||
|
||||
```shell
|
||||
# extract features of gallert set and query set
|
||||
python3 main/extract_feature.py -dj data_jsons/caltech_gallery.json -sp /data/features/caltech/gallery/ -cfg configs/caltech.yaml
|
||||
|
||||
python3 main/extract_feature.py -dj data_jsons/caltech_query.json -sp /data/features/caltech/query/ -cfg configs/caltech.yaml
|
||||
```
|
||||
|
||||
## Index
|
||||
|
||||
The path of query set features and gallery set features is specified in the config file.
|
||||
|
||||
Index the query set features by:
|
||||
|
||||
```shell
|
||||
python3 main/index.py [-cfg ${config_file}]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `config_file`: Path of the configuration file in yaml format.
|
||||
|
||||
Examples:
|
||||
|
||||
```shell
|
||||
python3 main/index.py -cfg configs/caltech.yaml
|
||||
```
|
||||
|
||||
## Single Index
|
||||
|
||||
For visulization results and wrong case analysis, we provide the script for single query image and you can visualize or save the retrieval results easily.
|
||||
|
||||
Use this command to single index:
|
||||
|
||||
```shell
|
||||
python3 main/single_index.py [-cfg ${config_file}]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `config_file`: Path of the configuration file in yaml format.
|
||||
|
||||
Examples:
|
||||
|
||||
```shell
|
||||
python3 main/single_index.py -cfg configs/caltech.yaml
|
||||
```
|
||||
|
||||
Please see single_index.py for more details.
|
||||
|
||||
## Add Your Own Module
|
||||
|
||||
We basically categorize retrieval process into 4 components.
|
||||
|
||||
- model: the pre-trained model for feature extraction.
|
||||
- extract: assign which layer to output, including splitter functions and aggregation methods.
|
||||
- index: index features, including dimension process, feature enhance, distance metric and re-rank.
|
||||
- evaluate: evaluate retrieval results, outputting recall and mAP results.
|
||||
|
||||
Here we show how to add your own model to extract features.
|
||||
|
||||
1. Create your model file `retrieval_tool_box/models/backbone/backbone_impl/reid_baseline.py`.
|
||||
|
||||
```shell
|
||||
import torch.nn as nn
|
||||
|
||||
from ..backbone_base import BackboneBase
|
||||
from ...registry import BACKBONES
|
||||
|
||||
@BACKBONES.register
|
||||
class ft_net(BackboneBase):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
import torch.nn as nn
|
||||
|
||||
from ..backbone_base import BackboneBase
|
||||
from ...registry import BACKBONES
|
||||
|
||||
class FT_NET(BackboneBase):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
|
||||
@BACKBONES.register
|
||||
def ft_net():
|
||||
model = FT_NET()
|
||||
return model
|
||||
```
|
||||
|
||||
2. Import the module in `retrieval_tool_box/models/backbone/__init__.py`.
|
||||
|
||||
```shell
|
||||
from .backbone_impl.reid_baseline import ft_net
|
||||
|
||||
__all__ = [
|
||||
'ft_net',
|
||||
]
|
||||
```
|
||||
|
||||
3. Use it in your config file.
|
||||
|
||||
```shell
|
||||
model:
|
||||
name: "ft_net"
|
||||
ft_net:
|
||||
load_checkpoint: "/data/my_model_zoo/res50_market1501.pth"
|
||||
```
|
||||
|
||||
## Pipeline Combinations Search
|
||||
|
||||
Since tricks used in each stage have a signicant impact on retrieval performance, we present the pipeline combinations search scripts to help users to find possible combinations of approaches with various hyper-parameters.
|
||||
|
||||
### Get into the combinations search scripts
|
||||
|
||||
```shell
|
||||
cd search/
|
||||
```
|
||||
|
||||
### Define Search Space
|
||||
|
||||
We decompose the search space into three sub search spaces: data_process, extract and index, each of which corresponds to a specified file. Search space is defined by adding methods with hyper-parameters to a specified dict. You can add a search operator as follows:
|
||||
|
||||
```shell
|
||||
data_processes.add(
|
||||
"PadResize224",
|
||||
{
|
||||
"batch_size": 32,
|
||||
"folder": {
|
||||
"name": "Folder"
|
||||
},
|
||||
"collate_fn": {
|
||||
"name": "CollateFn"
|
||||
},
|
||||
"transformers": {
|
||||
"names": ["PadResize", "ToTensor", "Normalize"],
|
||||
"PadResize": {
|
||||
"size": 224,
|
||||
"padding_v": [124, 116, 104]
|
||||
},
|
||||
"Normalize": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
By doing this, a data_process operator named "PadResize224" is added to the data_process sub search space and will be searched in the following process.
|
||||
|
||||
### Search
|
||||
|
||||
Similar to the image retrieval pipeline, combinations search includes two stages: search for feature extraction and search for indexing.
|
||||
|
||||
#### search for feature extraction
|
||||
|
||||
Search for the feature extraction combinations by:
|
||||
|
||||
```shell
|
||||
python3 search_extract.py [-sp ${save_path}] [-sp ${search_modules}]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `save_path`: path for saving the output features in pickle format.
|
||||
- `search_modules`: name of the folder containing search space files.
|
||||
|
||||
Examples:
|
||||
|
||||
```shell
|
||||
python3 search_extract.py -sp /data/features/gap_gmp_gem_crow_spoc/ -sm search_modules
|
||||
```
|
||||
|
||||
#### search for indexing
|
||||
|
||||
Search for the indexing combinations by:
|
||||
|
||||
```shell
|
||||
python3 search_query.py [-fd ${fea_dir}] [-sm ${search_modules}] [-sp ${save_path}]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `fea_dir`: path of the output features extracted by the feature extraction combinations search.
|
||||
- `search_modules`: name of the folder containing search space files.
|
||||
- `save_path`: path for saving the retrieval results of each combination.
|
||||
|
||||
Examples:
|
||||
|
||||
```shell
|
||||
python3 search_query.py -fd /data/features/gap_gmp_gem_crow_spoc/ -sm search_modules -sp /data/features/gap_gmp_gem_crow_spoc_result.json
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
# Installation
|
||||
|
||||
## Requirements
|
||||
|
||||
- Linux (Windows is not officially supported)
|
||||
- Python 3.5 or higher
|
||||
- PyTorch 1.2.0 or higher
|
||||
- torchvison 0.4.0 or higher
|
||||
- CUDA 9.0 or higher
|
||||
- numpy
|
||||
- sklearn
|
||||
- yacs
|
||||
- tqdm
|
||||
|
||||
Our experiments are conducted on the following environment:
|
||||
|
||||
- PyTorch 1.2.0
|
||||
- torchvision 0.4.0
|
||||
- CUDA 9.0
|
||||
- numpy 1.17.2
|
||||
- sklearn 0.21.3
|
||||
- tqdm 4.36.1
|
||||
|
||||
## Install PyRetri
|
||||
|
||||
1. Install PyTorch and torchvision following the official instructions.
|
||||
|
||||
2. Clone the PyRetri repository.
|
||||
|
||||
```she
|
||||
git clone https://github.com/???
|
||||
cd ???
|
||||
```
|
||||
|
||||
3. Install PyRetri.
|
||||
|
||||
```shell
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
## Prepare Datasets
|
||||
|
||||
### Datasets
|
||||
|
||||
In our experiments, we use four general image retrieval dataset and two person re-identification dataset.
|
||||
|
||||
- [Oxford5k](https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/): collecting crawling images from Flickr using the names of 11 different landmarks in Oxford, which stands for landmark recognition task.
|
||||
- [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html): containing photos of 200 bird species, which represents fine-grained visual categorization task.
|
||||
- [Indoor](http://web.mit.edu/torralba/www/indoor.html): containing indoor scene images with 67 categories, representing scene recognition task.
|
||||
- [Caltech101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/): consisting pictures of objects belonging to 101 categories, standing for general object recognition task.
|
||||
- [Market-1501](http://www.liangzheng.com.cn/Project/project_reid.html): containing images taken on the Tsinghua campus under 6 camera viewpoints, representing person re-identification task.
|
||||
- [DukeMTMC-reID](https://drive.google.com/file/d/1jjE85dRCMOgRtvJ5RQV9-Afs-2_5dY3O/view): containing images captured by 8 cameras, which is more challenging.
|
||||
|
||||
To reproduce our experimental results, you need to first download these datasets, then follow these steps to re-organize the dataset.
|
||||
|
||||
### Split Dataset
|
||||
|
||||
For image retrieval task, the dataset should be divided into two subset: query set and gallery set. If your dataset has been divided already, you can skip this step.
|
||||
|
||||
In order to help you to reproduce our results conventionally, we provide several txt files, each of which is the division protocol used in our experiments. These txt files can be found in [split_file] and you can use the following command to split the dataset mentioned above:
|
||||
|
||||
```shell
|
||||
python3 main/split_dataset.py [-d ${dataset}] [-sf ${split_file}]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `dataset`: Path of the dataset to be splitted.
|
||||
- `split_file`: Path of the division protocol txt file, with each line corresponding to one image:<image_path> <is_gallery_image>. <image_path> corresponds to the relative path of the image, and a value of 1 or 0 for <is_gallery_image> denotes that the file is in the gallery or query set, respectively.
|
||||
|
||||
Examples:
|
||||
|
||||
```shell
|
||||
python3 main/split_dataset.py -d /data/caltech101/ -sf split_file/caltech_split.txt
|
||||
```
|
||||
|
||||
Then query folder and gallery folder will be created under the dataset folder.
|
||||
|
||||
Note:
|
||||
|
||||
1. For Re-ID dataset, the images are well divided in advance, so we do not need to split it.
|
||||
2. Since we use symlink images instead of copying images to split the dataset, the overwrite operation is prohibited. In other words, if you want to split the dataset again, please remember to delete the last generated folders.
|
|
@ -0,0 +1,58 @@
|
|||
# Model Zoo
|
||||
|
||||
## General image retrieval
|
||||
|
||||
### pre-trained models
|
||||
|
||||
| Training Set | Backbone | for Short | Download |
|
||||
| :------------------: | :-------: | :-------: | :----------------------------------------------------------: |
|
||||
| ImageNet | VGG-16 | I-VGG16 | [model](https://download.pytorch.org/models/vgg16-397923af.pth) |
|
||||
| Places365 | VGG-16 | P-VGG16 | [model](https://drive.google.com/open?id=1U_VWbn_0L9mSDCBGiAIFbXxMvBeOiTG9) |
|
||||
| ImageNet + Places365 | VGG-16 | H-VGG16 | [model](https://drive.google.com/open?id=11zE5kGNeeAXMhlHNv31Ye4kDcECrlJ1t) |
|
||||
| ImageNet | ResNet-50 | I-Res50 | [model](https://download.pytorch.org/models/resnet50-19c8e357.pth) |
|
||||
| Places365 | ResNet-50 | P-Res50 | [model](https://drive.google.com/open?id=1lp_nNw7hh1MQO_kBW86GG8y3_CyugdS2) |
|
||||
| ImageNet + Places365 | ResNet-50 | H-Res50 | [model](https://drive.google.com/open?id=1_USt_gOxgV4NJ9Zjw_U8Fq-1HEC_H_ki) |
|
||||
|
||||
### performance
|
||||
|
||||
| Dataset | Data Augmentation | Backbone | Pooling | Dimension Process | mAP |
|
||||
| :--------: | :------------------------: | :------: | :-----: | :------------------: | :--: |
|
||||
| Oxford5k | ShorterResize + CenterCrop | H-VGG16 | GAP | l2 +SVD(whiten) + l2 | 62.9 |
|
||||
| CUB-200 | ShorterResize + CenterCrop | I-Res50 | SCDA | l2 + PCA + l2 | 27.8 |
|
||||
| Indoor | DirectResize | P-Res50 | CroW | l2 + PCA + l2 | 51.8 |
|
||||
| Caltech101 | PadResize | I-Res50 | GeM | l2 + PCA + l2 | 77.9 |
|
||||
|
||||
Choosing the implementations mentioned above as baselines and adding some tricks, we have:
|
||||
|
||||
| Dataset | Implementations | mAP |
|
||||
| :--------: | :--------------------------------: | :--: |
|
||||
| Oxford5k | baseline + K-reciprocal | 72.9 |
|
||||
| CUB-200 | baseline + K-reciprocal | 38.9 |
|
||||
| Indoor | baseline + DBA + QE | 63.7 |
|
||||
| Caltech101 | baseline + DBA + QE + K-reciprocal | 86.1 |
|
||||
|
||||
## Person re-identification
|
||||
|
||||
For person re-identification, we use the model provided by [Person_reID_baseline](https://github.com/layumi/Person_reID_baseline_pytorch) and reproduce its resutls. In addition, we train a model on DukeMTMC-reID through the open source code for further experiments.
|
||||
|
||||
###pre-trained models
|
||||
|
||||
| Training Set | Backbone | for Short | Download |
|
||||
| :-----------: | :-------: | :-------: | :------: |
|
||||
| Market-1501 | ResNet-50 | M-Res50 | [model](https://drive.google.com/open?id=1-6LT_NCgp_0ps3EO-uqERrtlGnbynWD5) |
|
||||
| DukeMTMC-reID | ResNet-50 | D-Res50 | [model](https://drive.google.com/open?id=1X2Tiv-SQH3FxwClvBUalWkLqflgZHb9m) |
|
||||
|
||||
### performance
|
||||
|
||||
| Dataset | Data Augmentation | Backbone | Pooling | Dimension Process | mAP | Recall@1 |
|
||||
| :-----------: | :--------------------: | :------: | :-----: | :---------------: | ---- | :------: |
|
||||
| Market-1501 | DirectResize + TwoFlip | M-Res50 | GAP | l2 | 71.6 | 88.8 |
|
||||
| DukeMTMC-reID | DirectResize + TwoFlip | D-Res50 | GAP | l2 | 62.5 | 80.4 |
|
||||
|
||||
Choosing the implementations mentioned above as baselines and adding some tricks, we have:
|
||||
|
||||
| Dataset | Implementations | mAP | Recall@1 |
|
||||
| :-----------: | :-------------------------------------: | :--: | :------: |
|
||||
| Market-1501 | Baseline + l2 + PCA + l2 + K-reciprocal | 84.8 | 90.4 |
|
||||
| DukeMTMC-reID | Baseline + l2 + PCA + l2 + K-reciprocal | 78.3 | 84.2 |
|
||||
|
Binary file not shown.
|
@ -0,0 +1,52 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from retrieval_tool_box.config import get_defaults_cfg, setup_cfg
|
||||
from retrieval_tool_box.datasets import build_folder, build_loader
|
||||
from retrieval_tool_box.models import build_model
|
||||
from retrieval_tool_box.extract import build_extract_helper
|
||||
|
||||
from torchvision import models
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
|
||||
parser.add_argument('--data_json', '-dj', default=None, type=str, help='json file for dataset to be extracted')
|
||||
parser.add_argument('--save_path', '-sp', default=None, type=str, help='save path for features')
|
||||
parser.add_argument('--config_file', '-cfg', default=None, metavar='FILE', type=str, help='path to config file')
|
||||
parser.add_argument('--save_interval', '-si', default=5000, type=int, help='number of features saved in one part file')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# init args
|
||||
args = parse_args()
|
||||
assert args.data_json is not None, 'the dataset json must be provided!'
|
||||
assert args.save_path is not None, 'the save path must be provided!'
|
||||
assert args.config_file is not None, 'a config file must be provided!'
|
||||
|
||||
# init and load retrieval pipeline settings
|
||||
cfg = get_defaults_cfg()
|
||||
cfg = setup_cfg(cfg, args.config_file, args.opts)
|
||||
|
||||
# build dataset and dataloader
|
||||
dataset = build_folder(args.data_json, cfg.datasets)
|
||||
dataloader = build_loader(dataset, cfg.datasets)
|
||||
|
||||
# build model
|
||||
model = build_model(cfg.model)
|
||||
|
||||
# build helper and extract features
|
||||
extract_helper = build_extract_helper(model, cfg.extract)
|
||||
extract_helper.do_extract(dataloader, args.save_path, args.save_interval)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,48 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
from retrieval_tool_box.config import get_defaults_cfg, setup_cfg
|
||||
from retrieval_tool_box.index import build_index_helper, feature_loader
|
||||
from retrieval_tool_box.evaluate import build_evaluate_helper
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
|
||||
parser.add_argument('--config_file', '-cfg', default=None, metavar='FILE', type=str, help='path to config file')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# init args
|
||||
args = parse_args()
|
||||
assert args.config_file is not None, 'a config file must be provided!'
|
||||
|
||||
# init and load retrieval pipeline settings
|
||||
cfg = get_defaults_cfg()
|
||||
cfg = setup_cfg(cfg, args.config_file, args.opts)
|
||||
|
||||
# load features
|
||||
query_fea, query_info, _ = feature_loader.load(cfg.index.query_fea_dir, cfg.index.feature_names)
|
||||
gallery_fea, gallery_info, _ = feature_loader.load(cfg.index.gallery_fea_dir, cfg.index.feature_names)
|
||||
|
||||
# build helper and index features
|
||||
index_helper = build_index_helper(cfg.index)
|
||||
index_result_info, query_fea, gallery_fea = index_helper.do_index(query_fea, query_info, gallery_fea)
|
||||
|
||||
# build helper and evaluate results
|
||||
evaluate_helper = build_evaluate_helper(cfg.evaluate)
|
||||
mAP, recall_at_k = evaluate_helper.do_eval(index_result_info, gallery_info)
|
||||
|
||||
# show results
|
||||
evaluate_helper.show_results(mAP, recall_at_k)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,36 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
|
||||
from retrieval_tool_box.extract import make_data_json
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
|
||||
parser.add_argument('--dataset', '-d', default=None, type=str, help="path for the dataset that make the json file")
|
||||
parser.add_argument('--save_path', '-sp', default=None, type=str, help="save path for the json file")
|
||||
parser.add_argument('--type', '-t', default=None, type=str, help="mode of the dataset")
|
||||
parser.add_argument('--ground_truth', '-gt', default=None, type=str, help="ground truth of the dataset")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# init args
|
||||
args = parse_args()
|
||||
assert args.dataset is not None, 'the data must be provided!'
|
||||
assert args.save_path is not None, 'the save path must be provided!'
|
||||
assert args.type is not None, 'the type must be provided!'
|
||||
|
||||
# make data json
|
||||
make_data_json(args.dataset, args.save_path, args.type, args.ground_truth)
|
||||
|
||||
print('make data json have done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,68 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from retrieval_tool_box.config import get_defaults_cfg, setup_cfg
|
||||
from retrieval_tool_box.datasets import build_transformers
|
||||
from retrieval_tool_box.models import build_model
|
||||
from retrieval_tool_box.extract import build_extract_helper
|
||||
from retrieval_tool_box.index import build_index_helper, feature_loader
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
|
||||
parser.add_argument('--config_file', '-cfg', default=None, metavar='FILE', type=str, help='path to config file')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# init args
|
||||
args = parse_args()
|
||||
assert args.config_file is not "", 'a config file must be provided!'
|
||||
assert os.path.exists(args.config_file), 'the config file must be existed!'
|
||||
|
||||
# init and load retrieval pipeline settings
|
||||
cfg = get_defaults_cfg()
|
||||
cfg = setup_cfg(cfg, args.config_file, args.opts)
|
||||
|
||||
# set path for single image
|
||||
path = '/data/caltech101/query/airplanes/image_0004.jpg'
|
||||
|
||||
# build transformers
|
||||
transformers = build_transformers(cfg.datasets.transformers)
|
||||
|
||||
# build model
|
||||
model = build_model(cfg.model)
|
||||
|
||||
# read image and convert it to tensor
|
||||
img = Image.open(path).convert("RGB")
|
||||
img_tensor = transformers(img)
|
||||
|
||||
# build helper and extract feature for single image
|
||||
extract_helper = build_extract_helper(model, cfg.extract)
|
||||
img_fea_info = extract_helper.do_single_extract(img_tensor)
|
||||
stacked_feature = list()
|
||||
for name in cfg.index.feature_names:
|
||||
assert name in img_fea_info[0], "invalid feature name: {} not in {}!".format(name, img_fea_info[0].keys())
|
||||
stacked_feature.append(img_fea_info[0][name])
|
||||
img_fea = np.concatenate(stacked_feature, axis=1)
|
||||
|
||||
# load gallery features
|
||||
gallery_fea, gallery_info, _ = feature_loader.load(cfg.index.gallery_fea_dir, cfg.index.feature_names)
|
||||
|
||||
# build helper and single index feature
|
||||
index_helper = build_index_helper(cfg.index)
|
||||
index_result_info, query_fea, gallery_fea = index_helper.do_index(img_fea, img_fea_info, gallery_fea)
|
||||
|
||||
# index_helper.show_topk_retrieved_images(index_result_info[0], 4, gallery_info)
|
||||
index_helper.save_topk_retrieved_images('../retrieved_images', index_result_info[0], 5, gallery_info)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from retrieval_tool_box.extract.utils import split_dataset
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='A tool box for deep learning-based image retrieval')
|
||||
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
|
||||
parser.add_argument('--dataset', '-d', default=None, type=str, help="path for the dataset.")
|
||||
parser.add_argument('--split_file', '-sf', default=None, type=str, help="name for the dataset.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# init args
|
||||
args = parse_args()
|
||||
assert args.dataset is not None, 'the dataset must be provided!'
|
||||
assert args.split_file is not None, 'the save path must be provided!'
|
||||
|
||||
# split dataset
|
||||
split_dataset(args.dataset, args.split_file)
|
||||
|
||||
print('split dataset have done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
numpy
|
||||
torch>=1.1
|
||||
torchvision>=0.4
|
||||
sklearn
|
||||
yacs
|
||||
tqdm
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
__version__ = "0.1.0"
|
|
@ -0,0 +1,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .config import setup_cfg, get_defaults_cfg
|
||||
|
||||
|
||||
__all__ = [
|
||||
'get_defaults_cfg',
|
||||
'setup_cfg',
|
||||
]
|
|
@ -0,0 +1,46 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from ..datasets import get_datasets_cfg
|
||||
from ..models import get_model_cfg
|
||||
from ..extract import get_extract_cfg
|
||||
from ..index import get_index_cfg
|
||||
from ..evaluate import get_evaluate_cfg
|
||||
|
||||
|
||||
def get_defaults_cfg() -> CfgNode:
|
||||
"""
|
||||
Construct the default configuration tree.
|
||||
|
||||
Returns:
|
||||
cfg (CfgNode): the default configuration tree.
|
||||
"""
|
||||
cfg = CfgNode()
|
||||
|
||||
cfg["datasets"] = get_datasets_cfg()
|
||||
cfg["model"] = get_model_cfg()
|
||||
cfg["extract"] = get_extract_cfg()
|
||||
cfg["index"] = get_index_cfg()
|
||||
cfg["evaluate"] = get_evaluate_cfg()
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def setup_cfg(cfg: CfgNode, cfg_file: str, cfg_opts: list or None = None) -> CfgNode:
|
||||
"""
|
||||
Load a yaml config file and merge it this CfgNode.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree with default structure.
|
||||
cfg_file (str): the path for yaml config file which is matched with the CfgNode.
|
||||
cfg_opts (list, optional): config (keys, values) in a list (e.g., from command line) into this CfgNode.
|
||||
|
||||
Returns:
|
||||
cfg (CfgNode): the configuration tree with settings in the config file.
|
||||
"""
|
||||
cfg.merge_from_file(cfg_file)
|
||||
cfg.merge_from_list(cfg_opts)
|
||||
cfg.freeze()
|
||||
|
||||
return cfg
|
Binary file not shown.
|
@ -0,0 +1,12 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .builder import build_collate, build_folder, build_transformers, build_loader
|
||||
|
||||
from .config import get_datasets_cfg
|
||||
|
||||
__all__ = [
|
||||
'get_datasets_cfg',
|
||||
'build_collate', 'build_folder', 'build_transformers', 'build_loader',
|
||||
]
|
|
@ -0,0 +1,81 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .registry import COLLATEFNS, FOLDERS, TRANSFORMERS
|
||||
from .collate_fn import CollateFnBase
|
||||
from .folder import FolderBase
|
||||
from .transformer import TransformerBase
|
||||
|
||||
from ..utils import simple_build
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
|
||||
def build_collate(cfg: CfgNode) -> CollateFnBase:
|
||||
"""
|
||||
Instantiate a collate class with the given configuration tree.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
collate (CollateFnBase): a collate class.
|
||||
"""
|
||||
name = cfg["name"]
|
||||
collate = simple_build(name, cfg, COLLATEFNS)
|
||||
return collate
|
||||
|
||||
|
||||
def build_transformers(cfg: CfgNode) -> Compose:
|
||||
"""
|
||||
Instantiate a compose class containing several transforms with the given configuration tree.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
transformers (Compose): a compose class.
|
||||
"""
|
||||
names = cfg["names"]
|
||||
transformers = list()
|
||||
for name in names:
|
||||
transformers.append(simple_build(name, cfg, TRANSFORMERS))
|
||||
transformers = Compose(transformers)
|
||||
return transformers
|
||||
|
||||
|
||||
def build_folder(data_json_path: str, cfg: CfgNode) -> FolderBase:
|
||||
"""
|
||||
Instantiate a folder class with the given configuration tree.
|
||||
|
||||
Args:
|
||||
data_json_path (str): the path of the data json file.
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
folder (FolderBase): a folder class.
|
||||
"""
|
||||
trans = build_transformers(cfg.transformers)
|
||||
folder = simple_build(cfg.folder["name"], cfg.folder, FOLDERS, data_json_path=data_json_path, transformer=trans)
|
||||
return folder
|
||||
|
||||
|
||||
def build_loader(folder: FolderBase, cfg: CfgNode) -> DataLoader:
|
||||
"""
|
||||
Instantiate a data loader class with the given configuration tree.
|
||||
|
||||
Args:
|
||||
folder (FolderBase): the folder function.
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
data_loader (DataLoader): a data loader class.
|
||||
"""
|
||||
co_fn = build_collate(cfg.collate_fn)
|
||||
|
||||
data_loader = DataLoader(folder, cfg["batch_size"], collate_fn=co_fn, num_workers=8, pin_memory=True)
|
||||
|
||||
return data_loader
|
|
@ -0,0 +1,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .collate_fn_impl.collate_fn import CollateFn
|
||||
from .collate_fn_base import CollateFnBase
|
||||
|
||||
__all__ = [
|
||||
'CollateFnBase',
|
||||
'CollateFn',
|
||||
]
|
|
@ -0,0 +1,27 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import ModuleBase
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class CollateFnBase(ModuleBase):
|
||||
"""
|
||||
The base class of collate function.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps: default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(CollateFnBase, self).__init__(hps)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, batch: List[Dict]) -> Dict[str, torch.tensor]:
|
||||
pass
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile
|
||||
|
||||
|
||||
modules = glob.glob(dirname(__file__) + "/*.py")
|
||||
__all__ = [
|
||||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
|
||||
]
|
|
@ -0,0 +1,28 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from ..collate_fn_base import CollateFnBase
|
||||
from ...registry import COLLATEFNS
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
@COLLATEFNS.register
|
||||
class CollateFn(CollateFnBase):
|
||||
"""
|
||||
A wrapper for torch.utils.data.dataloader.default_collate.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(CollateFn, self).__init__(hps)
|
||||
|
||||
def __call__(self, batch: List[Dict]) -> Dict[str, torch.tensor]:
|
||||
assert isinstance(batch, list)
|
||||
assert isinstance(batch[0], dict)
|
||||
return default_collate(batch)
|
|
@ -0,0 +1,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .registry import COLLATEFNS, FOLDERS, TRANSFORMERS
|
||||
|
||||
from ..utils import get_config_from_registry
|
||||
|
||||
def get_collate_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(COLLATEFNS)
|
||||
cfg["name"] = "unknown"
|
||||
return cfg
|
||||
|
||||
|
||||
def get_folder_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(FOLDERS)
|
||||
cfg["name"] = "unknown"
|
||||
return cfg
|
||||
|
||||
|
||||
def get_tranformers_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(TRANSFORMERS)
|
||||
cfg["names"] = ["unknown"]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_datasets_cfg() -> CfgNode:
|
||||
cfg = CfgNode()
|
||||
cfg["collate_fn"] = get_collate_cfg()
|
||||
cfg["folder"] = get_folder_cfg()
|
||||
cfg["transformers"] = get_tranformers_cfg()
|
||||
cfg["batch_size"] = 1
|
||||
return cfg
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .folder_impl.folder import Folder
|
||||
from .folder_base import FolderBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'FolderBase',
|
||||
'Folder',
|
||||
]
|
|
@ -0,0 +1,59 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pickle
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
from ...utils import ModuleBase
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class FolderBase(ModuleBase):
|
||||
"""
|
||||
The base class of folder function.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, data_json_path: str, transformer: callable or None = None, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
data_json_path (str): the path for data json file.
|
||||
transformer (callable): a list of data augmentation operations.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(FolderBase, self).__init__(hps)
|
||||
with open(data_json_path, "rb") as f:
|
||||
self.data_info = pickle.load(f)
|
||||
self.data_json_path = data_json_path
|
||||
self.transformer = transformer
|
||||
|
||||
def __len__(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
pass
|
||||
|
||||
def find_classes(self, info_dicts: Dict) -> (List, Dict):
|
||||
pass
|
||||
|
||||
def read_img(self, path: str) -> Image:
|
||||
"""
|
||||
Load image.
|
||||
|
||||
Args:
|
||||
path (str): the path of the image.
|
||||
|
||||
Returns:
|
||||
image (Image): shape (H, W, C).
|
||||
"""
|
||||
try:
|
||||
img = Image.open(path)
|
||||
img = img.convert("RGB")
|
||||
return img
|
||||
except Exception as e:
|
||||
print('[DataSet]: WARNING image can not be loaded: {}'.format(str(e)))
|
||||
return None
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile
|
||||
|
||||
|
||||
modules = glob.glob(dirname(__file__) + "/*.py")
|
||||
__all__ = [
|
||||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
|
||||
]
|
|
@ -0,0 +1,79 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import pickle
|
||||
|
||||
from ..folder_base import FolderBase
|
||||
from ...registry import FOLDERS
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@FOLDERS.register
|
||||
class Folder(FolderBase):
|
||||
"""
|
||||
A folder function for loading images.
|
||||
|
||||
Hyper-Params:
|
||||
use_bbox: bool, whether use bbox to crop image. When set to true,
|
||||
make sure that bbox attribute is provided in your data json and bbox format is [x1, y1, x2, y2].
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"use_bbox": False,
|
||||
}
|
||||
|
||||
def __init__(self, data_json_path: str, transformer: callable or None = None, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
data_json_path (str): the path for data json file.
|
||||
transformer (callable): a list of data augmentation operations.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(Folder, self).__init__(data_json_path, transformer, hps)
|
||||
self.classes, self.class_to_idx = self.find_classes(self.data_info["info_dicts"])
|
||||
|
||||
def find_classes(self, info_dicts: Dict) -> (List, Dict):
|
||||
"""
|
||||
Get the class names and the mapping relations.
|
||||
|
||||
Args:
|
||||
info_dicts (dict): the dataset information contained the data json file.
|
||||
|
||||
Returns:
|
||||
tuple (list, dict): a list of class names and a dict for projecting class name into int label.
|
||||
"""
|
||||
classes = list()
|
||||
for i in range(len(info_dicts)):
|
||||
if info_dicts[i]["label"] not in classes:
|
||||
classes.append(info_dicts[i]["label"])
|
||||
classes.sort()
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
return classes, class_to_idx
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get the number of total training samples.
|
||||
|
||||
Returns:
|
||||
length (int): the number of total training samples.
|
||||
"""
|
||||
return len(self.data_info["info_dicts"])
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
"""
|
||||
Load the image and convert it to tensor for training.
|
||||
|
||||
Args:
|
||||
idx (int): the serial number of the image.
|
||||
|
||||
Returns:
|
||||
item (dict): the dict containing the image after augmentations, serial number and label.
|
||||
"""
|
||||
info = self.data_info["info_dicts"][idx]
|
||||
img = self.read_img(info["path"])
|
||||
if self._hyper_params["use_bbox"]:
|
||||
assert info["bbox"] is not None, 'image {} does not have a bbox'.format(info["path"])
|
||||
x1, y1, x2, y2 = info["bbox"]
|
||||
box = map(int, (x1, y1, x2, y2))
|
||||
img = img.crop(box)
|
||||
img = self.transformer(img)
|
||||
return {"img": img, "idx": idx, "label": self.class_to_idx[info["label"]]}
|
|
@ -0,0 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from ..utils.registry import Registry
|
||||
|
||||
COLLATEFNS = Registry()
|
||||
FOLDERS = Registry()
|
||||
TRANSFORMERS = Registry()
|
Binary file not shown.
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .transformers_impl.transformers import (DirectResize, PadResize, ShorterResize, CenterCrop,
|
||||
ToTensor, ToCaffeTensor, Normalize, TenCrop, TwoFlip)
|
||||
from .transformers_base import TransformerBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'TransformerBase',
|
||||
'DirectResize', 'PadResize', 'ShorterResize', 'CenterCrop', 'ToTensor', 'ToCaffeTensor',
|
||||
'Normalize', 'TenCrop', 'TwoFlip',
|
||||
]
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import ModuleBase
|
||||
from ...utils import Registry
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class TransformerBase(ModuleBase):
|
||||
"""
|
||||
The base class of data augmentation operations.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(TransformerBase, self).__init__(hps)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, img: Image) -> Image or torch.tensor:
|
||||
pass
|
Binary file not shown.
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile
|
||||
|
||||
|
||||
modules = glob.glob(dirname(__file__) + "/*.py")
|
||||
__all__ = [
|
||||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
|
||||
]
|
|
@ -0,0 +1,259 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from ..transformers_base import TransformerBase
|
||||
from ...registry import TRANSFORMERS
|
||||
from torchvision.transforms import Resize as TResize
|
||||
from torchvision.transforms import TenCrop as TTenCrop
|
||||
from torchvision.transforms import CenterCrop as TCenterCrop
|
||||
from torchvision.transforms import ToTensor as TToTensor
|
||||
from torchvision.transforms.functional import hflip
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@TRANSFORMERS.register
|
||||
class DirectResize(TransformerBase):
|
||||
"""
|
||||
Directly resize image to target size, regardless of h: w ratio.
|
||||
|
||||
Hyper-Params
|
||||
size (sequence): desired output size.
|
||||
interpolation (int): desired interpolation.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"size": (224, 224),
|
||||
"interpolation": Image.BILINEAR,
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(DirectResize, self).__init__(hps)
|
||||
self.t_transformer = TResize(self._hyper_params["size"], self._hyper_params["interpolation"])
|
||||
|
||||
def __call__(self, img: Image) -> Image:
|
||||
return self.t_transformer(img)
|
||||
|
||||
|
||||
@TRANSFORMERS.register
|
||||
class PadResize(TransformerBase):
|
||||
"""
|
||||
Resize image's longer edge to target size, and then pad the shorter edge to target size.
|
||||
|
||||
Hyper-Params
|
||||
size (int): desired output size of the longer edge.
|
||||
padding_v (sequence): padding pixel value.
|
||||
interpolation (int): desired interpolation.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"size": 224,
|
||||
"padding_v": [124, 116, 104],
|
||||
"interpolation": Image.BILINEAR,
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps: default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(PadResize, self).__init__(hps)
|
||||
|
||||
def __call__(self, img: Image) -> Image:
|
||||
target_size = self._hyper_params["size"]
|
||||
padding_v = tuple(self._hyper_params["padding_v"])
|
||||
interpolation = self._hyper_params["interpolation"]
|
||||
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
img = img.resize((int(target_size), int(h * target_size * 1.0 / w)), interpolation)
|
||||
else:
|
||||
img = img.resize((int(w * target_size * 1.0 / h), int(target_size)), interpolation)
|
||||
|
||||
ret_img = Image.new("RGB", (target_size, target_size), padding_v)
|
||||
w, h = img.size
|
||||
st_w = int((ret_img.size[0] - w) / 2.0)
|
||||
st_h = int((ret_img.size[1] - h) / 2.0)
|
||||
ret_img.paste(img, (st_w, st_h))
|
||||
return ret_img
|
||||
|
||||
|
||||
@TRANSFORMERS.register
|
||||
class ShorterResize(TransformerBase):
|
||||
"""
|
||||
Resize image's shorter edge to target size, while keep h: w ratio.
|
||||
|
||||
Hyper-Params
|
||||
size (int): desired output size.
|
||||
interpolation (int): desired interpolation.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"size": 224,
|
||||
"interpolation": Image.BILINEAR,
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(ShorterResize, self).__init__(hps)
|
||||
self.t_transformer = TResize(self._hyper_params["size"], self._hyper_params["interpolation"])
|
||||
|
||||
def __call__(self, img: Image) -> Image:
|
||||
return self.t_transformer(img)
|
||||
|
||||
|
||||
@TRANSFORMERS.register
|
||||
class CenterCrop(TransformerBase):
|
||||
"""
|
||||
A wrapper from CenterCrop in pytorch, see torchvision.transformers.CenterCrop for explanation.
|
||||
|
||||
Hyper-Params
|
||||
size(sequence or int): desired output size.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"size": 224,
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(CenterCrop, self).__init__(hps)
|
||||
self.t_transformer = TCenterCrop(self._hyper_params["size"])
|
||||
|
||||
def __call__(self, img: Image) -> Image:
|
||||
return self.t_transformer(img)
|
||||
|
||||
|
||||
@TRANSFORMERS.register
|
||||
class ToTensor(TransformerBase):
|
||||
"""
|
||||
A wrapper from ToTensor in pytorch, see torchvision.transformers.ToTensor for explanation.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(ToTensor, self).__init__(hps)
|
||||
self.t_transformer = TToTensor()
|
||||
|
||||
def __call__(self, imgs: Image or tuple) -> torch.Tensor:
|
||||
if not isinstance(imgs, tuple):
|
||||
imgs = [imgs]
|
||||
ret_tensor = list()
|
||||
for img in imgs:
|
||||
ret_tensor.append(self.t_transformer(img))
|
||||
ret_tensor = torch.stack(ret_tensor, dim=0)
|
||||
return ret_tensor
|
||||
|
||||
|
||||
@TRANSFORMERS.register
|
||||
class ToCaffeTensor(TransformerBase):
|
||||
"""
|
||||
Create tensors for models trained in caffe.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(ToCaffeTensor, self).__init__(hps)
|
||||
|
||||
def __call__(self, imgs: Image or tuple) -> torch.tensor:
|
||||
if not isinstance(imgs, tuple):
|
||||
imgs = [imgs]
|
||||
|
||||
ret_tensor = list()
|
||||
for img in imgs:
|
||||
img = np.array(img, np.int32, copy=False)
|
||||
r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
|
||||
img = np.stack([b, g, r], axis=2)
|
||||
img = torch.from_numpy(img)
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
img = img.float()
|
||||
ret_tensor.append(img)
|
||||
ret_tensor = torch.stack(ret_tensor, dim=0)
|
||||
return ret_tensor
|
||||
|
||||
|
||||
@TRANSFORMERS.register
|
||||
class Normalize(TransformerBase):
|
||||
"""
|
||||
Normalize a tensor image with mean and standard deviation.
|
||||
|
||||
Hyper-Params
|
||||
mean (sequence): sequence of means for each channel.
|
||||
std (sequence): sequence of standard deviations for each channel.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(Normalize, self).__init__(hps)
|
||||
for v in ["mean", "std"]:
|
||||
self.__dict__[v] = np.array(self._hyper_params[v])[None, :, None, None]
|
||||
self.__dict__[v] = torch.from_numpy(self.__dict__[v]).float()
|
||||
|
||||
def __call__(self, tensor: torch.tensor) -> torch.tensor:
|
||||
assert tensor.ndimension() == 4
|
||||
tensor.sub_(self.mean).div_(self.std)
|
||||
return tensor
|
||||
|
||||
|
||||
@TRANSFORMERS.register
|
||||
class TenCrop(TransformerBase):
|
||||
"""
|
||||
A wrapper from TenCrop in pytorch,see torchvision.transformers.TenCrop for explanation.
|
||||
|
||||
Hyper-Params
|
||||
size (sequence or int): desired output size.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"size": 224,
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(TenCrop, self).__init__(hps)
|
||||
self.t_transformer = TTenCrop(self._hyper_params["size"])
|
||||
|
||||
def __call__(self, img: Image) -> Image:
|
||||
return self.t_transformer(img)
|
||||
|
||||
|
||||
@TRANSFORMERS.register
|
||||
class TwoFlip(TransformerBase):
|
||||
"""
|
||||
Return the image itself and its horizontal flipped one.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(TwoFlip, self).__init__(hps)
|
||||
|
||||
def __call__(self, img: Image) -> (Image, Image):
|
||||
return img, hflip(img)
|
|
@ -0,0 +1,12 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .config import get_evaluate_cfg
|
||||
from .builder import build_evaluate_helper
|
||||
|
||||
|
||||
__all__ = [
|
||||
'get_evaluate_cfg',
|
||||
'build_evaluate_helper',
|
||||
]
|
|
@ -0,0 +1,39 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .registry import EVALUATORS
|
||||
from .evaluator import EvaluatorBase
|
||||
from .helper import EvaluateHelper
|
||||
|
||||
from ..utils import simple_build
|
||||
|
||||
|
||||
def build_evaluator(cfg: CfgNode) -> EvaluatorBase:
|
||||
"""
|
||||
Instantiate a evaluator class.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
evaluator (EvaluatorBase): a evaluator class.
|
||||
"""
|
||||
name = cfg["name"]
|
||||
evaluator = simple_build(name, cfg, EVALUATORS)
|
||||
return evaluator
|
||||
|
||||
|
||||
def build_evaluate_helper(cfg: CfgNode) -> EvaluateHelper:
|
||||
"""
|
||||
Instantiate a evaluate helper class.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
helper (EvaluateHelper): a evaluate helper class.
|
||||
"""
|
||||
evaluator = build_evaluator(cfg.evaluator)
|
||||
helper = EvaluateHelper(evaluator)
|
||||
return helper
|
|
@ -0,0 +1,18 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .registry import EVALUATORS
|
||||
|
||||
from ..utils import get_config_from_registry
|
||||
|
||||
|
||||
def get_evaluator_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(EVALUATORS)
|
||||
cfg["name"] = "unknown"
|
||||
return cfg
|
||||
|
||||
def get_evaluate_cfg() -> CfgNode:
|
||||
cfg = CfgNode()
|
||||
cfg["evaluator"] = get_evaluator_cfg()
|
||||
return cfg
|
|
@ -0,0 +1,13 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .evaluators_impl.overall import OverAll
|
||||
from .evaluators_impl.oxford_overall import OxfordOverAll
|
||||
from .evaluators_impl.reid_overall import ReIDOverAll
|
||||
from .evaluators_base import EvaluatorBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'EvaluatorBase',
|
||||
'OverAll', 'OxfordOverAll',
|
||||
'ReIDOverAll',
|
||||
]
|
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from ...utils import ModuleBase
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class EvaluatorBase(ModuleBase):
|
||||
"""
|
||||
The base class of evaluators which compute mAP and recall.
|
||||
"""
|
||||
default_hyper_params = {}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
super(EvaluatorBase, self).__init__(hps)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, query_result: Dict, gallery_info: Dict) -> (float, Dict):
|
||||
pass
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile
|
||||
|
||||
|
||||
modules = glob.glob(dirname(__file__) + "/*.py")
|
||||
__all__ = [
|
||||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
|
||||
]
|
|
@ -0,0 +1,86 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..evaluators_base import EvaluatorBase
|
||||
from ...registry import EVALUATORS
|
||||
|
||||
from sklearn.metrics import average_precision_score
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@EVALUATORS.register
|
||||
class OverAll(EvaluatorBase):
|
||||
"""
|
||||
A evaluator for mAP and recall computation.
|
||||
|
||||
Hyper-Params
|
||||
recall_k (sequence): positions of recalls to be calculated.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"recall_k": [1, 2, 4, 8],
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(OverAll, self).__init__(hps)
|
||||
self._hyper_params["recall_k"] = np.sort(self._hyper_params["recall_k"])
|
||||
|
||||
def compute_recall_at_k(self, gt: List[bool], result_dict: Dict) -> None:
|
||||
"""
|
||||
Calculate the recall at each position.
|
||||
|
||||
Args:
|
||||
gt (sequence): a list of bool indicating if the result is equal to the label.
|
||||
result_dict (dict): a dict of indexing results.
|
||||
"""
|
||||
ks = self._hyper_params["recall_k"]
|
||||
gt = gt[:ks[-1]]
|
||||
first_tp = np.where(gt)[0]
|
||||
if len(first_tp) == 0:
|
||||
return
|
||||
for k in ks:
|
||||
if k >= first_tp[0] + 1:
|
||||
result_dict[k] = result_dict[k] + 1
|
||||
|
||||
def __call__(self, query_result: List, gallery_info: List) -> (float, Dict):
|
||||
"""
|
||||
Calculate the mAP and recall for the indexing results.
|
||||
|
||||
Args:
|
||||
query_result (list): a list of indexing results.
|
||||
gallery_info (list): a list of gallery set information.
|
||||
|
||||
Returns:
|
||||
tuple (float, dict): mean average precision and recall for each position.
|
||||
"""
|
||||
aps = list()
|
||||
|
||||
# For mAP calculation
|
||||
pseudo_score = np.arange(0, len(gallery_info))[::-1]
|
||||
|
||||
recall_at_k = dict()
|
||||
for k in self._hyper_params["recall_k"]:
|
||||
recall_at_k[k] = 0
|
||||
|
||||
gallery_label = np.array([gallery_info[idx]["label_idx"] for idx in range(len(gallery_info))])
|
||||
for i in range(len(query_result)):
|
||||
ranked_idx = query_result[i]["ranked_neighbors_idx"]
|
||||
gt = (gallery_label[query_result[i]["ranked_neighbors_idx"]] == query_result[i]["label_idx"])
|
||||
aps.append(average_precision_score(gt, pseudo_score[:len(gt)]))
|
||||
|
||||
# deal with 'gallery as query' test
|
||||
if gallery_info[ranked_idx[0]]["path"] == query_result[i]["path"]:
|
||||
gt.pop(0)
|
||||
self.compute_recall_at_k(gt, recall_at_k)
|
||||
|
||||
mAP = np.mean(aps) * 100
|
||||
|
||||
for k in recall_at_k:
|
||||
recall_at_k[k] = recall_at_k[k] * 100 / len(query_result)
|
||||
|
||||
return mAP, recall_at_k
|
|
@ -0,0 +1,132 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..evaluators_base import EvaluatorBase
|
||||
from ...registry import EVALUATORS
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@EVALUATORS.register
|
||||
class OxfordOverAll(EvaluatorBase):
|
||||
"""
|
||||
A evaluator for Oxford mAP and recall computation.
|
||||
|
||||
Hyper-Params
|
||||
gt_dir (str): the path of the oxford ground truth.
|
||||
recall_k (sequence): positions of recalls to be calculated.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"gt_dir": "/data/cbir/oxford/gt",
|
||||
"recall_k": [1, 2, 4, 8],
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(OxfordOverAll, self).__init__(hps)
|
||||
assert os.path.exists(self._hyper_params["gt_dir"]), 'the ground truth files must be existed!'
|
||||
|
||||
@staticmethod
|
||||
def _load_tag_set(file: str) -> set:
|
||||
"""
|
||||
Read information from the txt file.
|
||||
|
||||
Args:
|
||||
file (str): the path of the txt file.
|
||||
|
||||
Returns:
|
||||
ret (set): the information.
|
||||
"""
|
||||
ret = set()
|
||||
with open(file, "r") as f:
|
||||
for line in f.readlines():
|
||||
ret.add(line.strip())
|
||||
return ret
|
||||
|
||||
def compute_ap(self, query_tag: str, ranked_tags: List[str], recall_at_k: Dict) -> float:
|
||||
"""
|
||||
Calculate the ap for one query.
|
||||
|
||||
Args:
|
||||
query_tag (str): name of the query image.
|
||||
ranked_tags (list): a list of label of the indexing results.
|
||||
recall_at_k (dict): positions of recalls to be calculated.
|
||||
|
||||
Returns:
|
||||
ap (float): ap for one query.
|
||||
"""
|
||||
gt_prefix = os.path.join(self._hyper_params["gt_dir"], query_tag)
|
||||
good_set = self._load_tag_set(gt_prefix + "_good.txt")
|
||||
ok_set = self._load_tag_set(gt_prefix + "_ok.txt")
|
||||
junk_set = self._load_tag_set(gt_prefix + "_junk.txt")
|
||||
pos_set = set.union(good_set, ok_set)
|
||||
|
||||
old_recall = 0.0
|
||||
old_precision = 1.0
|
||||
ap = 0.0
|
||||
intersect_size = 0.0
|
||||
i = 0
|
||||
first_tp = -1
|
||||
|
||||
for tag in ranked_tags:
|
||||
if tag in junk_set:
|
||||
continue
|
||||
if tag in pos_set:
|
||||
intersect_size += 1
|
||||
|
||||
# Remember that in oxford query mode, the first element in rank_list is the query itself.
|
||||
if first_tp == -1:
|
||||
first_tp = i
|
||||
|
||||
recall = intersect_size * 1.0 / len(pos_set)
|
||||
precision = intersect_size / (i + 1.0)
|
||||
|
||||
ap += (recall - old_recall) * ((old_precision + precision) / 2.0)
|
||||
old_recall = recall
|
||||
old_precision = precision
|
||||
i += 1
|
||||
|
||||
if first_tp != -1:
|
||||
ks = self._hyper_params["recall_k"]
|
||||
for k in ks:
|
||||
if k >= first_tp + 1:
|
||||
recall_at_k[k] = recall_at_k[k] + 1
|
||||
|
||||
return ap
|
||||
|
||||
def __call__(self, query_result: List, gallery_info: List) -> (float, Dict):
|
||||
"""
|
||||
Calculate the mAP and recall for the indexing results.
|
||||
|
||||
Args:
|
||||
query_result (list): a list of indexing results.
|
||||
gallery_info (list): a list of gallery set information.
|
||||
|
||||
Returns:
|
||||
tuple (float, dict): mean average precision and recall for each position.
|
||||
"""
|
||||
aps = list()
|
||||
|
||||
recall_at_k = dict()
|
||||
for k in self._hyper_params["recall_k"]:
|
||||
recall_at_k[k] = 0
|
||||
|
||||
for i in range(len(query_result)):
|
||||
ranked_idx = query_result[i]["ranked_neighbors_idx"]
|
||||
ranked_tags = list()
|
||||
for idx in ranked_idx:
|
||||
ranked_tags.append(gallery_info[idx]["label"])
|
||||
aps.append(self.compute_ap(query_result[i]["query_name"], ranked_tags, recall_at_k))
|
||||
|
||||
mAP = np.mean(aps) * 100
|
||||
|
||||
for k in recall_at_k:
|
||||
recall_at_k[k] = recall_at_k[k] * 100 / len(query_result)
|
||||
|
||||
return mAP, recall_at_k
|
|
@ -0,0 +1,144 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from ..evaluators_base import EvaluatorBase
|
||||
from ...registry import EVALUATORS
|
||||
from sklearn.metrics import average_precision_score
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@EVALUATORS.register
|
||||
class ReIDOverAll(EvaluatorBase):
|
||||
"""
|
||||
A evaluator for Re-ID task mAP and recall computation.
|
||||
|
||||
Hyper-Params
|
||||
recall_k (sequence): positions of recalls to be calculated.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"recall_k": [1, 2, 4, 8],
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(ReIDOverAll, self).__init__(hps)
|
||||
self._hyper_params["recall_k"] = np.sort(self._hyper_params["recall_k"])
|
||||
|
||||
def compute_ap_cmc(self, index: np.ndarray, good_index: np.ndarray, junk_index: np.ndarray) -> (float, torch.tensor):
|
||||
"""
|
||||
Calculate the ap and cmc for one query.
|
||||
|
||||
Args:
|
||||
index (np.ndarray): the sorted retrieval index for one query.
|
||||
good_index (np.ndarray): the index for good matching.
|
||||
junk_index (np.ndarray): the index for junk matching.
|
||||
|
||||
Returns:
|
||||
tupele (float, torch.tensor): (ap, cmc), ap and cmc for one query.
|
||||
"""
|
||||
ap = 0
|
||||
cmc = torch.IntTensor(len(index)).zero_()
|
||||
if good_index.size == 0:
|
||||
cmc[0] = -1
|
||||
return ap, cmc
|
||||
|
||||
# remove junk_index
|
||||
mask = np.in1d(index, junk_index, invert=True)
|
||||
index = index[mask]
|
||||
|
||||
# find good_index index
|
||||
ngood = len(good_index)
|
||||
mask = np.in1d(index, good_index)
|
||||
rows_good = np.argwhere(mask == True)
|
||||
rows_good = rows_good.flatten()
|
||||
|
||||
cmc[rows_good[0]:] = 1
|
||||
for i in range(ngood):
|
||||
d_recall = 1.0 / ngood
|
||||
precision = (i + 1) * 1.0 / (rows_good[i] + 1)
|
||||
if rows_good[i] != 0:
|
||||
old_precision = i * 1.0 / rows_good[i]
|
||||
else:
|
||||
old_precision = 1.0
|
||||
ap = ap + d_recall * (old_precision + precision) / 2
|
||||
|
||||
return ap, cmc
|
||||
|
||||
def evaluate_once(self, index: np.ndarray, ql: int, qc: int, gl: np.ndarray, gc: np.ndarray) -> (float, torch.tensor):
|
||||
"""
|
||||
Generate the indexes and calculate the ap and cmc for one query.
|
||||
|
||||
Args:
|
||||
index (np.ndarray): the sorted retrieval index for one query.
|
||||
ql (int): the person id of the query.
|
||||
qc (int): the camera id of the query.
|
||||
gl (np.ndarray): the person ids of the gallery set.
|
||||
gc (np.ndarray): the camera ids of the gallery set.
|
||||
|
||||
Returns:
|
||||
tuple (float, torch.tensor): ap and cmc for one query.
|
||||
"""
|
||||
query_index = (ql == gl)
|
||||
query_index = np.argwhere(query_index)
|
||||
|
||||
camera_index = (qc == gc)
|
||||
camera_index = np.argwhere(camera_index)
|
||||
|
||||
# good index
|
||||
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
|
||||
|
||||
# junk index
|
||||
junk_index1 = np.argwhere(gl == -1)
|
||||
junk_index2 = np.intersect1d(query_index, camera_index)
|
||||
junk_index = np.append(junk_index2, junk_index1)
|
||||
|
||||
AP_tmp, CMC_tmp = self.compute_ap_cmc(index, good_index, junk_index)
|
||||
|
||||
return AP_tmp, CMC_tmp
|
||||
|
||||
def __call__(self, query_result: List, gallery_info: List) -> (float, Dict):
|
||||
"""
|
||||
Calculate the mAP and recall for the indexing results.
|
||||
|
||||
Args:
|
||||
query_result (list): a list of indexing results.
|
||||
gallery_info (list): a list of gallery set information.
|
||||
|
||||
Returns:
|
||||
tuple (float, dict): mean average precision and recall for each position.
|
||||
"""
|
||||
AP = 0.0
|
||||
CMC = torch.IntTensor(range(len(gallery_info))).zero_()
|
||||
gallery_label = np.array([int(gallery_info[idx]["label"]) for idx in range(len(gallery_info))])
|
||||
gallery_cam = np.array([int(gallery_info[idx]["cam"]) for idx in range(len(gallery_info))])
|
||||
|
||||
recall_at_k = dict()
|
||||
for k in self._hyper_params["recall_k"]:
|
||||
recall_at_k[k] = 0
|
||||
|
||||
for i in range(len(query_result)):
|
||||
AP_tmp, CMC_tmp = self.evaluate_once(np.array(query_result[i]["ranked_neighbors_idx"]),
|
||||
int(query_result[i]["label"]), int(query_result[i]["cam"]),
|
||||
gallery_label, gallery_cam)
|
||||
if CMC_tmp[0] == -1:
|
||||
continue
|
||||
CMC = CMC + CMC_tmp
|
||||
AP += AP_tmp
|
||||
|
||||
CMC = CMC.float()
|
||||
CMC = CMC / len(query_result) # average CMC
|
||||
|
||||
for k in recall_at_k:
|
||||
recall_at_k[k] = (CMC[k-1] * 100).item()
|
||||
|
||||
mAP = AP / len(query_result) * 100
|
||||
|
||||
return mAP, recall_at_k
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .helper import EvaluateHelper
|
||||
|
||||
|
||||
__all__ = [
|
||||
'EvaluateHelper',
|
||||
]
|
|
@ -0,0 +1,49 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from ..evaluator import EvaluatorBase
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
class EvaluateHelper:
|
||||
"""
|
||||
A helper class to evaluate query results.
|
||||
"""
|
||||
def __init__(self, evaluator: EvaluatorBase):
|
||||
"""
|
||||
Args:
|
||||
evaluator: a evaluator class.
|
||||
"""
|
||||
self.evaluator = evaluator
|
||||
self.recall_k = evaluator.default_hyper_params["recall_k"]
|
||||
|
||||
def show_results(self, mAP: float, recall_at_k: Dict) -> None:
|
||||
"""
|
||||
Show the evaluate results.
|
||||
|
||||
Args:
|
||||
mAP (float): mean average precision.
|
||||
recall_at_k (Dict): recall at the k position.
|
||||
"""
|
||||
repr_str = "mAP: {:.1f}\n".format(mAP)
|
||||
|
||||
for k in self.recall_k:
|
||||
repr_str += "R@{}: {:.1f}\t".format(k, recall_at_k[k])
|
||||
|
||||
print('--------------- Retrieval Evaluation ------------')
|
||||
print(repr_str)
|
||||
|
||||
def do_eval(self, query_result_info: List, gallery_info: List) -> (float, Dict):
|
||||
"""
|
||||
Get the evaluate results.
|
||||
|
||||
Args:
|
||||
query_result_info (list): a list of indexing results.
|
||||
gallery_info (list): a list of gallery set information.
|
||||
|
||||
Returns:
|
||||
tuple (float, Dict): mean average precision and recall for each position.
|
||||
"""
|
||||
mAP, recall_at_k = self.evaluator(query_result_info, gallery_info)
|
||||
|
||||
return mAP, recall_at_k
|
|
@ -0,0 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from ..utils import Registry
|
||||
|
||||
EVALUATORS = Registry()
|
Binary file not shown.
|
@ -0,0 +1,15 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .config import get_extractor_cfg, get_aggregators_cfg, get_extract_cfg
|
||||
from .builder import build_aggregators, build_extractor, build_extract_helper
|
||||
|
||||
from .utils import split_dataset, make_data_json
|
||||
|
||||
|
||||
__all__ = [
|
||||
'get_extract_cfg',
|
||||
'build_aggregators', 'build_extractor', 'build_extract_helper',
|
||||
'split_dataset', 'make_data_json',
|
||||
]
|
Binary file not shown.
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .aggregators_impl.crow import Crow
|
||||
from .aggregators_impl.gap import GAP
|
||||
from .aggregators_impl.gem import GeM
|
||||
from .aggregators_impl.gmp import GMP
|
||||
from .aggregators_impl.pwa import PWA
|
||||
from .aggregators_impl.r_mac import RMAC
|
||||
from .aggregators_impl.scda import SCDA
|
||||
from .aggregators_impl.spoc import SPoC
|
||||
|
||||
from .aggregators_base import AggregatorBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'AggregatorBase',
|
||||
'Crow', 'GAP', 'GeM', 'GMP', 'PWA', 'RMAC', 'SCDA', 'SPoC',
|
||||
'build_aggregators',
|
||||
]
|
|
@ -0,0 +1,27 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import ModuleBase
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class AggregatorBase(ModuleBase):
|
||||
r"""
|
||||
The base class for feature aggregators.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(AggregatorBase, self).__init__(hps)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
||||
pass
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile
|
||||
|
||||
|
||||
modules = glob.glob(dirname(__file__) + "/*.py")
|
||||
__all__ = [
|
||||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
|
||||
]
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from ..aggregators_base import AggregatorBase
|
||||
from ...registry import AGGREGATORS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@AGGREGATORS.register
|
||||
class Crow(AggregatorBase):
|
||||
"""
|
||||
Cross-dimensional Weighting for Aggregated Deep Convolutional Features.
|
||||
c.f. https://arxiv.org/pdf/1512.04065.pdf
|
||||
|
||||
Hyper-Params
|
||||
spatial_a (float): hyper-parameter for calculating spatial weight.
|
||||
spatial_b (float): hyper-parameter for calculating spatial weight.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"spatial_a": 2.0,
|
||||
"spatial_b": 2.0,
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
self.first_show = True
|
||||
super(Crow, self).__init__(hps)
|
||||
|
||||
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
||||
spatial_a = self._hyper_params["spatial_a"]
|
||||
spatial_b = self._hyper_params["spatial_b"]
|
||||
|
||||
ret = dict()
|
||||
for key in features:
|
||||
fea = features[key]
|
||||
if fea.ndimension() == 4:
|
||||
spatial_weight = fea.sum(dim=1, keepdims=True)
|
||||
z = (spatial_weight ** spatial_a).sum(dim=(2, 3), keepdims=True)
|
||||
z = z ** (1.0 / spatial_a)
|
||||
spatial_weight = (spatial_weight / z) ** (1.0 / spatial_b)
|
||||
|
||||
c, w, h = fea.shape[1:]
|
||||
nonzeros = (fea!=0).float().sum(dim=(2, 3)) / 1.0 / (w * h) + 1e-6
|
||||
channel_weight = torch.log(nonzeros.sum(dim=1, keepdims=True) / nonzeros)
|
||||
|
||||
fea = fea * spatial_weight
|
||||
fea = fea.sum(dim=(2, 3))
|
||||
fea = fea * channel_weight
|
||||
|
||||
ret[key + "_{}".format(self.__class__.__name__)] = fea
|
||||
else:
|
||||
# In case of fc feature.
|
||||
assert fea.ndimension() == 2
|
||||
if self.first_show:
|
||||
print("[Crow Aggregator]: find 2-dimension feature map, skip aggregation")
|
||||
self.first_show = False
|
||||
ret[key] = fea
|
||||
return ret
|
|
@ -0,0 +1,39 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from ..aggregators_base import AggregatorBase
|
||||
from ...registry import AGGREGATORS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@AGGREGATORS.register
|
||||
class GAP(AggregatorBase):
|
||||
"""
|
||||
Global average pooling.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
self.first_show = True
|
||||
super(GAP, self).__init__(hps)
|
||||
|
||||
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
||||
ret = dict()
|
||||
for key in features:
|
||||
fea = features[key]
|
||||
if fea.ndimension() == 4:
|
||||
fea = fea.mean(dim=3).mean(dim=2)
|
||||
ret[key + "_{}".format(self.__class__.__name__)] = fea
|
||||
else:
|
||||
# In case of fc feature.
|
||||
assert fea.ndimension() == 2
|
||||
if self.first_show:
|
||||
print("[GAP Aggregator]: find 2-dimension feature map, skip aggregation")
|
||||
self.first_show = False
|
||||
ret[key] = fea
|
||||
return ret
|
|
@ -0,0 +1,51 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from ..aggregators_base import AggregatorBase
|
||||
from ...registry import AGGREGATORS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@AGGREGATORS.register
|
||||
class GeM(AggregatorBase):
|
||||
"""
|
||||
Generalized-mean pooling.
|
||||
c.f. https://pdfs.semanticscholar.org/a2ca/e0ed91d8a3298b3209fc7ea0a4248b914386.pdf
|
||||
|
||||
Hyper-Params
|
||||
p (float): hyper-parameter for calculating generalized mean. If p = 1, GeM is equal to global average pooling, and
|
||||
if p = +infinity, GeM is equal to global max pooling.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"p": 3.0,
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
self.first_show = True
|
||||
super(GeM, self).__init__(hps)
|
||||
|
||||
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
||||
p = self._hyper_params["p"]
|
||||
|
||||
ret = dict()
|
||||
for key in features:
|
||||
fea = features[key]
|
||||
if fea.ndimension() == 4:
|
||||
fea = fea ** p
|
||||
h, w = fea.shape[2:]
|
||||
fea = fea.sum(dim=(2, 3)) * 1.0 / w / h
|
||||
fea = fea ** (1.0 / p)
|
||||
ret[key + "_{}".format(self.__class__.__name__)] = fea
|
||||
else:
|
||||
# In case of fc feature.
|
||||
assert fea.ndimension() == 2
|
||||
if self.first_show:
|
||||
print("[GeM Aggregator]: find 2-dimension feature map, skip aggregation")
|
||||
self.first_show = False
|
||||
ret[key] = fea
|
||||
return ret
|
|
@ -0,0 +1,39 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from ..aggregators_base import AggregatorBase
|
||||
from ...registry import AGGREGATORS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@AGGREGATORS.register
|
||||
class GMP(AggregatorBase):
|
||||
"""
|
||||
Global maximum pooling
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
self.first_show = True
|
||||
super(GMP, self).__init__(hps)
|
||||
|
||||
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
||||
ret = dict()
|
||||
for key in features:
|
||||
fea = features[key]
|
||||
if fea.ndimension() == 4:
|
||||
fea = (fea.max(dim=3)[0]).max(dim=2)[0]
|
||||
ret[key + "_{}".format(self.__class__.__name__)] = fea
|
||||
else:
|
||||
# In case of fc feature.
|
||||
assert fea.ndimension() == 2
|
||||
if self.first_show:
|
||||
print("[GMP Aggregator]: find 2-dimension feature map, skip aggregation")
|
||||
self.first_show = False
|
||||
ret[key] = fea
|
||||
return ret
|
|
@ -0,0 +1,81 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from ..aggregators_base import AggregatorBase
|
||||
from ...registry import AGGREGATORS
|
||||
from ....index.utils import feature_loader
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@AGGREGATORS.register
|
||||
class PWA(AggregatorBase):
|
||||
"""
|
||||
Part-based Weighting Aggregation.
|
||||
c.f. https://arxiv.org/abs/1705.01247
|
||||
|
||||
Hyper-Params
|
||||
train_fea_dir (str): path of feature dir for selecting channels.
|
||||
n_proposal (int): number of proposals to be selected.
|
||||
alpha (float): alpha for calculate spatial weight.
|
||||
beta (float): beta for calculate spatial weight.
|
||||
"""
|
||||
|
||||
default_hyper_params = {
|
||||
"train_fea_dir": "",
|
||||
"n_proposal": 25,
|
||||
"alpha": 2.0,
|
||||
"beta": 2.0,
|
||||
"train_fea_names": ["pool5_GAP"],
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(PWA, self).__init__(hps)
|
||||
self.first_show = True
|
||||
assert self._hyper_params["train_fea_dir"] != ""
|
||||
self.selected_proposals_idx = None
|
||||
self.train()
|
||||
|
||||
def train(self) -> None:
|
||||
n_proposal = self._hyper_params["n_proposal"]
|
||||
stacked_fea, _, pos_info = feature_loader.load(
|
||||
self._hyper_params["train_fea_dir"],
|
||||
self._hyper_params["train_fea_names"]
|
||||
)
|
||||
self.selected_proposals_idx = dict()
|
||||
for fea_name in pos_info:
|
||||
st_idx, ed_idx = pos_info[fea_name]
|
||||
fea = stacked_fea[:, st_idx: ed_idx]
|
||||
assert fea.ndim == 2, "invalid train feature"
|
||||
channel_variance = np.std(fea, axis=0)
|
||||
selected_idx = channel_variance.argsort()[-n_proposal:]
|
||||
fea_name = "_".join(fea_name.split("_")[:-1])
|
||||
self.selected_proposals_idx[fea_name] = selected_idx.tolist()
|
||||
|
||||
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
||||
alpha, beta = self._hyper_params["alpha"], self._hyper_params["beta"]
|
||||
ret = dict()
|
||||
for key in features:
|
||||
fea = features[key]
|
||||
if fea.ndimension() == 4:
|
||||
assert (key in self.selected_proposals_idx), '{} is not in the {}'.format(key, self.selected_proposals_idx.keys())
|
||||
proposals_idx = np.array(self.selected_proposals_idx[key])
|
||||
proposals = fea[:, proposals_idx, :, :]
|
||||
power_norm = (proposals ** alpha).sum(dim=(2, 3), keepdims=True) ** (1.0 / alpha)
|
||||
normed_proposals = (proposals / (power_norm + 1e-5)) ** (1.0 / beta)
|
||||
fea = (fea[:, None, :, :, :] * normed_proposals[:, :, None, :, :]).sum(dim=(3, 4))
|
||||
fea = fea.view(fea.shape[0], -1)
|
||||
ret[key + "_{}".format(self.__class__.__name__)] = fea
|
||||
else:
|
||||
# In case of fc feature.
|
||||
assert fea.ndimension() == 2
|
||||
if self.first_show:
|
||||
print("[PWA Aggregator]: find 2-dimension feature map, skip aggregation")
|
||||
self.first_show = False
|
||||
ret[key] = fea
|
||||
return ret
|
|
@ -0,0 +1,118 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from ..aggregators_base import AggregatorBase
|
||||
from ...registry import AGGREGATORS
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
@AGGREGATORS.register
|
||||
class RMAC(AggregatorBase):
|
||||
"""
|
||||
Regional Maximum activation of convolutions (R-MAC).
|
||||
c.f. https://arxiv.org/pdf/1511.05879.pdf
|
||||
|
||||
Hyper-Params
|
||||
level_n (int): number of levels for selecting regions.
|
||||
"""
|
||||
|
||||
default_hyper_params = {
|
||||
"level_n": 3,
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(RMAC, self).__init__(hps)
|
||||
self.first_show = True
|
||||
self.cached_regions = dict()
|
||||
|
||||
def _get_regions(self, h: int, w: int) -> List:
|
||||
"""
|
||||
Divide the image into several regions.
|
||||
|
||||
Args:
|
||||
h (int): height for dividing regions.
|
||||
w (int): width for dividing regions.
|
||||
|
||||
Returns:
|
||||
regions (List): a list of region positions.
|
||||
"""
|
||||
if (h, w) in self.cached_regions:
|
||||
return self.cached_regions[(h, w)]
|
||||
|
||||
m = 1
|
||||
n_h, n_w = 1, 1
|
||||
regions = list()
|
||||
if h != w:
|
||||
min_edge = min(h, w)
|
||||
left_space = max(h, w) - min(h, w)
|
||||
iou_target = 0.4
|
||||
iou_best = 1.0
|
||||
while True:
|
||||
iou_tmp = (min_edge ** 2 - min_edge * (left_space // m)) / (min_edge ** 2)
|
||||
|
||||
# small m maybe result in non-overlap
|
||||
if iou_tmp <= 0:
|
||||
m += 1
|
||||
continue
|
||||
|
||||
if abs(iou_tmp - iou_target) <= iou_best:
|
||||
iou_best = abs(iou_tmp - iou_target)
|
||||
m += 1
|
||||
else:
|
||||
break
|
||||
if h < w:
|
||||
n_w = m
|
||||
else:
|
||||
n_h = m
|
||||
|
||||
for i in range(self._hyper_params["level_n"]):
|
||||
region_width = int(2 * 1.0 / (i + 2) * min(h, w))
|
||||
step_size_h = (h - region_width) // n_h
|
||||
step_size_w = (w - region_width) // n_w
|
||||
|
||||
for x in range(n_h):
|
||||
for y in range(n_w):
|
||||
st_x = step_size_h * x
|
||||
ed_x = st_x + region_width - 1
|
||||
assert ed_x < h
|
||||
st_y = step_size_w * y
|
||||
ed_y = st_y + region_width - 1
|
||||
assert ed_y < w
|
||||
regions.append((st_x, st_y, ed_x, ed_y))
|
||||
|
||||
n_h += 1
|
||||
n_w += 1
|
||||
|
||||
self.cached_regions[(h, w)] = regions
|
||||
return regions
|
||||
|
||||
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
||||
ret = dict()
|
||||
for key in features:
|
||||
fea = features[key]
|
||||
if fea.ndimension() == 4:
|
||||
h, w = fea.shape[2:]
|
||||
final_fea = None
|
||||
regions = self._get_regions(h, w)
|
||||
for _, r in enumerate(regions):
|
||||
st_x, st_y, ed_x, ed_y = r
|
||||
region_fea = (fea[:, :, st_x: ed_x, st_y: ed_y].max(dim=3)[0]).max(dim=2)[0]
|
||||
region_fea = region_fea / torch.norm(region_fea, dim=1, keepdim=True)
|
||||
if final_fea is None:
|
||||
final_fea = region_fea
|
||||
else:
|
||||
final_fea = final_fea + region_fea
|
||||
ret[key + "_{}".format(self.__class__.__name__)] = final_fea
|
||||
else:
|
||||
# In case of fc feature.
|
||||
assert fea.ndimension() == 2
|
||||
if self.first_show:
|
||||
print("[RMAC Aggregator]: find 2-dimension feature map, skip aggregation")
|
||||
self.first_show = False
|
||||
ret[key] = fea
|
||||
return ret
|
|
@ -0,0 +1,103 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import queue
|
||||
|
||||
import torch
|
||||
|
||||
from ..aggregators_base import AggregatorBase
|
||||
from ...registry import AGGREGATORS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@AGGREGATORS.register
|
||||
class SCDA(AggregatorBase):
|
||||
"""
|
||||
Selective Convolutional Descriptor Aggregation for Fine-Grained Image Retrieval.
|
||||
c.f. http://www.weixiushen.com/publication/tip17SCDA.pdf
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(SCDA, self).__init__(hps)
|
||||
self.first_show = True
|
||||
|
||||
def bfs(self, x: int, y: int, mask: torch.tensor, cc_map: torch.tensor, cc_id: int) -> int:
|
||||
dirs = [[1, 0], [-1, 0], [0, 1], [0, -1]]
|
||||
q = queue.LifoQueue()
|
||||
q.put((x, y))
|
||||
|
||||
ret = 1
|
||||
cc_map[x][y] = cc_id
|
||||
|
||||
while not q.empty():
|
||||
x, y = q.get()
|
||||
|
||||
for (dx, dy) in dirs:
|
||||
new_x = x + dx
|
||||
new_y = y + dy
|
||||
if 0 <= new_x < mask.shape[0] and 0 <= new_y < mask.shape[1]:
|
||||
if mask[new_x][new_y] == 1 and cc_map[new_x][new_y] == 0:
|
||||
q.put((new_x, new_y))
|
||||
ret += 1
|
||||
cc_map[new_x][new_y] = cc_id
|
||||
return ret
|
||||
|
||||
def find_max_cc(self, mask: torch.tensor) -> torch.tensor:
|
||||
"""
|
||||
Find the largest connected component of the mask。
|
||||
|
||||
Args:
|
||||
mask (torch.tensor): the original mask.
|
||||
|
||||
Returns:
|
||||
mask (torch.tensor): the mask only containing the maximum connected component.
|
||||
"""
|
||||
assert mask.ndim == 4
|
||||
assert mask.shape[1] == 1
|
||||
mask = mask[:, 0, :, :]
|
||||
for i in range(mask.shape[0]):
|
||||
m = mask[i]
|
||||
cc_map = torch.zeros(m.shape)
|
||||
cc_num = list()
|
||||
|
||||
for x in range(m.shape[0]):
|
||||
for y in range(m.shape[1]):
|
||||
if m[x][y] == 1 and cc_map[x][y] == 0:
|
||||
cc_id = len(cc_num) + 1
|
||||
cc_num.append(self.bfs(x, y, m, cc_map, cc_id))
|
||||
|
||||
max_cc_id = cc_num.index(max(cc_num)) + 1
|
||||
m[cc_map != max_cc_id] = 0
|
||||
mask = mask[:, None, :, :]
|
||||
return mask
|
||||
|
||||
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
||||
ret = dict()
|
||||
for key in features:
|
||||
fea = features[key]
|
||||
if fea.ndimension() == 4:
|
||||
mask = fea.sum(dim=1, keepdims=True)
|
||||
thres = mask.mean(dim=(2, 3), keepdims=True)
|
||||
mask[mask <= thres] = 0
|
||||
mask[mask > thres] = 1
|
||||
mask = self.find_max_cc(mask)
|
||||
fea = fea * mask
|
||||
|
||||
gap = fea.mean(dim=(2, 3))
|
||||
gmp, _ = fea.max(dim=3)
|
||||
gmp, _ = gmp.max(dim=2)
|
||||
|
||||
ret[key + "_{}".format(self.__class__.__name__)] = torch.cat([gap, gmp], dim=1)
|
||||
else:
|
||||
# In case of fc feature.
|
||||
assert fea.ndimension() == 2
|
||||
if self.first_show:
|
||||
print("[SCDA Aggregator]: find 2-dimension feature map, skip aggregation")
|
||||
self.first_show = False
|
||||
ret[key] = fea
|
||||
return ret
|
|
@ -0,0 +1,54 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from ..aggregators_base import AggregatorBase
|
||||
from ...registry import AGGREGATORS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@AGGREGATORS.register
|
||||
class SPoC(AggregatorBase):
|
||||
"""
|
||||
SPoC with center prior.
|
||||
c.f. https://arxiv.org/pdf/1510.07493.pdf
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(SPoC, self).__init__(hps)
|
||||
self.first_show = True
|
||||
self.spatial_weight_cache = dict()
|
||||
|
||||
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
||||
ret = dict()
|
||||
for key in features:
|
||||
fea = features[key]
|
||||
if fea.ndimension() == 4:
|
||||
h, w = fea.shape[2:]
|
||||
if (h, w) in self.spatial_weight_cache:
|
||||
spatial_weight = self.spatial_weight_cache[(h, w)]
|
||||
else:
|
||||
sigma = min(h, w) / 2.0 / 3.0
|
||||
x = torch.Tensor(range(w))
|
||||
y = torch.Tensor(range(h))[:, None]
|
||||
spatial_weight = torch.exp(-((x - (w - 1) / 2.0) ** 2 + (y - (h - 1) / 2.0) ** 2) / 2.0 / (sigma ** 2))
|
||||
if torch.cuda.is_available():
|
||||
spatial_weight = spatial_weight.cuda()
|
||||
spatial_weight = spatial_weight[None, None, :, :]
|
||||
self.spatial_weight_cache[(h, w)] = spatial_weight
|
||||
fea = (fea * spatial_weight).sum(dim=(2, 3))
|
||||
ret[key + "_{}".format(self.__class__.__name__)] = fea
|
||||
else:
|
||||
# In case of fc feature.
|
||||
assert fea.ndimension() == 2
|
||||
if self.first_show:
|
||||
print("[SPoC Aggregator]: find 2-dimension feature map, skip aggregation")
|
||||
self.first_show = False
|
||||
ret[key] = fea
|
||||
return ret
|
|
@ -0,0 +1,83 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .registry import AGGREGATORS, SPLITTERS, EXTRACTORS
|
||||
from .extractor import ExtractorBase
|
||||
from .splitter import SplitterBase
|
||||
from .aggregator import AggregatorBase
|
||||
from .helper import ExtractHelper
|
||||
|
||||
from ..utils import simple_build
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def build_aggregators(cfg: CfgNode) -> List[AggregatorBase]:
|
||||
"""
|
||||
Instantiate a list of aggregator classes.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
aggregators (list): a list of instances of aggregator class.
|
||||
"""
|
||||
names = cfg["names"]
|
||||
aggregators = list()
|
||||
for name in names:
|
||||
aggregators.append(simple_build(name, cfg, AGGREGATORS))
|
||||
return aggregators
|
||||
|
||||
|
||||
def build_extractor(model: nn.Module, cfg: CfgNode) -> ExtractorBase:
|
||||
"""
|
||||
Instantiate a extractor class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model for extracting features.
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
extractor (ExtractorBase): an instance of extractor class.
|
||||
"""
|
||||
name = cfg["name"]
|
||||
extractor = simple_build(name, cfg, EXTRACTORS, model=model)
|
||||
return extractor
|
||||
|
||||
|
||||
def build_splitter(cfg: CfgNode) -> SplitterBase:
|
||||
"""
|
||||
Instantiate a splitter class.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
splitter (SplitterBase): an instance of splitter class.
|
||||
"""
|
||||
name = cfg["name"]
|
||||
splitter = simple_build(name, cfg, SPLITTERS)
|
||||
return splitter
|
||||
|
||||
|
||||
def build_extract_helper(model: nn.Module, cfg: CfgNode) -> ExtractHelper:
|
||||
"""
|
||||
Instantiate a extract helper class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model for extracting features.
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
helper (ExtractHelper): an instance of extract helper class.
|
||||
"""
|
||||
assemble = cfg.assemble
|
||||
extractor = build_extractor(model, cfg.extractor)
|
||||
splitter = build_splitter(cfg.splitter)
|
||||
aggregators = build_aggregators(cfg.aggregators)
|
||||
helper = ExtractHelper(assemble, extractor, splitter, aggregators)
|
||||
return helper
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .registry import EXTRACTORS, SPLITTERS, AGGREGATORS
|
||||
|
||||
from ..utils import get_config_from_registry
|
||||
|
||||
|
||||
def get_aggregators_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(AGGREGATORS)
|
||||
cfg["names"] = list()
|
||||
return cfg
|
||||
|
||||
|
||||
def get_splitter_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(SPLITTERS)
|
||||
cfg["name"] = "unknown"
|
||||
return cfg
|
||||
|
||||
|
||||
def get_extractor_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(EXTRACTORS)
|
||||
cfg["name"] = "unknown"
|
||||
return cfg
|
||||
|
||||
|
||||
def get_extract_cfg() -> CfgNode:
|
||||
cfg = CfgNode()
|
||||
cfg["assemble"] = 0
|
||||
cfg["extractor"] = get_extractor_cfg()
|
||||
cfg["splitter"] = get_splitter_cfg()
|
||||
cfg["aggregators"] = get_aggregators_cfg()
|
||||
return cfg
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .extractors_impl.vgg_series import VggSeries
|
||||
from .extractors_impl.res_series import ResSeries
|
||||
from .extractors_impl.reid_series import ReIDSeries
|
||||
from .extractors_base import ExtractorBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ExtractorBase',
|
||||
'VggSeries', 'ResSeries',
|
||||
'ReIDSeries',
|
||||
]
|
|
@ -0,0 +1,74 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from ...utils import ModuleBase
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class ExtractorBase(ModuleBase):
|
||||
"""
|
||||
The base class feature map extractors.
|
||||
|
||||
Hyper-Parameters
|
||||
extract_features (list): indicates which feature maps to output. See available_feas for available feature maps.
|
||||
If it is ["all"], then all available features will be output.
|
||||
"""
|
||||
available_feas = list()
|
||||
default_hyper_params = {
|
||||
"extract_features": list(),
|
||||
}
|
||||
|
||||
def __init__(self, model: nn.Module, feature_modules: Dict[str, nn.Module], hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
model (nn.Module): the model for extracting features.
|
||||
feature_modules (dict): the output layer of the model.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(ExtractorBase, self).__init__(hps)
|
||||
assert len(self._hyper_params["extract_features"]) > 0
|
||||
|
||||
self.model = model.eval()
|
||||
if torch.cuda.is_available():
|
||||
self.model.cuda()
|
||||
if torch.cuda.device_count() > 1:
|
||||
self.model = nn.DataParallel(self.model)
|
||||
self.feature_modules = feature_modules
|
||||
self.feature_buffer = dict()
|
||||
|
||||
if self._hyper_params["extract_features"][0] == "all":
|
||||
self._hyper_params["extract_features"] = self.available_feas
|
||||
for fea in self._hyper_params["extract_features"]:
|
||||
self.feature_buffer[fea] = dict()
|
||||
|
||||
self._register_hook()
|
||||
|
||||
def _register_hook(self) -> None:
|
||||
"""
|
||||
Register hooks to output inner feature map.
|
||||
"""
|
||||
def hook(feature_buffer, fea_name, module, input, output):
|
||||
feature_buffer[fea_name][str(output.device)] = output.data
|
||||
|
||||
for fea in self._hyper_params["extract_features"]:
|
||||
assert fea in self.feature_modules, 'unknown feature {}!'.format(fea)
|
||||
self.feature_modules[fea].register_forward_hook(partial(hook, self.feature_buffer, fea))
|
||||
|
||||
def __call__(self, x: torch.tensor) -> Dict:
|
||||
with torch.no_grad():
|
||||
self.model(x)
|
||||
ret = dict()
|
||||
for fea in self._hyper_params["extract_features"]:
|
||||
ret[fea] = list()
|
||||
devices = list(self.feature_buffer[fea].keys())
|
||||
devices = np.sort(devices)
|
||||
for d in devices:
|
||||
ret[fea].append(self.feature_buffer[fea][d])
|
||||
ret[fea] = torch.cat(ret[fea], dim=0)
|
||||
return ret
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile
|
||||
|
||||
|
||||
modules = glob.glob(dirname(__file__) + "/*.py")
|
||||
__all__ = [
|
||||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
|
||||
]
|
|
@ -0,0 +1,35 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..extractors_base import ExtractorBase
|
||||
from ...registry import EXTRACTORS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@EXTRACTORS.register
|
||||
class ReIDSeries(ExtractorBase):
|
||||
"""
|
||||
The extractors for reid baseline models.
|
||||
|
||||
Hyper-Parameters
|
||||
extract_features (list): indicates which feature maps to output. See available_feas for available feature maps.
|
||||
If it is ["all"], then all available features will be output.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"extract_features": list(),
|
||||
}
|
||||
|
||||
available_feas = ["output"]
|
||||
|
||||
def __init__(self, model: nn.Module, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
model (nn.Module): the model for extracting features.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
children = list(model.children())
|
||||
feature_modules = {
|
||||
"output": children[1].add_block,
|
||||
}
|
||||
super(ReIDSeries, self).__init__(model, feature_modules, hps)
|
|
@ -0,0 +1,35 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from ..extractors_base import ExtractorBase
|
||||
from ...registry import EXTRACTORS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@EXTRACTORS.register
|
||||
class ResSeries(ExtractorBase):
|
||||
"""
|
||||
The extractors for ResNet.
|
||||
|
||||
Hyper-Parameters
|
||||
extract_features (list): indicates which feature maps to output. See available_feas for available feature maps.
|
||||
If it is ["all"], then all available features will be output.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"extract_features": list(),
|
||||
}
|
||||
|
||||
available_feas = ["pool5", "pool4", "pool3"]
|
||||
|
||||
def __init__(self, model, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
model (nn.Module): the model for extracting features.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
children = list(model.children())
|
||||
feature_modules = {
|
||||
"pool5": children[-3][-1].relu,
|
||||
"pool4": children[-4][-1].relu,
|
||||
"pool3": children[-5][-1].relu
|
||||
}
|
||||
super(ResSeries, self).__init__(model, feature_modules, hps)
|
|
@ -0,0 +1,34 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from ..extractors_base import ExtractorBase
|
||||
from ...registry import EXTRACTORS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@EXTRACTORS.register
|
||||
class VggSeries(ExtractorBase):
|
||||
"""
|
||||
The extractors for VGG Net.
|
||||
|
||||
Hyper-Parameters
|
||||
extract_features (list): indicates which feature maps to output. See available_feas for available feature maps.
|
||||
If it is ["all"], then all available features will be output.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"extract_features": list(),
|
||||
}
|
||||
|
||||
available_feas = ["fc", "pool5", "pool4"]
|
||||
|
||||
def __init__(self, model, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
model (nn.Module): the model for extracting features.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
feature_modules = {
|
||||
"fc": model.classifier[4],
|
||||
"pool5": model.features[-1],
|
||||
"pool4": model.features[23]
|
||||
}
|
||||
super(VggSeries, self).__init__(model, feature_modules, hps)
|
|
@ -0,0 +1,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .helper import ExtractHelper
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ExtractHelper',
|
||||
]
|
|
@ -0,0 +1,141 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from copy import deepcopy
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
|
||||
from ..extractor import ExtractorBase
|
||||
from ..aggregator import AggregatorBase
|
||||
from ..splitter import SplitterBase
|
||||
from ...utils import ensure_dir
|
||||
from torch.utils.data import DataLoader
|
||||
from typing import Dict, List
|
||||
|
||||
class ExtractHelper:
|
||||
"""
|
||||
A helper class to extract feature maps from model, and then aggregate them.
|
||||
"""
|
||||
def __init__(self, assemble: int, extractor: ExtractorBase, splitter: SplitterBase, aggregators: List[AggregatorBase]):
|
||||
"""
|
||||
Args:
|
||||
assemble (int): way to assemble features if transformers produce multiple images (e.g. TwoFlip, TenCrop).
|
||||
extractor (ExtractorBase): a extractor class for extracting features.
|
||||
splitter (SplitterBase): a splitter class for splitting features.
|
||||
aggregators (list): a list of extractor classes for aggregating features.
|
||||
"""
|
||||
self.assemble = assemble
|
||||
self.extractor = extractor
|
||||
self.splitter = splitter
|
||||
self.aggregators = aggregators
|
||||
|
||||
def _save_part_fea(self, datainfo: Dict, save_fea: List, save_path: str) -> None:
|
||||
"""
|
||||
Save features in a json file.
|
||||
|
||||
Args:
|
||||
datainfo (dict): the dataset information contained the data json file.
|
||||
save_fea (list): a list of features to be saved.
|
||||
save_path (str): the save path for the extracted features.
|
||||
"""
|
||||
save_json = dict()
|
||||
for key in datainfo:
|
||||
if key != "info_dicts":
|
||||
save_json[key] = datainfo[key]
|
||||
save_json["info_dicts"] = save_fea
|
||||
|
||||
with open(save_path, "wb") as f:
|
||||
pickle.dump(save_json, f)
|
||||
|
||||
def extract_one_batch(self, batch: Dict) -> Dict:
|
||||
"""
|
||||
Extract features for a batch of images.
|
||||
|
||||
Args:
|
||||
batch (dict): a dict containing several image tensors.
|
||||
|
||||
Returns:
|
||||
all_fea_dict (dict): a dict containing extracted features.
|
||||
"""
|
||||
img = batch["img"]
|
||||
if torch.cuda.is_available():
|
||||
img = img.cuda()
|
||||
# img is in the shape (N, IMG_AUG, C, H, W)
|
||||
batch_size, aug_size = img.shape[0], img.shape[1]
|
||||
img = img.view(-1, img.shape[2], img.shape[3], img.shape[4])
|
||||
|
||||
features = self.extractor(img)
|
||||
|
||||
features = self.splitter(features)
|
||||
|
||||
all_fea_dict = dict()
|
||||
for aggregator in self.aggregators:
|
||||
fea_dict = aggregator(features)
|
||||
all_fea_dict.update(fea_dict)
|
||||
|
||||
# PyTorch will duplicate inputs if batch_size < n_gpu
|
||||
for key in all_fea_dict.keys():
|
||||
if self.assemble == 0:
|
||||
features = all_fea_dict[key][:img.shape[0], :]
|
||||
features = features.view(batch_size, aug_size, -1)
|
||||
features = features.view(batch_size, -1)
|
||||
all_fea_dict[key] = features
|
||||
elif self.assemble == 1:
|
||||
features = all_fea_dict[key].view(batch_size, aug_size, -1)
|
||||
features = features.sum(dim=1)
|
||||
all_fea_dict[key] = features
|
||||
|
||||
return all_fea_dict
|
||||
|
||||
def do_extract(self, dataloader: DataLoader, save_path: str, save_interval: int = 5000) -> None:
|
||||
"""
|
||||
Extract features for a whole dataset and save features in json files.
|
||||
|
||||
Args:
|
||||
dataloader (DataLoader): a DataLoader class for loading images for training.
|
||||
save_path (str): the save path for the extracted features.
|
||||
save_interval (int, optional): number of features saved in one part file.
|
||||
"""
|
||||
datainfo = dataloader.dataset.data_info
|
||||
pbar = tqdm(range(len(dataloader)))
|
||||
save_fea = list()
|
||||
part_cnt = 0
|
||||
ensure_dir(save_path)
|
||||
|
||||
for _, batch in zip(pbar, dataloader):
|
||||
feature_dict = self.extract_one_batch(batch)
|
||||
for i in range(len(batch["img"])):
|
||||
idx = batch["idx"][i]
|
||||
save_fea.append(deepcopy(datainfo["info_dicts"][idx]))
|
||||
single_fea_dict = dict()
|
||||
for key in feature_dict:
|
||||
single_fea_dict[key] = feature_dict[key][i].tolist()
|
||||
save_fea[-1]["feature"] = single_fea_dict
|
||||
save_fea[-1]["idx"] = int(idx)
|
||||
|
||||
if len(save_fea) >= save_interval:
|
||||
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
|
||||
part_cnt += 1
|
||||
del save_fea
|
||||
save_fea = list()
|
||||
|
||||
if len(save_fea) >= 1:
|
||||
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
|
||||
|
||||
def do_single_extract(self, img: torch.Tensor) -> [Dict]:
|
||||
"""
|
||||
Extract features for a single image.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): a single image tensor.
|
||||
|
||||
Returns:
|
||||
[fea_dict] (sequence): the extract features of the image.
|
||||
"""
|
||||
batch = dict()
|
||||
batch["img"] = img.view(1, img.shape[0], img.shape[1], img.shape[2], img.shape[3])
|
||||
fea_dict = self.extract_one_batch(batch)
|
||||
|
||||
return [fea_dict]
|
|
@ -0,0 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from ..utils.registry import Registry
|
||||
|
||||
EXTRACTORS = Registry()
|
||||
SPLITTERS = Registry()
|
||||
AGGREGATORS = Registry()
|
Binary file not shown.
|
@ -0,0 +1,13 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .splitter_impl.identity import Identity
|
||||
from .splitter_impl.pcb import PCB
|
||||
from .splitter_base import SplitterBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'SplitterBase',
|
||||
'Identity', 'PCB',
|
||||
]
|
|
@ -0,0 +1,27 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import ModuleBase
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class SplitterBase(ModuleBase):
|
||||
"""
|
||||
The base class for splitter function.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(SplitterBase, self).__init__(hps)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, features: torch.tensor) -> Dict:
|
||||
pass
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile
|
||||
|
||||
|
||||
modules = glob.glob(dirname(__file__) + "/*.py")
|
||||
__all__ = [
|
||||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
|
||||
]
|
|
@ -0,0 +1,27 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from ..splitter_base import SplitterBase
|
||||
from ...registry import SPLITTERS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@SPLITTERS.register
|
||||
class Identity(SplitterBase):
|
||||
"""
|
||||
Directly return feature maps without any operations.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(Identity, self).__init__(hps)
|
||||
|
||||
def __call__(self, features: torch.tensor) -> torch.tensor:
|
||||
return features
|
|
@ -0,0 +1,45 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from ..splitter_base import SplitterBase
|
||||
from ...registry import SPLITTERS
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@SPLITTERS.register
|
||||
class PCB(SplitterBase):
|
||||
"""
|
||||
PCB function to split feature maps.
|
||||
c.f. http://openaccess.thecvf.com/content_ECCV_2018/papers/Yifan_Sun_Beyond_Part_Models_ECCV_2018_paper.pdf
|
||||
|
||||
Hyper-Params:
|
||||
stripe_num (int): the number of stripes divided.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
'stripe_num': 2,
|
||||
}
|
||||
|
||||
def __init__(self, hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(PCB, self).__init__(hps)
|
||||
|
||||
def __call__(self, features: torch.tensor) -> Dict:
|
||||
ret = dict()
|
||||
for key in features:
|
||||
fea = features[key]
|
||||
assert fea.ndimension() == 4
|
||||
assert self.default_hyper_params["stripe_num"] <= fea.shape[2], \
|
||||
'stripe num must be less than or equal to the height of fea'
|
||||
|
||||
stride = fea.shape[2] // self.default_hyper_params["stripe_num"]
|
||||
|
||||
for i in range(int(self.default_hyper_params["stripe_num"])):
|
||||
ret[key + "_part_{}".format(i)] = fea[:, :, stride * i: stride * (i + 1), :]
|
||||
ret[key + "_global"] = fea
|
||||
return ret
|
|
@ -0,0 +1,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .make_data_json import make_data_json
|
||||
from .split_dataset import split_dataset
|
||||
|
||||
__all__ = [
|
||||
'split_dataset', 'make_data_json',
|
||||
]
|
|
@ -0,0 +1,115 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import pickle
|
||||
import os
|
||||
|
||||
|
||||
def make_ds_for_general(dataset_path: str, save_path: str) -> None:
|
||||
"""
|
||||
Generate data json file for dataset collecting images with the same label one directory. e.g. CUB-200-2011.
|
||||
|
||||
Args:
|
||||
dataset_path (str): the path of the dataset.
|
||||
save_ds_path (str): the path for saving the data json files.
|
||||
"""
|
||||
info_dicts = list()
|
||||
img_dirs = os.listdir(dataset_path)
|
||||
label_list = list()
|
||||
label_to_idx = dict()
|
||||
for dir in img_dirs:
|
||||
for root, _, files in os.walk(os.path.join(dataset_path, dir)):
|
||||
for file in files:
|
||||
info_dict = dict()
|
||||
info_dict['path'] = os.path.join(root, file)
|
||||
if dir not in label_list:
|
||||
label_to_idx[dir] = len(label_list)
|
||||
label_list.append(dir)
|
||||
info_dict['label'] = dir
|
||||
info_dict['label_idx'] = label_to_idx[dir]
|
||||
info_dicts += [info_dict]
|
||||
with open(save_path, 'wb') as f:
|
||||
pickle.dump({'nr_class': len(img_dirs), 'path_type': 'absolute_path', 'info_dicts': info_dicts}, f)
|
||||
|
||||
|
||||
def make_ds_for_oxford(dataset_path, save_path: str or None=None, gt_path: str or None=None) -> None:
|
||||
"""
|
||||
Generate data json file for oxford dataset.
|
||||
|
||||
Args:
|
||||
dataset_path (str): the path of the dataset.
|
||||
save_ds_path (str): the path for saving the data json files.
|
||||
gt_path (str, optional): the path of the ground truth, necessary for Oxford.
|
||||
"""
|
||||
label_list = list()
|
||||
info_dicts = list()
|
||||
query_info = dict()
|
||||
if 'query' in dataset_path:
|
||||
for root, _, files in os.walk(gt_path):
|
||||
for file in files:
|
||||
if 'query' in file:
|
||||
with open(os.path.join(root, file), 'r') as f:
|
||||
line = f.readlines()[0].strip('\n').split(' ')
|
||||
query_name = file[:-10]
|
||||
label = line[0][5:]
|
||||
bbox = [float(line[1]), float(line[2]), float(line[3]), float(line[4])]
|
||||
query_info[label] = {'query_name': query_name, 'bbox': bbox,}
|
||||
|
||||
for root, _, files in os.walk(dataset_path):
|
||||
for file in files:
|
||||
info_dict = dict()
|
||||
info_dict['path'] = os.path.join(root, file)
|
||||
label = file.split('.')[0]
|
||||
if label not in label_list:
|
||||
label_list.append(label)
|
||||
info_dict['label'] = label
|
||||
if 'query' in dataset_path:
|
||||
info_dict['bbox'] = query_info[label]['bbox']
|
||||
info_dict['query_name'] = query_info[label]['query_name']
|
||||
info_dicts += [info_dict]
|
||||
|
||||
with open(save_path, 'wb') as f:
|
||||
pickle.dump({'nr_class': len(label_list), 'path_type': 'absolute_path', 'info_dicts': info_dicts}, f)
|
||||
|
||||
|
||||
def make_ds_for_reid(dataset_path: str, save_path: str) -> None:
|
||||
"""
|
||||
Generating data json file for Re-ID dataset.
|
||||
|
||||
Args:
|
||||
dataset_path (str): the path of the dataset.
|
||||
save_ds_path (str): the path for saving the data json files.
|
||||
"""
|
||||
label_list = list()
|
||||
info_dicts = list()
|
||||
for root, _, files in os.walk(dataset_path):
|
||||
for file in files:
|
||||
info_dict = dict()
|
||||
info_dict['path'] = os.path.join(root, file)
|
||||
label = file.split('_')[0]
|
||||
cam = file.split('_')[1][1]
|
||||
if label not in label_list:
|
||||
label_list.append(label)
|
||||
info_dict['label'] = label
|
||||
info_dict['cam'] = cam
|
||||
info_dicts += [info_dict]
|
||||
with open(save_path, 'wb') as f:
|
||||
pickle.dump({'nr_class': len(label_list), 'path_type': 'absolute_path', 'info_dicts': info_dicts}, f)
|
||||
|
||||
|
||||
def make_data_json(dataset_path: str, save_path: str, type: str, gt_path: str or None=None) -> None:
|
||||
"""
|
||||
Generate data json file for dataset.
|
||||
|
||||
Args:
|
||||
dataset_path (str): the path of the dataset.
|
||||
save_ds_path (str): the path for saving the data json files.
|
||||
type (str): the structure type of the dataset.
|
||||
gt_path (str, optional): the path of the ground truth, necessary for Oxford.
|
||||
"""
|
||||
assert type in ['general', 'oxford', 'reid']
|
||||
if type == 'general':
|
||||
make_ds_for_general(dataset_path, save_path)
|
||||
elif type == 'oxford':
|
||||
make_ds_for_oxford(dataset_path, save_path, gt_path)
|
||||
elif typem == 'reid':
|
||||
make_ds_for_reid(dataset_path, save_path)
|
|
@ -0,0 +1,35 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from shutil import copyfile
|
||||
|
||||
|
||||
def split_dataset(dataset_path: str, split_file: str) -> None:
|
||||
"""
|
||||
Split the dataset according to the given splitting rules.
|
||||
|
||||
Args:
|
||||
dataset_path (str): the path of the dataset.
|
||||
split_file (str): the path of the file containing the splitting rules.
|
||||
"""
|
||||
|
||||
with open(split_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
path = line.strip('\n').split(' ')[0]
|
||||
is_gallery = line.strip('\n').split(' ')[1]
|
||||
if is_gallery == '0':
|
||||
src = os.path.join(dataset_path, path)
|
||||
dst = src.replace(path.split('/')[0], 'query')
|
||||
dst_index = len(dst.split('/')[-1])
|
||||
dst_dir = dst[:len(dst) - dst_index]
|
||||
if not os.path.isdir(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
os.symlink(src, dst)
|
||||
elif is_gallery == '1':
|
||||
src = os.path.join(dataset_path, path)
|
||||
dst = src.replace(path.split('/')[0], 'gallery')
|
||||
dst_index = len(dst.split('/')[-1])
|
||||
dst_dir = dst[:len(dst) - dst_index]
|
||||
if not os.path.isdir(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
os.symlink(src, dst)
|
Binary file not shown.
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .config import get_index_cfg
|
||||
|
||||
from .builder import build_index_helper
|
||||
|
||||
from .utils import feature_loader
|
||||
|
||||
|
||||
__all__ = [
|
||||
'get_index_cfg',
|
||||
'build_index_helper',
|
||||
'feature_loader',
|
||||
]
|
|
@ -0,0 +1,94 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .registry import ENHANCERS, METRICS, DIMPROCESSORS, RERANKERS
|
||||
from .feature_enhancer import EnhanceBase
|
||||
from .helper import IndexHelper
|
||||
from .metric import MetricBase
|
||||
from .dim_processor import DimProcessorBase
|
||||
from .re_ranker import ReRankerBase
|
||||
|
||||
from ..utils import simple_build
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def build_enhance(cfg: CfgNode) -> EnhanceBase:
|
||||
"""
|
||||
Instantiate a feature enhancer class.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
enhance (EnhanceBase): an instance of feature enhancer class.
|
||||
"""
|
||||
name = cfg["name"]
|
||||
enhance = simple_build(name, cfg, ENHANCERS)
|
||||
return enhance
|
||||
|
||||
|
||||
def build_metric(cfg: CfgNode) -> MetricBase:
|
||||
"""
|
||||
Instantiate a metric class.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
metric (MetricBase): an instance of metric class.
|
||||
"""
|
||||
name = cfg["name"]
|
||||
metric = simple_build(name, cfg, METRICS)
|
||||
return metric
|
||||
|
||||
|
||||
def build_processors(feature_names: List[str], cfg: CfgNode) -> DimProcessorBase:
|
||||
"""
|
||||
Instantiate a list of dimension processor classes.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
processors (list): a list of instances of dimension process class.
|
||||
"""
|
||||
names = cfg["names"]
|
||||
processors = list()
|
||||
for name in names:
|
||||
processors.append(simple_build(name, cfg, DIMPROCESSORS, feature_names=feature_names))
|
||||
return processors
|
||||
|
||||
|
||||
def build_ranker(cfg: CfgNode) -> ReRankerBase:
|
||||
"""
|
||||
Instantiate a re-ranker class.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
re_rank (list): an instance of re-ranker class.
|
||||
"""
|
||||
name = cfg["name"]
|
||||
re_rank = simple_build(name, cfg, RERANKERS)
|
||||
return re_rank
|
||||
|
||||
|
||||
def build_index_helper(cfg: CfgNode) -> IndexHelper:
|
||||
"""
|
||||
Instantiate a index helper class.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): the configuration tree.
|
||||
|
||||
Returns:
|
||||
helper (IndexHelper): an instance of index helper class.
|
||||
"""
|
||||
dim_processors = build_processors(cfg["feature_names"], cfg.dim_processors)
|
||||
metric = build_metric(cfg.metric)
|
||||
feature_enhancer = build_enhance(cfg.feature_enhancer)
|
||||
re_ranker = build_ranker(cfg.re_ranker)
|
||||
helper = IndexHelper(dim_processors, feature_enhancer, metric, re_ranker)
|
||||
return helper
|
|
@ -0,0 +1,43 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .registry import ENHANCERS, METRICS, DIMPROCESSORS, RERANKERS
|
||||
|
||||
from ..utils import get_config_from_registry
|
||||
|
||||
|
||||
def get_enhancer_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(ENHANCERS)
|
||||
cfg["name"] = "unknown"
|
||||
return cfg
|
||||
|
||||
|
||||
def get_metric_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(METRICS)
|
||||
cfg["name"] = "unknown"
|
||||
return cfg
|
||||
|
||||
|
||||
def get_processors_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(DIMPROCESSORS)
|
||||
cfg["names"] = ["unknown"]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_ranker_cfg() -> CfgNode:
|
||||
cfg = get_config_from_registry(RERANKERS)
|
||||
cfg["name"] = "unknown"
|
||||
return cfg
|
||||
|
||||
|
||||
def get_index_cfg() -> CfgNode:
|
||||
cfg = CfgNode()
|
||||
cfg["query_fea_dir"] = "unknown"
|
||||
cfg["gallery_fea_dir"] = "unknown"
|
||||
cfg["feature_names"] = ["all"]
|
||||
cfg["dim_processors"] = get_processors_cfg()
|
||||
cfg["feature_enhancer"] = get_enhancer_cfg()
|
||||
cfg["metric"] = get_metric_cfg()
|
||||
cfg["re_ranker"] = get_ranker_cfg()
|
||||
return cfg
|
Binary file not shown.
|
@ -0,0 +1,18 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .dim_processors_impl.identity import Identity
|
||||
from .dim_processors_impl.l2_normalize import L2Normalize
|
||||
from .dim_processors_impl.part_pca import PartPCA
|
||||
from .dim_processors_impl.part_svd import PartSVD
|
||||
from .dim_processors_impl.pca import PCA
|
||||
from .dim_processors_impl.svd import SVD
|
||||
from .dim_processors_impl.rmac_pca import RMACPCA
|
||||
from .dim_processors_base import DimProcessorBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'DimProcessorBase',
|
||||
'Identity', 'L2Normalize', 'PartPCA', 'PartSVD', 'PCA', 'SVD', 'RMACPCA',
|
||||
]
|
|
@ -0,0 +1,29 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils import ModuleBase
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class DimProcessorBase(ModuleBase):
|
||||
"""
|
||||
The base class of dimension processor.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, feature_names: List[str], hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
feature_names (list): a list of features names to be loaded.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
ModuleBase.__init__(self, hps)
|
||||
self.feature_names = feature_names
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||
pass
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile
|
||||
|
||||
|
||||
modules = glob.glob(dirname(__file__) + "/*.py")
|
||||
__all__ = [
|
||||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") and not f.endswith("utils.py")
|
||||
]
|
|
@ -0,0 +1,26 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dim_processors_base import DimProcessorBase
|
||||
from ...registry import DIMPROCESSORS
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
@DIMPROCESSORS.register
|
||||
class Identity(DimProcessorBase):
|
||||
"""
|
||||
Directly return feature without any dimension process operations.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, feature_names: List[str], hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
feature_names (list): a list of features names to be loaded.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(Identity, self).__init__(feature_names, hps)
|
||||
|
||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||
return fea
|
|
@ -0,0 +1,27 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dim_processors_base import DimProcessorBase
|
||||
from ...registry import DIMPROCESSORS
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
@DIMPROCESSORS.register
|
||||
class L2Normalize(DimProcessorBase):
|
||||
"""
|
||||
L2 normalize the features.
|
||||
"""
|
||||
default_hyper_params = dict()
|
||||
|
||||
def __init__(self, feature_names: List[str], hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
feature_names (list): a list of features names to be loaded.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(L2Normalize, self).__init__(feature_names, hps)
|
||||
|
||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||
return normalize(fea, norm="l2")
|
|
@ -0,0 +1,92 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dim_processors_base import DimProcessorBase
|
||||
from ...registry import DIMPROCESSORS
|
||||
from ...utils import feature_loader
|
||||
from sklearn.preprocessing import normalize
|
||||
from sklearn.decomposition import PCA as SKPCA
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
@DIMPROCESSORS.register
|
||||
class PartPCA(DimProcessorBase):
|
||||
"""
|
||||
Part PCA will divided whole feature into several parts. Then apply PCA transformation to each part.
|
||||
It is usually used for features that extracted by several feature maps and concatenated together.
|
||||
|
||||
Hyper-Params:
|
||||
proj_dim (int): the dimension after reduction. If it is 0, then no reduction will be done.
|
||||
whiten (bool): whether do whiten for each part.
|
||||
train_fea_dir (str): the path of features for training PCA.
|
||||
l2 (bool): whether do l2-normalization for the training features.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"proj_dim": 0,
|
||||
"whiten": True,
|
||||
"train_fea_dir": "unknown",
|
||||
"l2": True,
|
||||
}
|
||||
|
||||
def __init__(self, feature_names: List[str], hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
feature_names (list): a list of features names to be loaded.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(PartPCA, self).__init__(feature_names, hps)
|
||||
|
||||
self.pcas = dict()
|
||||
self._train(self._hyper_params["train_fea_dir"])
|
||||
|
||||
def _train(self, fea_dir: str) -> None:
|
||||
"""
|
||||
Train the part PCA.
|
||||
|
||||
Args:
|
||||
fea_dir (str): the path of features for training part PCA.
|
||||
"""
|
||||
fea, _, pos_info = feature_loader.load(fea_dir, self.feature_names)
|
||||
fea_names = list(pos_info.keys())
|
||||
ori_dim = fea.shape[1]
|
||||
|
||||
already_proj_dim = 0
|
||||
for fea_name in fea_names:
|
||||
st_idx, ed_idx = pos_info[fea_name][0], pos_info[fea_name][1]
|
||||
ori_part_dim = ed_idx - st_idx
|
||||
if self._hyper_params["proj_dim"] == 0:
|
||||
proj_part_dim = ori_part_dim
|
||||
else:
|
||||
ratio = self._hyper_params["proj_dim"] * 1.0 / ori_dim
|
||||
if fea_name != fea_names[-1]:
|
||||
proj_part_dim = int(ori_part_dim * ratio)
|
||||
else:
|
||||
proj_part_dim = self._hyper_params["proj_dim"] - already_proj_dim
|
||||
assert proj_part_dim <= ori_part_dim, "reduction dimension can not be distributed to each part!"
|
||||
already_proj_dim += proj_part_dim
|
||||
|
||||
pca = SKPCA(n_components=proj_part_dim, whiten=self._hyper_params["whiten"])
|
||||
train_fea = fea[:, st_idx: ed_idx]
|
||||
if self._hyper_params["l2"]:
|
||||
train_fea = normalize(train_fea, norm="l2")
|
||||
pca.fit(train_fea)
|
||||
self.pcas[fea_name] = {
|
||||
"pos": (st_idx, ed_idx),
|
||||
"pca": pca
|
||||
}
|
||||
|
||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||
fea_names = np.sort(list(self.pcas.keys()))
|
||||
ret = list()
|
||||
for fea_name in fea_names:
|
||||
st_idx, ed_idx = self.pcas[fea_name]["pos"][0], self.pcas[fea_name]["pos"][1]
|
||||
pca = self.pcas[fea_name]["pca"]
|
||||
|
||||
ori_fea = fea[:, st_idx: ed_idx]
|
||||
proj_fea = pca.transform(ori_fea)
|
||||
|
||||
ret.append(proj_fea)
|
||||
|
||||
ret = np.concatenate(ret, axis=1)
|
||||
return ret
|
|
@ -0,0 +1,100 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dim_processors_base import DimProcessorBase
|
||||
from ...registry import DIMPROCESSORS
|
||||
from ...utils import feature_loader
|
||||
from sklearn.preprocessing import normalize
|
||||
from sklearn.decomposition import TruncatedSVD as SKSVD
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
@DIMPROCESSORS.register
|
||||
class PartSVD(DimProcessorBase):
|
||||
"""
|
||||
Part SVD will divided whole feature into several parts. Then apply SVD transformation to each part.
|
||||
It is usually used for features that extracted by several feature maps and concatenated together.
|
||||
|
||||
Hyper-Params:
|
||||
proj_dim (int): the dimension after reduction. If it is 0, then no reduction will be done
|
||||
(in SVD, we will minus origin dimension by 1).
|
||||
whiten (bool): whether do whiten for each part.
|
||||
train_fea_dir (str): the path of features for training SVD.
|
||||
l2 (bool): whether do l2-normalization for the training features.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"proj_dim": 0,
|
||||
"whiten": True,
|
||||
"train_fea_dir": "unknown",
|
||||
"l2": True,
|
||||
}
|
||||
|
||||
def __init__(self, feature_names: List[str], hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
feature_names (list): a list of features names to be loaded.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(PartSVD, self).__init__(feature_names, hps)
|
||||
|
||||
self.svds = dict()
|
||||
self._train(self._hyper_params["train_fea_dir"])
|
||||
|
||||
def _train(self, fea_dir: str) -> None:
|
||||
"""
|
||||
Train the part SVD.
|
||||
|
||||
Args:
|
||||
fea_dir (str): the path of features for training part SVD.
|
||||
"""
|
||||
fea, _, pos_info = feature_loader.load(fea_dir, self.feature_names)
|
||||
fea_names = list(pos_info.keys())
|
||||
ori_dim = fea.shape[1]
|
||||
|
||||
already_proj_dim = 0
|
||||
for fea_name in fea_names:
|
||||
st_idx, ed_idx = pos_info[fea_name][0], pos_info[fea_name][1]
|
||||
ori_part_dim = ed_idx - st_idx
|
||||
if self._hyper_params["proj_dim"] == 0:
|
||||
proj_part_dim = ori_part_dim - 1
|
||||
else:
|
||||
ratio = self._hyper_params["proj_dim"] * 1.0 / ori_dim
|
||||
if fea_name != fea_names[-1]:
|
||||
proj_part_dim = int(ori_part_dim * ratio)
|
||||
else:
|
||||
proj_part_dim = self._hyper_params["proj_dim"] - already_proj_dim
|
||||
assert proj_part_dim < ori_part_dim, "reduction dimension can not be distributed to each part!"
|
||||
|
||||
svd = SKSVD(n_components=proj_part_dim)
|
||||
train_fea = fea[:, st_idx: ed_idx]
|
||||
if self._hyper_params["l2"]:
|
||||
train_fea = normalize(train_fea, norm="l2")
|
||||
train_fea = svd.fit_transform(train_fea)
|
||||
std = train_fea.std(axis=0, keepdims=True)
|
||||
self.svds[fea_name] = {
|
||||
"pos": (st_idx, ed_idx),
|
||||
"svd": svd,
|
||||
"std": std
|
||||
}
|
||||
|
||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||
if self._hyper_params["proj_dim"] != 0:
|
||||
ret = np.zeros(shape=(fea.shape[0], self._hyper_params["proj_dim"]))
|
||||
else:
|
||||
ret = np.zeros(shape=(fea.shape[0], fea.shape[1] - len(self.svds)))
|
||||
|
||||
for fea_name in self.svds:
|
||||
st_idx, ed_idx = self.svds[fea_name]["pos"][0], self.svds[fea_name]["pos"][1]
|
||||
svd = self.svds[fea_name]["svd"]
|
||||
|
||||
proj_fea = fea[:, st_idx: ed_idx]
|
||||
proj_fea = svd.transform(proj_fea)
|
||||
if self._hyper_params["whiten"]:
|
||||
proj_fea = proj_fea / (self.svds[fea_name]["std"] + 1e-6)
|
||||
|
||||
ret[:, st_idx: ed_idx] = proj_fea
|
||||
|
||||
return ret
|
||||
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dim_processors_base import DimProcessorBase
|
||||
from ...registry import DIMPROCESSORS
|
||||
from ...utils import feature_loader
|
||||
from sklearn.preprocessing import normalize
|
||||
from sklearn.decomposition import PCA as SKPCA
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
@DIMPROCESSORS.register
|
||||
class PCA(DimProcessorBase):
|
||||
"""
|
||||
Do the PCA transformation for dimension reduction.
|
||||
|
||||
Hyper-Params:
|
||||
proj_dim (int): the dimension after reduction. If it is 0, then no reduction will be done.
|
||||
whiten (bool): whether do whiten.
|
||||
train_fea_dir (str): the path of features for training PCA.
|
||||
l2 (bool): whether do l2-normalization for the training features.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"proj_dim": 0,
|
||||
"whiten": True,
|
||||
"train_fea_dir": "unknown",
|
||||
"l2": True,
|
||||
}
|
||||
|
||||
def __init__(self, feature_names: List[str], hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
feature_names (list): a list of features names to be loaded.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(PCA, self).__init__(feature_names, hps)
|
||||
|
||||
self.pca = SKPCA(n_components=self._hyper_params["proj_dim"], whiten=self._hyper_params["whiten"])
|
||||
self._train(self._hyper_params["train_fea_dir"])
|
||||
|
||||
def _train(self, fea_dir: str) -> None:
|
||||
"""
|
||||
Train the PCA.
|
||||
|
||||
Args:
|
||||
fea_dir (str): the path of features for training PCA.
|
||||
"""
|
||||
train_fea, _, _ = feature_loader.load(fea_dir, self.feature_names)
|
||||
if self._hyper_params["l2"]:
|
||||
train_fea = normalize(train_fea, norm="l2")
|
||||
self.pca.fit(train_fea)
|
||||
|
||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||
ori_fea = fea
|
||||
proj_fea = self.pca.transform(ori_fea)
|
||||
return proj_fea
|
||||
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dim_processors_base import DimProcessorBase
|
||||
from ...registry import DIMPROCESSORS
|
||||
from ...utils import feature_loader
|
||||
from sklearn.preprocessing import normalize
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
@DIMPROCESSORS.register
|
||||
class RMACPCA(DimProcessorBase):
|
||||
"""
|
||||
Do the PCA transformation for R-MAC only.
|
||||
When call this transformation, each part feature is processed by l2-normalize, PCA and l2-normalize.
|
||||
Then the global feature is processed by l2-normalize.
|
||||
|
||||
Hyper-Params:
|
||||
proj_dim: int. The dimension after reduction. If it is 0, then no reduction will be done.
|
||||
whiten: bool, whether do whiten for each part.
|
||||
train_fea: str, feature directory for training PCA.
|
||||
l2 (bool): whether do l2-normalization for the training features.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"proj_dim": 0,
|
||||
"whiten": True,
|
||||
"train_fea_dir": "unknown",
|
||||
"l2": True,
|
||||
}
|
||||
|
||||
def __init__(self, feature_names: List[str], hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
feature_names (list): a list of features names to be loaded.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(RMACPCA, self).__init__(feature_names, hps)
|
||||
|
||||
self.pca = dict()
|
||||
self._train(self._hyper_params["train_fea_dir"])
|
||||
|
||||
def _train(self, fea_dir: str) -> None:
|
||||
"""
|
||||
Train the PCA for R-MAC.
|
||||
|
||||
Args:
|
||||
fea_dir (str): the path of features for training PCA.
|
||||
"""
|
||||
fea, _, pos_info = feature_loader.load(fea_dir, self.feature_names)
|
||||
|
||||
fea_names = np.sort(list(pos_info.keys()))
|
||||
region_feas = list()
|
||||
for fea_name in fea_names:
|
||||
st_idx, ed_idx = pos_info[fea_name][0], pos_info[fea_name][1]
|
||||
region_fea = fea[:, st_idx: ed_idx]
|
||||
region_feas.append(region_fea)
|
||||
|
||||
train_fea = np.concatenate(region_feas, axis=0)
|
||||
if self._hyper_params["l2"]:
|
||||
train_fea = normalize(train_fea, norm="l2")
|
||||
pca = PCA(n_components=self._hyper_params["proj_dim"], whiten=self._hyper_params["whiten"])
|
||||
pca.fit(train_fea)
|
||||
|
||||
self.pca = {
|
||||
"pca": pca,
|
||||
"pos_info": pos_info
|
||||
}
|
||||
|
||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||
pca = self.pca["pca"]
|
||||
pos_info = self.pca["pos_info"]
|
||||
|
||||
fea_names = np.sort(list(pos_info.keys()))
|
||||
|
||||
final_fea = None
|
||||
for fea_name in fea_names:
|
||||
st_idx, ed_idx = pos_info[fea_name][0], pos_info[fea_name][1]
|
||||
region_fea = fea[:, st_idx: ed_idx]
|
||||
region_fea = normalize(region_fea)
|
||||
region_fea = pca.transform(region_fea)
|
||||
region_fea = normalize(region_fea)
|
||||
if final_fea is None:
|
||||
final_fea = region_fea
|
||||
else:
|
||||
final_fea += region_fea
|
||||
final_fea = normalize(final_fea)
|
||||
|
||||
return final_fea
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dim_processors_base import DimProcessorBase
|
||||
from ...registry import DIMPROCESSORS
|
||||
from ...utils import feature_loader
|
||||
from sklearn.preprocessing import normalize
|
||||
from sklearn.decomposition import TruncatedSVD as SKSVD
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
@DIMPROCESSORS.register
|
||||
class SVD(DimProcessorBase):
|
||||
"""
|
||||
Do the SVD transformation for dimension reduction.
|
||||
|
||||
Hyper-Params:
|
||||
proj_dim (int): the dimension after reduction. If it is 0, then no reduction will be done
|
||||
(in SVD, we will minus origin dimension by 1).
|
||||
whiten (bool): whether do whiten for each part.
|
||||
train_fea_dir (str): the path of features for training SVD.
|
||||
l2 (bool): whether do l2-normalization for the training features.
|
||||
"""
|
||||
default_hyper_params = {
|
||||
"proj_dim": 0,
|
||||
"whiten": True,
|
||||
"train_fea_dir": "unknown",
|
||||
"l2": True,
|
||||
}
|
||||
|
||||
def __init__(self, feature_names: List[str], hps: Dict or None = None):
|
||||
"""
|
||||
Args:
|
||||
feature_names (list): a list of features names to be loaded.
|
||||
hps (dict): default hyper parameters in a dict (keys, values).
|
||||
"""
|
||||
super(SVD, self).__init__(feature_names, hps)
|
||||
|
||||
self.svd = SKSVD(n_components=self._hyper_params["proj_dim"])
|
||||
self.std = 0.0
|
||||
self._train(self._hyper_params["train_fea_dir"])
|
||||
|
||||
def _train(self, fea_dir: str) -> None:
|
||||
"""
|
||||
Train the SVD.
|
||||
|
||||
Args:
|
||||
fea_dir: the path of features for training SVD.
|
||||
"""
|
||||
train_fea, _, _ = feature_loader.load(fea_dir, self.feature_names)
|
||||
if self._hyper_params["l2"]:
|
||||
train_fea = normalize(train_fea, norm="l2")
|
||||
train_fea = self.svd.fit_transform(train_fea)
|
||||
self.std = train_fea.std(axis=0, keepdims=True)
|
||||
|
||||
def __call__(self, fea: np.ndarray) -> np.ndarray:
|
||||
ori_fea = fea
|
||||
proj_fea = self.svd.transform(ori_fea)
|
||||
if self._hyper_params["whiten"]:
|
||||
proj_fea = proj_fea / (self.std + 1e-6)
|
||||
return proj_fea
|
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue