mirror of https://github.com/JDAI-CV/fast-reid.git
Compare commits
122 Commits
Author | SHA1 | Date |
---|---|---|
|
c9bc3ceb2f | |
|
817c748e8c | |
|
39887a102e | |
|
afe432b8c0 | |
|
4508251d74 | |
|
31d99b793f | |
|
fb1027de4c | |
|
9208930c11 | |
|
581bddbbdc | |
|
c45d9b2498 | |
|
43da387b77 | |
|
10a5f38aaa | |
|
100830e5ef | |
|
d9d6e19b2c | |
|
f4551a128b | |
|
ee3818e706 | |
|
00ffa10d3f | |
|
2cac19ce31 | |
|
ced654431b | |
|
7e652fea2a | |
|
44d1e04e9a | |
|
7c1269d6c3 | |
|
2fee8327fa | |
|
10b04b75ff | |
|
d792a69f3f | |
|
ac6256887c | |
|
62ad5e1a8b | |
|
4ce04e7cc2 | |
|
f2286e7f55 | |
|
4f90197336 | |
|
7ed6240e2c | |
|
de81b3dbaa | |
|
8f8cbf9411 | |
|
2d2279be6a | |
|
764fa67fe9 | |
|
3c2eeb865d | |
|
0572765085 | |
|
54f96ba78a | |
|
6300bd756e | |
|
c3ac4f504c | |
|
91ff631184 | |
|
8ab3554958 | |
|
256721cfde | |
|
07b8251ccb | |
|
2cabc3428a | |
|
2b65882447 | |
|
dbf1604231 | |
|
ff8a958fff | |
|
46b0681313 | |
|
0c8e3d9805 | |
|
bb6ddbf8b1 | |
|
37ccd3683d | |
|
8276ccf4fd | |
|
e124a9afd3 | |
|
e0ad8c70bc | |
|
fc67350e99 | |
|
1dce15efad | |
|
0da5917064 | |
|
55300730e1 | |
|
44cee30dfc | |
|
9288db6303 | |
|
fb36b23678 | |
|
25cfa88fd9 | |
|
be0a089e1f | |
|
664ba4ae11 | |
|
890224f25c | |
|
9d83550b67 | |
|
15c556c43a | |
|
883fd4aede | |
|
41c3d6ff4d | |
|
9b5af4166e | |
|
cb7a1cb3e1 | |
|
d7c1294d9e | |
|
0cc9fb95a6 | |
|
f57c5764e3 | |
|
68c190b53c | |
|
44ad4b83b1 | |
|
fcfa6800bb | |
|
575aeaec3f | |
|
b9bda486f0 | |
|
69eb044b81 | |
|
96fd58c48f | |
|
52b75b7974 | |
|
e2a1e14bc3 | |
|
2f1836825c | |
|
527b09c696 | |
|
c4412d367f | |
|
819c5a8ab4 | |
|
cf46e5f071 | |
|
3854069f4e | |
|
2f95da9a59 | |
|
0617e3eeb7 | |
|
159494e4a4 | |
|
96b7d3133a | |
|
39e25e0457 | |
|
c9537c97d1 | |
|
254a489eb1 | |
|
50ceeb0832 | |
|
a8ae0fd1e9 | |
|
5f7d3d586e | |
|
ebc375e51e | |
|
1666c82db4 | |
|
d2f9450041 | |
|
dc5f1924dc | |
|
2f877d239b | |
|
77a91b1204 | |
|
db8670db63 | |
|
6b4b935ce4 | |
|
1c08b7ac88 | |
|
885dc96608 | |
|
1ed1a13eed | |
|
69454860d1 | |
|
b786001ebd | |
|
63ed07ab0d | |
|
efcc2f28cb | |
|
ef6ebf451b | |
|
a53fd17874 | |
|
b5c3c0a24d | |
|
274cd81dab | |
|
6ab40bd43a | |
|
e26182e6ec | |
|
fdaa4b1a84 |
.github
workflows
configs
Market1501
VERIWild
VeRi
VehicleID
datasets
docker
fastreid
data
transforms
modeling
Binary file not shown.
After Width: | Height: | Size: 272 KiB |
|
@ -0,0 +1,19 @@
|
|||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "30 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v3
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
|
@ -1,8 +1,37 @@
|
|||
.idea
|
||||
|
||||
logs
|
||||
|
||||
# compilation and distribution
|
||||
__pycache__
|
||||
.DS_Store
|
||||
.vscode
|
||||
_ext
|
||||
*.pyc
|
||||
*.pyd
|
||||
*.so
|
||||
logs/
|
||||
.ipynb_checkpoints
|
||||
logs
|
||||
*.dll
|
||||
*.egg-info/
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
|
||||
# pytorch/python/numpy formats
|
||||
*.pth
|
||||
*.pkl
|
||||
*.npy
|
||||
*.ts
|
||||
model_ts*.txt
|
||||
|
||||
# ipython/jupyter notebooks
|
||||
*.ipynb
|
||||
**/.ipynb_checkpoints/
|
||||
|
||||
# Editor temporaries
|
||||
*.swn
|
||||
*.swo
|
||||
*.swp
|
||||
*~
|
||||
|
||||
# editor settings
|
||||
.idea
|
||||
.vscode
|
||||
_darcs
|
||||
.DS_Store
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Changelog
|
||||
|
||||
### v1.3
|
||||
|
||||
#### New Features
|
||||
- Vision Transformer backbone, see config in `configs/Market1501/bagtricks_vit.yml`
|
||||
- Self-Distillation with EMA update
|
||||
- Gradient Clip
|
||||
|
||||
#### Improvements
|
||||
- Faster dataloader with pre-fetch thread and cuda stream
|
||||
- Optimize DDP training speed by removing `find_unused_parameters` in DDP
|
||||
|
||||
|
||||
### v1.2 (06/04/2021)
|
||||
|
||||
#### New Features
|
||||
|
||||
- Multiple machine training support
|
||||
- [RepVGG](https://github.com/DingXiaoH/RepVGG) backbone
|
||||
- [Partial FC](projects/FastFace)
|
||||
|
||||
#### Improvements
|
||||
|
||||
- Torch2trt pipeline
|
||||
- Decouple linear transforms and softmax
|
||||
- config decorator
|
||||
|
||||
### v1.1 (29/01/2021)
|
||||
|
||||
#### New Features
|
||||
|
||||
- NAIC20(reid track) [1-st solution](projects/NAIC20)
|
||||
- Multi-teacher Knowledge Distillation
|
||||
- TRT network definition APIs in [FastRT](projects/FastRT)
|
||||
|
||||
#### Bug Fixes
|
||||
|
||||
#### Improvements
|
|
@ -21,7 +21,7 @@ You may want to use it as a reference to write your own training script.
|
|||
To train a model with "train_net.py", first setup up the corresponding datasets following [datasets/README.md](https://github.com/JDAI-CV/fast-reid/tree/master/datasets), then run:
|
||||
|
||||
```bash
|
||||
./tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml MODEL.DEVICE "cuda:0"
|
||||
python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml MODEL.DEVICE "cuda:0"
|
||||
```
|
||||
|
||||
The configs are made for 1-GPU training.
|
||||
|
@ -29,14 +29,34 @@ The configs are made for 1-GPU training.
|
|||
If you want to train model with 4 GPUs, you can run:
|
||||
|
||||
```bash
|
||||
python tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --num-gpus 4
|
||||
python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --num-gpus 4
|
||||
```
|
||||
|
||||
If you want to train model with multiple machines, you can run:
|
||||
|
||||
```
|
||||
# machine 1
|
||||
export GLOO_SOCKET_IFNAME=eth0
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
|
||||
python3 tools/train_net.py --config-file configs/Market1501/bagtricks_R50.yml \
|
||||
--num-gpus 4 --num-machines 2 --machine-rank 0 --dist-url tcp://ip:port
|
||||
|
||||
# machine 2
|
||||
export GLOO_SOCKET_IFNAME=eth0
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
|
||||
python3 tools/train_net.py --config-file configs/Market1501/bagtricks_R50.yml \
|
||||
--num-gpus 4 --num-machines 2 --machine-rank 1 --dist-url tcp://ip:port
|
||||
```
|
||||
|
||||
Make sure the dataset path and code are the same in different machines, and machines can communicate with each other.
|
||||
|
||||
To evaluate a model's performance, use
|
||||
|
||||
```bash
|
||||
python tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \
|
||||
python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \
|
||||
MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0"
|
||||
```
|
||||
|
||||
For more options, see `./tools/train_net.py -h`.
|
||||
For more options, see `python3 tools/train_net.py -h`.
|
|
@ -21,8 +21,9 @@
|
|||
conda create -n fastreid python=3.7
|
||||
conda activate fastreid
|
||||
conda install pytorch==1.6.0 torchvision tensorboard -c pytorch
|
||||
pip install -r requirements
|
||||
pip install -r docs/requirements.txt
|
||||
```
|
||||
|
||||
# Set up with Dockder
|
||||
comming soon
|
||||
|
||||
Please check the [docker folder](docker)
|
|
@ -154,7 +154,7 @@ Bag of Specials(BoS):
|
|||
|
||||
| Method | Pretrained | Rank@1 | mAP | mINP | download |
|
||||
| :---: | :---: | :---: |:---: | :---: | :---:|
|
||||
| [SBS(R50-ibn)](https://github.com/JDAI-CV/fast-reid/blob/master/configs/VeRi/sbs_R50-ibn.yml) | ImageNet | 97.0% | 81.9% | 46.3% | -|
|
||||
| [SBS(R50-ibn)](https://github.com/JDAI-CV/fast-reid/blob/master/configs/VeRi/sbs_R50-ibn.yml) | ImageNet | 97.0% | 81.9% | 46.3% | [model](https://github.com/JDAI-CV/fast-reid/releases/download/v0.1.1/veri_sbs_R50-ibn.pth) |
|
||||
|
||||
### VehicleID Baseline
|
||||
|
||||
|
@ -193,7 +193,7 @@ Test protocol: 10-fold cross-validation; trained on 4 NVIDIA P40 GPU.
|
|||
<td align="center">96.0%</td>
|
||||
<td align="center">80.6%</td>
|
||||
<td align="center">93.9%</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center"><a href="https://github.com/JDAI-CV/fast-reid/releases/download/v0.1.1/vehicleid_bot_R50-ibn.pth">model</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
@ -241,7 +241,7 @@ Test protocol: Trained on 4 NVIDIA P40 GPU.
|
|||
<td align="center">92.5%</td>
|
||||
<td align="center">77.3%</td>
|
||||
<td align="center">49.8%</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center"><a href="https://github.com/JDAI-CV/fast-reid/releases/download/v0.1.1/veriwild_bot_R50-ibn.pth">model</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
54
README.md
54
README.md
|
@ -1,50 +1,72 @@
|
|||
<img src=".github/FastReID-Logo.png" width="300" >
|
||||
|
||||
FastReID is a research platform that implements state-of-the-art re-identification algorithms. It is a groud-up rewrite of the previous version, [reid strong baseline](https://github.com/michuanhaohao/reid-strong-baseline).
|
||||
[](https://gitter.im/fast-reid/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
|
||||
|
||||
Gitter: [fast-reid/community](https://gitter.im/fast-reid/community?utm_source=share-link&utm_medium=link&utm_campaign=share-link)
|
||||
|
||||
Wechat:
|
||||
|
||||
<img src=".github/wechat_group.png" width="150" >
|
||||
|
||||
|
||||
FastReID is a research platform that implements state-of-the-art re-identification algorithms. It is a ground-up rewrite of the previous version, [reid strong baseline](https://github.com/michuanhaohao/reid-strong-baseline).
|
||||
|
||||
## What's New
|
||||
|
||||
- [Jan 2021] FastReID V1.0 has been released!🎉
|
||||
Support many tasks beyond reid, such image retrieval and face recognition. See [projects](https://github.com/JDAI-CV/fast-reid/tree/master/projects).
|
||||
- [Oct 2020] Added the [Hyper-Parameter Optimization](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastTune) based on fastreid. See `projects/FastTune`.
|
||||
- [Sep 2020] Added the [person attribute recognition](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastAttr) based on fastreid. See `projects/FastAttr`.
|
||||
- [Sep 2021] [DG-ReID](https://github.com/xiaomingzhid/sskd) is updated, you can check the [paper](https://arxiv.org/pdf/2108.05045.pdf).
|
||||
- [June 2021] [Contiguous parameters](https://github.com/PhilJd/contiguous_pytorch_params) is supported, now it can
|
||||
accelerate ~20%.
|
||||
- [May 2021] Vision Transformer backbone supported, see `configs/Market1501/bagtricks_vit.yml`.
|
||||
- [Apr 2021] Partial FC supported in [FastFace](projects/FastFace)!
|
||||
- [Jan 2021] TRT network definition APIs in [FastRT](projects/FastRT) has been released!
|
||||
Thanks for [Darren](https://github.com/TCHeish)'s contribution.
|
||||
- [Jan 2021] NAIC20(reid track) [1-st solution](projects/NAIC20) based on fastreid has been released!
|
||||
- [Jan 2021] FastReID V1.0 has been released!🎉
|
||||
Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0).
|
||||
- [Oct 2020] Added the [Hyper-Parameter Optimization](projects/FastTune) based on fastreid. See `projects/FastTune`.
|
||||
- [Sep 2020] Added the [person attribute recognition](projects/FastAttr) based on fastreid. See `projects/FastAttr`.
|
||||
- [Sep 2020] Automatic Mixed Precision training is supported with `apex`. Set `cfg.SOLVER.FP16_ENABLED=True` to switch it on.
|
||||
- [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
|
||||
- [Aug 2020] [Model Distillation](projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
|
||||
- [Aug 2020] ONNX/TensorRT converter is supported.
|
||||
- [Jul 2020] Distributed training with multiple GPUs, it trains much faster.
|
||||
- Includes more features such as circle loss, abundant visualization methods and evaluation metrics, SoTA results on conventional, cross-domain, partial and vehicle re-id, testing on multi-datasets simultaneously, etc.
|
||||
- Can be used as a library to support [different projects](https://github.com/JDAI-CV/fast-reid/tree/master/projects) on top of it. We'll open source more research projects in this way.
|
||||
- Can be used as a library to support [different projects](projects) on top of it. We'll open source more research projects in this way.
|
||||
- Remove [ignite](https://github.com/pytorch/ignite)(a high-level library) dependency and powered by [PyTorch](https://pytorch.org/).
|
||||
|
||||
We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29/fastreid.html) about this toolbox.
|
||||
We write a [fastreid intro](https://l1aoxingyu.github.io/blogpages/reid/fastreid/2020/05/29/fastreid.html)
|
||||
and [fastreid v1.0](https://l1aoxingyu.github.io/blogpages/reid/fastreid/2021/04/28/fastreid-v1.html) about this toolbox.
|
||||
|
||||
## Changelog
|
||||
|
||||
Please refer to [changelog.md](CHANGELOG.md) for details and release history.
|
||||
|
||||
## Installation
|
||||
|
||||
See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/blob/master/docs/INSTALL.md).
|
||||
See [INSTALL.md](INSTALL.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
|
||||
|
||||
See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/docs/GETTING_STARTED.md).
|
||||
See [GETTING_STARTED.md](GETTING_STARTED.md).
|
||||
|
||||
Learn more at out [documentation](). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
|
||||
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](projects) for some projects that are build on top of fastreid.
|
||||
|
||||
## Model Zoo and Baselines
|
||||
|
||||
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/docs/MODEL_ZOO.md).
|
||||
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](MODEL_ZOO.md).
|
||||
|
||||
## Deployment
|
||||
|
||||
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](https://github.com/JDAI-CV/fast-reid/blob/master/tools/deploy).
|
||||
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](tools/deploy).
|
||||
|
||||
## License
|
||||
|
||||
Fastreid is released under the [Apache 2.0 license](https://github.com/JDAI-CV/fast-reid/blob/master/LICENSE).
|
||||
Fastreid is released under the [Apache 2.0 license](LICENSE).
|
||||
|
||||
## Citing Fastreid
|
||||
## Citing FastReID
|
||||
|
||||
If you use Fastreid in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
|
||||
If you use FastReID in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
|
||||
|
||||
```BibTeX
|
||||
@article{he2020fastreid,
|
||||
|
|
|
@ -5,7 +5,7 @@ MODEL:
|
|||
WITH_NL: True
|
||||
|
||||
HEADS:
|
||||
POOL_LAYER: gempool
|
||||
POOL_LAYER: GeneralizedMeanPooling
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss")
|
||||
|
|
|
@ -8,8 +8,8 @@ MODEL:
|
|||
|
||||
HEADS:
|
||||
NECK_FEAT: after
|
||||
POOL_LAYER: gempoolP
|
||||
CLS_LAYER: circleSoftmax
|
||||
POOL_LAYER: GeneralizedMeanPoolingP
|
||||
CLS_LAYER: CircleSoftmax
|
||||
SCALE: 64
|
||||
MARGIN: 0.35
|
||||
|
||||
|
@ -29,20 +29,20 @@ INPUT:
|
|||
SIZE_TRAIN: [ 384, 128 ]
|
||||
SIZE_TEST: [ 384, 128 ]
|
||||
|
||||
DO_AUTOAUG: True
|
||||
AUTOAUG_PROB: 0.1
|
||||
AUTOAUG:
|
||||
ENABLED: True
|
||||
PROB: 0.1
|
||||
|
||||
DATALOADER:
|
||||
NUM_INSTANCE: 16
|
||||
|
||||
SOLVER:
|
||||
FP16_ENABLED: False
|
||||
AMP:
|
||||
ENABLED: True
|
||||
OPT: Adam
|
||||
MAX_EPOCH: 60
|
||||
BASE_LR: 0.00035
|
||||
BIAS_LR_FACTOR: 1.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
SCHED: CosineAnnealingLR
|
||||
|
@ -50,7 +50,7 @@ SOLVER:
|
|||
ETA_MIN_LR: 0.0000007
|
||||
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_EPOCHS: 10
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
FREEZE_ITERS: 1000
|
||||
|
||||
|
|
|
@ -14,9 +14,9 @@ MODEL:
|
|||
NAME: EmbeddingHead
|
||||
NORM: BN
|
||||
WITH_BNNECK: True
|
||||
POOL_LAYER: avgpool
|
||||
POOL_LAYER: GlobalAvgPool
|
||||
NECK_FEAT: before
|
||||
CLS_LAYER: linear
|
||||
CLS_LAYER: Linear
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss",)
|
||||
|
@ -34,25 +34,30 @@ MODEL:
|
|||
INPUT:
|
||||
SIZE_TRAIN: [ 256, 128 ]
|
||||
SIZE_TEST: [ 256, 128 ]
|
||||
|
||||
REA:
|
||||
ENABLED: True
|
||||
PROB: 0.5
|
||||
DO_PAD: True
|
||||
|
||||
FLIP:
|
||||
ENABLED: True
|
||||
|
||||
PADDING:
|
||||
ENABLED: True
|
||||
|
||||
DATALOADER:
|
||||
PK_SAMPLER: True
|
||||
NAIVE_WAY: True
|
||||
SAMPLER_TRAIN: NaiveIdentitySampler
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
FP16_ENABLED: True
|
||||
AMP:
|
||||
ENABLED: True
|
||||
OPT: Adam
|
||||
MAX_EPOCH: 120
|
||||
BASE_LR: 0.00035
|
||||
BIAS_LR_FACTOR: 2.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
WEIGHT_DECAY_NORM: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
SCHED: MultiStepLR
|
||||
|
@ -60,7 +65,7 @@ SOLVER:
|
|||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_EPOCHS: 10
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
CHECKPOINT_PERIOD: 30
|
||||
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
|
||||
MODEL:
|
||||
META_ARCHITECTURE: Baseline
|
||||
PIXEL_MEAN: [127.5, 127.5, 127.5]
|
||||
PIXEL_STD: [127.5, 127.5, 127.5]
|
||||
|
||||
BACKBONE:
|
||||
NAME: build_vit_backbone
|
||||
DEPTH: base
|
||||
FEAT_DIM: 768
|
||||
PRETRAIN: True
|
||||
PRETRAIN_PATH: /export/home/lxy/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth
|
||||
STRIDE_SIZE: (16, 16)
|
||||
DROP_PATH_RATIO: 0.1
|
||||
DROP_RATIO: 0.0
|
||||
ATT_DROP_RATE: 0.0
|
||||
|
||||
HEADS:
|
||||
NAME: EmbeddingHead
|
||||
NORM: BN
|
||||
WITH_BNNECK: True
|
||||
POOL_LAYER: Identity
|
||||
NECK_FEAT: before
|
||||
CLS_LAYER: Linear
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss",)
|
||||
|
||||
CE:
|
||||
EPSILON: 0. # no smooth
|
||||
SCALE: 1.
|
||||
|
||||
TRI:
|
||||
MARGIN: 0.0
|
||||
HARD_MINING: True
|
||||
NORM_FEAT: False
|
||||
SCALE: 1.
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [ 256, 128 ]
|
||||
SIZE_TEST: [ 256, 128 ]
|
||||
|
||||
REA:
|
||||
ENABLED: True
|
||||
PROB: 0.5
|
||||
|
||||
FLIP:
|
||||
ENABLED: True
|
||||
|
||||
PADDING:
|
||||
ENABLED: True
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER_TRAIN: NaiveIdentitySampler
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
AMP:
|
||||
ENABLED: False
|
||||
OPT: SGD
|
||||
MAX_EPOCH: 120
|
||||
BASE_LR: 0.008
|
||||
WEIGHT_DECAY: 0.0001
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
SCHED: CosineAnnealingLR
|
||||
ETA_MIN_LR: 0.000016
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 1000
|
||||
|
||||
CLIP_GRADIENTS:
|
||||
ENABLED: True
|
||||
|
||||
CHECKPOINT_PERIOD: 30
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 5
|
||||
IMS_PER_BATCH: 128
|
||||
|
||||
CUDNN_BENCHMARK: True
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("Market1501",)
|
||||
TESTS: ("Market1501",)
|
||||
|
||||
OUTPUT_DIR: logs/market1501/sbs_vit_base
|
|
@ -7,8 +7,10 @@ INPUT:
|
|||
MODEL:
|
||||
BACKBONE:
|
||||
WITH_IBN: True
|
||||
|
||||
HEADS:
|
||||
POOL_LAYER: gempool
|
||||
POOL_LAYER: GeneralizedMeanPooling
|
||||
|
||||
LOSSES:
|
||||
TRI:
|
||||
HARD_MINING: False
|
||||
|
@ -19,15 +21,15 @@ DATASETS:
|
|||
TESTS: ("SmallVeRiWild", "MediumVeRiWild", "LargeVeRiWild",)
|
||||
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 128
|
||||
MAX_ITER: 60
|
||||
STEPS: [30, 50]
|
||||
WARMUP_EPOCHS: 10
|
||||
IMS_PER_BATCH: 512 # 512 For 4 GPUs
|
||||
MAX_EPOCH: 120
|
||||
STEPS: [30, 70, 90]
|
||||
WARMUP_ITERS: 5000
|
||||
|
||||
CHECKPOINT_PERIOD: 20
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 20
|
||||
EVAL_PERIOD: 10
|
||||
IMS_PER_BATCH: 128
|
||||
|
||||
OUTPUT_DIR: logs/veriwild/bagtricks_R50-ibn_4gpu
|
||||
|
|
|
@ -7,6 +7,7 @@ INPUT:
|
|||
MODEL:
|
||||
BACKBONE:
|
||||
WITH_IBN: True
|
||||
WITH_NL: True
|
||||
|
||||
SOLVER:
|
||||
OPT: SGD
|
||||
|
@ -14,19 +15,21 @@ SOLVER:
|
|||
ETA_MIN_LR: 7.7e-5
|
||||
|
||||
IMS_PER_BATCH: 64
|
||||
MAX_ITER: 60
|
||||
DELAY_ITERS: 30
|
||||
WARMUP_EPOCHS: 10
|
||||
FREEZE_ITERS: 10
|
||||
MAX_EPOCH: 60
|
||||
WARMUP_ITERS: 3000
|
||||
FREEZE_ITERS: 3000
|
||||
|
||||
CHECKPOINT_PERIOD: 20
|
||||
CHECKPOINT_PERIOD: 10
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("VeRi",)
|
||||
TESTS: ("VeRi",)
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER_TRAIN: BalancedIdentitySampler
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 20
|
||||
IMS_PER_BATCH: 128
|
||||
EVAL_PERIOD: 10
|
||||
IMS_PER_BATCH: 256
|
||||
|
||||
OUTPUT_DIR: logs/veri/sbs_R50-ibn
|
||||
|
|
|
@ -8,7 +8,8 @@ MODEL:
|
|||
BACKBONE:
|
||||
WITH_IBN: True
|
||||
HEADS:
|
||||
POOL_LAYER: gempool
|
||||
POOL_LAYER: GeneralizedMeanPooling
|
||||
|
||||
LOSSES:
|
||||
TRI:
|
||||
HARD_MINING: False
|
||||
|
@ -22,9 +23,9 @@ SOLVER:
|
|||
BIAS_LR_FACTOR: 1.
|
||||
|
||||
IMS_PER_BATCH: 512
|
||||
MAX_ITER: 60
|
||||
MAX_EPOCH: 60
|
||||
STEPS: [30, 50]
|
||||
WARMUP_EPOCHS: 10
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
CHECKPOINT_PERIOD: 20
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ You can set the location for builtin datasets by `export FASTREID_DATASETS=/path
|
|||
|
||||
The [model zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md) contains configs and models that use these buildin datasets.
|
||||
|
||||
## Expected dataset structure for Market1501
|
||||
## Expected dataset structure for [Market1501](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Zheng_Scalable_Person_Re-Identification_ICCV_2015_paper.pdf)
|
||||
|
||||
1. Download dataset to `datasets/` from [baidu pan](https://pan.baidu.com/s/1ntIi2Op) or [google driver](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view)
|
||||
2. Extract dataset. The dataset structure would like:
|
||||
|
@ -18,7 +18,7 @@ datasets/
|
|||
bounding_box_train/
|
||||
```
|
||||
|
||||
## Expected dataset structure for DukeMTMC
|
||||
## Expected dataset structure for [DukeMTMC-reID](https://openaccess.thecvf.com/content_ICCV_2017/papers/Zheng_Unlabeled_Samples_Generated_ICCV_2017_paper.pdf)
|
||||
|
||||
1. Download datasets to `datasets/`
|
||||
2. Extract dataset. The dataset structure would like:
|
||||
|
@ -30,7 +30,7 @@ datasets/
|
|||
bounding_box_test/
|
||||
```
|
||||
|
||||
## Expected dataset structure for MSMT17
|
||||
## Expected dataset structure for [MSMT17](https://arxiv.org/abs/1711.08565)
|
||||
|
||||
1. Download datasets to `datasets/`
|
||||
2. Extract dataset. The dataset structure would like:
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
|
||||
We provide a command line tool to run a simple demo of builtin models.
|
||||
|
||||
You can run this command to get cosine similarites between different images
|
||||
You can run this command to get cosine similarites between different images
|
||||
|
||||
```bash
|
||||
cd demo/
|
||||
sh run_demo.sh
|
||||
```
|
||||
python demo/visualize_result.py --config-file logs/dukemtmc/mgn_R50-ibn/config.yaml \
|
||||
--parallel --vis-label --dataset-name DukeMTMC --output logs/mgn_duke_vis \
|
||||
--opts MODEL.WEIGHTS logs/dukemtmc/mgn_R50-ibn/model_final.pth
|
||||
```
|
||||
|
|
18
demo/demo.py
18
demo/demo.py
|
@ -9,20 +9,23 @@ import glob
|
|||
import os
|
||||
import sys
|
||||
|
||||
import torch.nn.functional as F
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from torch.backends import cudnn
|
||||
|
||||
sys.path.append('..')
|
||||
sys.path.append('.')
|
||||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
from predictor import FeatureExtractionDemo
|
||||
|
||||
# import some modules added in project like this below
|
||||
# from projects.PartialReID.partialreid import *
|
||||
# sys.path.append("projects/PartialReID")
|
||||
# from partialreid import *
|
||||
|
||||
cudnn.benchmark = True
|
||||
setup_logger(name="fastreid")
|
||||
|
@ -70,6 +73,13 @@ def get_parser():
|
|||
return parser
|
||||
|
||||
|
||||
def postprocess(features):
|
||||
# Normalize feature to compute cosine distance
|
||||
features = F.normalize(features)
|
||||
features = features.cpu().data.numpy()
|
||||
return features
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_parser().parse_args()
|
||||
cfg = setup_cfg(args)
|
||||
|
@ -83,5 +93,5 @@ if __name__ == '__main__':
|
|||
for path in tqdm.tqdm(args.input):
|
||||
img = cv2.imread(path)
|
||||
feat = demo.run_on_image(img)
|
||||
feat = feat.numpy()
|
||||
np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
|
||||
feat = postprocess(feat)
|
||||
np.save(os.path.join(args.output, os.path.basename(path).split('.')[0] + '.npy'), feat)
|
||||
|
|
|
@ -6,11 +6,11 @@
|
|||
|
||||
import atexit
|
||||
import bisect
|
||||
from collections import deque
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from collections import deque
|
||||
|
||||
from fastreid.engine import DefaultPredictor
|
||||
|
||||
|
@ -70,16 +70,16 @@ class FeatureExtractionDemo(object):
|
|||
if cnt >= buffer_size:
|
||||
batch = batch_data.popleft()
|
||||
predictions = self.predictor.get()
|
||||
yield predictions, batch["targets"].numpy(), batch["camids"].numpy()
|
||||
yield predictions, batch["targets"].cpu().numpy(), batch["camids"].cpu().numpy()
|
||||
|
||||
while len(batch_data):
|
||||
batch = batch_data.popleft()
|
||||
predictions = self.predictor.get()
|
||||
yield predictions, batch["targets"].numpy(), batch["camids"].numpy()
|
||||
yield predictions, batch["targets"].cpu().numpy(), batch["camids"].cpu().numpy()
|
||||
else:
|
||||
for batch in data_loader:
|
||||
predictions = self.predictor(batch["images"])
|
||||
yield predictions, batch["targets"].numpy(), batch["camids"].numpy()
|
||||
yield predictions, batch["targets"].cpu().numpy(), batch["camids"].cpu().numpy()
|
||||
|
||||
|
||||
class AsyncPredictor:
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
python demo/visualize_result.py --config-file logs/dukemtmc/mgn_R50-ibn/config.yaml \
|
||||
--parallel --vis-label --dataset-name 'DukeMTMC' --output logs/mgn_duke_vis \
|
||||
--opts MODEL.WEIGHTS logs/dukemtmc/mgn_R50-ibn/model_final.pth
|
|
@ -24,7 +24,8 @@ from fastreid.utils.visualizer import Visualizer
|
|||
|
||||
# import some modules added in project
|
||||
# for example, add partial reid like this below
|
||||
# from projects.PartialReID.partialreid import *
|
||||
# sys.path.append("projects/PartialReID")
|
||||
# from partialreid import *
|
||||
|
||||
cudnn.benchmark = True
|
||||
setup_logger(name="fastreid")
|
||||
|
@ -35,6 +36,7 @@ logger = logging.getLogger('fastreid.visualize_result')
|
|||
def setup_cfg(args):
|
||||
# load config from file and command-line arguments
|
||||
cfg = get_cfg()
|
||||
# add_partialreid_config(cfg)
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
|
@ -100,7 +102,7 @@ def get_parser():
|
|||
if __name__ == '__main__':
|
||||
args = get_parser().parse_args()
|
||||
cfg = setup_cfg(args)
|
||||
test_loader, num_query = build_reid_test_loader(cfg, args.dataset_name)
|
||||
test_loader, num_query = build_reid_test_loader(cfg, dataset_name=args.dataset_name)
|
||||
demo = FeatureExtractionDemo(cfg, parallel=args.parallel)
|
||||
|
||||
logger.info("Start extracting image features")
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
FROM nvidia/cuda:10.1-cudnn7-devel
|
||||
|
||||
# https://github.com/NVIDIA/nvidia-docker/issues/1632
|
||||
RUN rm /etc/apt/sources.list.d/cuda.list
|
||||
RUN rm /etc/apt/sources.list.d/nvidia-ml.list
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3-opencv ca-certificates python3-dev git wget sudo ninja-build
|
||||
RUN ln -sv /usr/bin/python3 /usr/bin/python
|
||||
|
||||
# create a non-root user
|
||||
ARG USER_ID=1000
|
||||
RUN useradd -m --no-log-init --system --uid ${USER_ID} appuser -g sudo
|
||||
RUN echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers
|
||||
USER appuser
|
||||
WORKDIR /home/appuser
|
||||
|
||||
# https://github.com/facebookresearch/detectron2/issues/3933
|
||||
ENV PATH="/home/appuser/.local/bin:${PATH}"
|
||||
RUN wget https://bootstrap.pypa.io/pip/3.6/get-pip.py && \
|
||||
python3 get-pip.py --user && \
|
||||
rm get-pip.py
|
||||
|
||||
# install dependencies
|
||||
# See https://pytorch.org/ for other options if you use a different version of CUDA
|
||||
RUN pip install --user tensorboard cmake # cmake from apt-get is too old
|
||||
RUN pip install --user torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/cu101/torch_stable.html
|
||||
RUN pip install --user -i https://pypi.tuna.tsinghua.edu.cn/simple tensorboard opencv-python cython yacs termcolor scikit-learn tabulate gdown gpustat faiss-gpu ipdb h5py
|
|
@ -0,0 +1,22 @@
|
|||
# Use the container
|
||||
|
||||
```shell script
|
||||
cd docker/
|
||||
# Build:
|
||||
docker build -t=fastreid:v0 .
|
||||
# Launch (requires GPUs)
|
||||
nvidia-docker run -v server_path:docker_path --name=fastreid --net=host --ipc=host -it fastreid:v0 /bin/sh
|
||||
```
|
||||
|
||||
## Install new dependencies
|
||||
|
||||
Add the following to `Dockerfile` to make persist changes.
|
||||
```shell script
|
||||
RUN sudo apt-get update && sudo apt-get install -y vim
|
||||
```
|
||||
|
||||
Or run them in the container to make temporary changes.
|
||||
|
||||
## A more complete docker container
|
||||
|
||||
If you want to use a complete docker container which contains many useful tools, you can check my development environment [Dockerfile](https://github.com/L1aoXingyu/fastreid_docker)
|
|
@ -0,0 +1 @@
|
|||
_build
|
|
@ -0,0 +1,19 @@
|
|||
# Minimal makefile for Sphinx documentation
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
|
||||
# You can set these variables from the command line.
|
||||
SPHINXOPTS =
|
||||
SPHINXBUILD = sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
|
@ -0,0 +1,16 @@
|
|||
# Read the docs:
|
||||
|
||||
The latest documentation built from this directory is available at [detectron2.readthedocs.io](https://detectron2.readthedocs.io/).
|
||||
Documents in this directory are not meant to be read on github.
|
||||
|
||||
# Build the docs:
|
||||
|
||||
1. Install detectron2 according to [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md).
|
||||
2. Install additional libraries required to build docs:
|
||||
- docutils==0.16
|
||||
- Sphinx==3.0.0
|
||||
- recommonmark==0.6.0
|
||||
- sphinx_rtd_theme
|
||||
- mock
|
||||
|
||||
3. Run `make html` from this directory.
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* some extra css to make markdown look similar between github/sphinx
|
||||
*/
|
||||
|
||||
/*
|
||||
* Below is for install.md:
|
||||
*/
|
||||
.rst-content code {
|
||||
white-space: pre;
|
||||
border: 0px;
|
||||
}
|
||||
|
||||
.rst-content th {
|
||||
border: 1px solid #e1e4e5;
|
||||
}
|
||||
|
||||
.rst-content th p {
|
||||
/* otherwise will be default 24px for regular paragraph */
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
div.section > details {
|
||||
padding-bottom: 1em;
|
||||
}
|
|
@ -0,0 +1,356 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file does only contain a selection of the most common options. For a
|
||||
# full list see the documentation:
|
||||
# http://www.sphinx-doc.org/en/master/config
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
from unittest import mock
|
||||
from sphinx.domains import Domain
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
import sphinx_rtd_theme
|
||||
|
||||
|
||||
class GithubURLDomain(Domain):
|
||||
"""
|
||||
Resolve certain links in markdown files to github source.
|
||||
"""
|
||||
|
||||
name = "githuburl"
|
||||
ROOT = "https://github.com/JDAI-CV/fast-reid/tree/master"
|
||||
LINKED_DOC = ["tutorials/install", "tutorials/getting_started"]
|
||||
|
||||
def resolve_any_xref(self, env, fromdocname, builder, target, node, contnode):
|
||||
github_url = None
|
||||
if not target.endswith("html") and target.startswith("../../"):
|
||||
url = target.replace("../", "")
|
||||
github_url = url
|
||||
if fromdocname in self.LINKED_DOC:
|
||||
# unresolved links in these docs are all github links
|
||||
github_url = target
|
||||
|
||||
if github_url is not None:
|
||||
if github_url.endswith("MODEL_ZOO") or github_url.endswith("README"):
|
||||
# bug of recommonmark.
|
||||
# https://github.com/readthedocs/recommonmark/blob/ddd56e7717e9745f11300059e4268e204138a6b1/recommonmark/parser.py#L152-L155
|
||||
github_url += ".md"
|
||||
print("Ref {} resolved to github:{}".format(target, github_url))
|
||||
contnode["refuri"] = self.ROOT + github_url
|
||||
return [("githuburl:any", contnode)]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
# to support markdown
|
||||
from recommonmark.parser import CommonMarkParser
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../"))
|
||||
os.environ["DOC_BUILDING"] = "True"
|
||||
DEPLOY = os.environ.get("READTHEDOCS") == "True"
|
||||
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
# fmt: off
|
||||
try:
|
||||
import torch # noqa
|
||||
except ImportError:
|
||||
for m in [
|
||||
"torch", "torchvision", "torch.nn", "torch.nn.parallel", "torch.distributed", "torch.multiprocessing", "torch.autograd",
|
||||
"torch.autograd.function", "torch.nn.modules", "torch.nn.modules.utils", "torch.utils", "torch.utils.data", "torch.onnx",
|
||||
"torchvision", "torchvision.ops",
|
||||
]:
|
||||
sys.modules[m] = mock.Mock(name=m)
|
||||
sys.modules['torch'].__version__ = "1.5" # fake version
|
||||
|
||||
for m in [
|
||||
"cv2", "scipy", "portalocker",
|
||||
"google", "google.protobuf", "google.protobuf.internal", "onnx",
|
||||
]:
|
||||
sys.modules[m] = mock.Mock(name=m)
|
||||
# fmt: on
|
||||
sys.modules["cv2"].__version__ = "3.4"
|
||||
|
||||
import fastreid # isort: skip
|
||||
|
||||
|
||||
project = "fastreid"
|
||||
copyright = "2019-2020, fastreid contributors"
|
||||
author = "fastreid contributors"
|
||||
|
||||
# The short X.Y version
|
||||
version = fastreid.__version__
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = version
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#
|
||||
needs_sphinx = "3.0"
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
"recommonmark",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.todo",
|
||||
"sphinx.ext.coverage",
|
||||
"sphinx.ext.mathjax",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.githubpages",
|
||||
]
|
||||
|
||||
# -- Configurations for plugins ------------
|
||||
napoleon_google_docstring = True
|
||||
napoleon_include_init_with_doc = True
|
||||
napoleon_include_special_with_doc = True
|
||||
napoleon_numpy_docstring = False
|
||||
napoleon_use_rtype = False
|
||||
autodoc_inherit_docstrings = False
|
||||
autodoc_member_order = "bysource"
|
||||
|
||||
if DEPLOY:
|
||||
intersphinx_timeout = 10
|
||||
else:
|
||||
# skip this when building locally
|
||||
intersphinx_timeout = 0.1
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3.6", None),
|
||||
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
|
||||
"torch": ("https://pytorch.org/docs/master/", None),
|
||||
}
|
||||
# -------------------------
|
||||
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ["_templates"]
|
||||
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = "index"
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
#
|
||||
# This is also used if you do content translation via gettext catalogs.
|
||||
# Usually you set "language" from the command line for these cases.
|
||||
language = None
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "build", "README.md", "tutorials/README.md"]
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = "sphinx"
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
#
|
||||
# html_theme_options = {}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ["_static"]
|
||||
html_css_files = ["css/custom.css"]
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
#
|
||||
# The default sidebars (for documents that don't match any pattern) are
|
||||
# defined by theme itself. Builtin themes are using these templates by
|
||||
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
|
||||
# 'searchbox.html']``.
|
||||
#
|
||||
# html_sidebars = {}
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = "fastreiddoc"
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
# 'pointsize': '10pt',
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# 'preamble': '',
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, "fastreid.tex", "fastreid Documentation", "fastreid contributors", "manual")
|
||||
]
|
||||
|
||||
|
||||
# -- Options for manual page output ------------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(master_doc, "fastreid", "fastreid Documentation", [author], 1)]
|
||||
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(
|
||||
master_doc,
|
||||
"fastreid",
|
||||
"fastreid Documentation",
|
||||
author,
|
||||
"fastreid",
|
||||
"One line description of project.",
|
||||
"Miscellaneous",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# -- Options for todo extension ----------------------------------------------
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = True
|
||||
|
||||
|
||||
def autodoc_skip_member(app, what, name, obj, skip, options):
|
||||
# we hide something deliberately
|
||||
if getattr(obj, "__HIDE_SPHINX_DOC__", False):
|
||||
return True
|
||||
|
||||
# Hide some that are deprecated or not intended to be used
|
||||
HIDDEN = {
|
||||
# "ResNetBlockBase",
|
||||
"GroupedBatchSampler",
|
||||
# "build_transform_gen",
|
||||
# "export_caffe2_model",
|
||||
# "export_onnx_model",
|
||||
# "apply_transform_gens",
|
||||
# "TransformGen",
|
||||
# "apply_augmentations",
|
||||
# "StandardAugInput",
|
||||
# "build_batch_data_loader",
|
||||
# "draw_panoptic_seg_predictions",
|
||||
}
|
||||
try:
|
||||
if name in HIDDEN or (
|
||||
hasattr(obj, "__doc__") and obj.__doc__.lower().strip().startswith("deprecated")
|
||||
):
|
||||
print("Skipping deprecated object: {}".format(name))
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return skip
|
||||
|
||||
|
||||
_PAPER_DATA = {
|
||||
"resnet": ("1512.03385", "Deep Residual Learning for Image Recognition"),
|
||||
"fpn": ("1612.03144", "Feature Pyramid Networks for Object Detection"),
|
||||
"mask r-cnn": ("1703.06870", "Mask R-CNN"),
|
||||
"faster r-cnn": (
|
||||
"1506.01497",
|
||||
"Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks",
|
||||
),
|
||||
"deformconv": ("1703.06211", "Deformable Convolutional Networks"),
|
||||
"deformconv2": ("1811.11168", "Deformable ConvNets v2: More Deformable, Better Results"),
|
||||
"panopticfpn": ("1901.02446", "Panoptic Feature Pyramid Networks"),
|
||||
"retinanet": ("1708.02002", "Focal Loss for Dense Object Detection"),
|
||||
"cascade r-cnn": ("1712.00726", "Cascade R-CNN: Delving into High Quality Object Detection"),
|
||||
"lvis": ("1908.03195", "LVIS: A Dataset for Large Vocabulary Instance Segmentation"),
|
||||
"rrpn": ("1703.01086", "Arbitrary-Oriented Scene Text Detection via Rotation Proposals"),
|
||||
"imagenet in 1h": ("1706.02677", "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"),
|
||||
"xception": ("1610.02357", "Xception: Deep Learning with Depthwise Separable Convolutions"),
|
||||
"mobilenet": (
|
||||
"1704.04861",
|
||||
"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def paper_ref_role(
|
||||
typ: str,
|
||||
rawtext: str,
|
||||
text: str,
|
||||
lineno: int,
|
||||
inliner,
|
||||
options: Dict = {},
|
||||
content: List[str] = [],
|
||||
):
|
||||
"""
|
||||
Parse :paper:`xxx`. Similar to the "extlinks" sphinx extension.
|
||||
"""
|
||||
from docutils import nodes, utils
|
||||
from sphinx.util.nodes import split_explicit_title
|
||||
|
||||
text = utils.unescape(text)
|
||||
has_explicit_title, title, link = split_explicit_title(text)
|
||||
link = link.lower()
|
||||
if link not in _PAPER_DATA:
|
||||
inliner.reporter.warning("Cannot find paper " + link)
|
||||
paper_url, paper_title = "#", link
|
||||
else:
|
||||
paper_url, paper_title = _PAPER_DATA[link]
|
||||
if "/" not in paper_url:
|
||||
paper_url = "https://arxiv.org/abs/" + paper_url
|
||||
if not has_explicit_title:
|
||||
title = paper_title
|
||||
pnode = nodes.reference(title, title, internal=False, refuri=paper_url)
|
||||
return [pnode], []
|
||||
|
||||
|
||||
def setup(app):
|
||||
from recommonmark.transform import AutoStructify
|
||||
|
||||
app.add_domain(GithubURLDomain)
|
||||
app.connect("autodoc-skip-member", autodoc_skip_member)
|
||||
app.add_role("paper", paper_ref_role)
|
||||
app.add_config_value(
|
||||
"recommonmark_config",
|
||||
{"enable_math": True, "enable_inline_math": True, "enable_eval_rst": True},
|
||||
True,
|
||||
)
|
||||
app.add_transform(AutoStructify)
|
|
@ -0,0 +1,14 @@
|
|||
.. fastreid documentation master file, created by
|
||||
sphinx-quickstart on Sat Sep 21 13:46:45 2019.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to fastreid's documentation!
|
||||
======================================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
tutorials/index
|
||||
notes/index
|
||||
modules/index
|
|
@ -0,0 +1,7 @@
|
|||
fastreid.checkpoint
|
||||
=============================
|
||||
|
||||
.. automodule:: fastreid.utils.checkpoint
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
|
@ -0,0 +1,19 @@
|
|||
fastreid.config
|
||||
=========================
|
||||
|
||||
Related tutorials: :doc:`../tutorials/configs`, :doc:`../tutorials/extend`.
|
||||
|
||||
.. automodule:: fastreid.config
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:inherited-members:
|
||||
|
||||
|
||||
Config References
|
||||
-----------------
|
||||
|
||||
.. literalinclude:: ../../fastreid/config/defaults.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lines: 4-
|
|
@ -0,0 +1,96 @@
|
|||
fastreid.data
|
||||
=======================
|
||||
|
||||
.. automodule:: fastreid.data.build
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
fastreid.data.data\_utils module
|
||||
---------------------------------------
|
||||
|
||||
.. automodule:: fastreid.data.data_utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
fastreid.data.datasets module
|
||||
---------------------------------------
|
||||
|
||||
.. automodule:: fastreid.data.datasets.market1501
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.cuhk03
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.dukemtmcreid
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.msmt17
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.AirportALERT
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.iLIDS
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.pku
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.prai
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.saivt
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.sensereid
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.sysu_mm
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.thermalworld
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.pes3d
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.caviara
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.viper
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.lpw
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.shinpuhkan
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.wildtracker
|
||||
:members:
|
||||
|
||||
.. automodule:: fastreid.data.datasets.cuhk_sysu
|
||||
:members:
|
||||
|
||||
|
||||
fastreid.data.samplers module
|
||||
---------------------------------------
|
||||
|
||||
.. automodule:: fastreid.data.samplers
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
fastreid.data.transforms module
|
||||
---------------------------------------
|
||||
|
||||
.. automodule:: fastreid.data.transforms
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:imported-members:
|
|
@ -0,0 +1,9 @@
|
|||
fastreid.data.transforms
|
||||
====================================
|
||||
|
||||
|
||||
.. automodule:: fastreid.data.transforms
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:imported-members:
|
|
@ -0,0 +1,24 @@
|
|||
fastreid.engine
|
||||
=========================
|
||||
|
||||
.. automodule:: fastreid.engine
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
fastreid.engine.defaults module
|
||||
---------------------------------
|
||||
|
||||
.. automodule:: fastreid.engine.defaults
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
fastreid.engine.hooks module
|
||||
---------------------------------
|
||||
|
||||
.. automodule:: fastreid.engine.hooks
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
|
@ -0,0 +1,7 @@
|
|||
fastreid.evaluation
|
||||
=============================
|
||||
|
||||
.. automodule:: fastreid.evaluation
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
|
@ -0,0 +1,17 @@
|
|||
API Documentation
|
||||
==================
|
||||
|
||||
.. toctree::
|
||||
|
||||
checkpoint
|
||||
config
|
||||
data
|
||||
data_transforms
|
||||
engine
|
||||
evaluation
|
||||
layers
|
||||
model_zoo
|
||||
modeling
|
||||
solver
|
||||
utils
|
||||
export
|
|
@ -0,0 +1,7 @@
|
|||
fastreid.layers
|
||||
=========================
|
||||
|
||||
.. automodule:: fastreid.layers
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
|
@ -0,0 +1,24 @@
|
|||
fastreid.modeling
|
||||
===========================
|
||||
|
||||
.. automodule:: fastreid.modeling
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Model Registries
|
||||
-----------------
|
||||
|
||||
These are different registries provided in modeling.
|
||||
Each registry provide you the ability to replace it with your customized component,
|
||||
without having to modify fastreid's code.
|
||||
|
||||
Note that it is impossible to allow users to customize any line of code directly.
|
||||
Even just to add one line at some place,
|
||||
you'll likely need to find out the smallest registry which contains that line,
|
||||
and register your component to that registry.
|
||||
|
||||
|
||||
.. autodata:: fastreid.modeling.BACKBONE_REGISTRY
|
||||
.. autodata:: fastreid.modeling.META_ARCH_REGISTRY
|
||||
.. autodata:: fastreid.modeling.REID_HEADS_REGISTRY
|
|
@ -0,0 +1,7 @@
|
|||
fastreid.solver
|
||||
=========================
|
||||
|
||||
.. automodule:: fastreid.solver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
|
@ -0,0 +1,80 @@
|
|||
fastreid.utils
|
||||
========================
|
||||
|
||||
fastreid.utils.colormap module
|
||||
--------------------------------
|
||||
|
||||
.. automodule:: fastreid.utils.colormap
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
fastreid.utils.comm module
|
||||
----------------------------
|
||||
|
||||
.. automodule:: fastreid.utils.comm
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
fastreid.utils.events module
|
||||
------------------------------
|
||||
|
||||
.. automodule:: fastreid.utils.events
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
fastreid.utils.logger module
|
||||
------------------------------
|
||||
|
||||
.. automodule:: fastreid.utils.logger
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
fastreid.utils.registry module
|
||||
--------------------------------
|
||||
|
||||
.. automodule:: fastreid.utils.registry
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
fastreid.utils.memory module
|
||||
----------------------------------
|
||||
|
||||
.. automodule:: fastreid.utils.memory
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
fastreid.utils.analysis module
|
||||
----------------------------------
|
||||
|
||||
.. automodule:: fastreid.utils.analysis
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
fastreid.utils.visualizer module
|
||||
----------------------------------
|
||||
|
||||
.. automodule:: fastreid.utils.visualizer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
fastreid.utils.video\_visualizer module
|
||||
-----------------------------------------
|
||||
|
||||
.. automodule:: fastreid.utils.video_visualizer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
|
@ -17,4 +17,4 @@ termcolor
|
|||
scikit-learn
|
||||
tabulate
|
||||
gdown
|
||||
faiss-cpu
|
||||
faiss-gpu
|
|
@ -5,4 +5,4 @@
|
|||
"""
|
||||
|
||||
|
||||
__version__ = "0.2.0"
|
||||
__version__ = "1.3"
|
||||
|
|
|
@ -4,5 +4,12 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .config import CfgNode, get_cfg
|
||||
from .defaults import _C as cfg
|
||||
from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
|
||||
|
||||
__all__ = [
|
||||
'CfgNode',
|
||||
'get_cfg',
|
||||
'global_cfg',
|
||||
'set_global_cfg',
|
||||
'configurable'
|
||||
]
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
@ -148,6 +150,9 @@ class CfgNode(_CfgNode):
|
|||
super().__setattr__(name, val)
|
||||
|
||||
|
||||
global_cfg = CfgNode()
|
||||
|
||||
|
||||
def get_cfg() -> CfgNode:
|
||||
"""
|
||||
Get a copy of the default config.
|
||||
|
@ -157,3 +162,158 @@ def get_cfg() -> CfgNode:
|
|||
from .defaults import _C
|
||||
|
||||
return _C.clone()
|
||||
|
||||
|
||||
def set_global_cfg(cfg: CfgNode) -> None:
|
||||
"""
|
||||
Let the global config point to the given cfg.
|
||||
Assume that the given "cfg" has the key "KEY", after calling
|
||||
`set_global_cfg(cfg)`, the key can be accessed by:
|
||||
::
|
||||
from detectron2.config import global_cfg
|
||||
print(global_cfg.KEY)
|
||||
By using a hacky global config, you can access these configs anywhere,
|
||||
without having to pass the config object or the values deep into the code.
|
||||
This is a hacky feature introduced for quick prototyping / research exploration.
|
||||
"""
|
||||
global global_cfg
|
||||
global_cfg.clear()
|
||||
global_cfg.update(cfg)
|
||||
|
||||
|
||||
def configurable(init_func=None, *, from_config=None):
|
||||
"""
|
||||
Decorate a function or a class's __init__ method so that it can be called
|
||||
with a :class:`CfgNode` object using a :func:`from_config` function that translates
|
||||
:class:`CfgNode` to arguments.
|
||||
Examples:
|
||||
::
|
||||
# Usage 1: Decorator on __init__:
|
||||
class A:
|
||||
@configurable
|
||||
def __init__(self, a, b=2, c=3):
|
||||
pass
|
||||
@classmethod
|
||||
def from_config(cls, cfg): # 'cfg' must be the first argument
|
||||
# Returns kwargs to be passed to __init__
|
||||
return {"a": cfg.A, "b": cfg.B}
|
||||
a1 = A(a=1, b=2) # regular construction
|
||||
a2 = A(cfg) # construct with a cfg
|
||||
a3 = A(cfg, b=3, c=4) # construct with extra overwrite
|
||||
# Usage 2: Decorator on any function. Needs an extra from_config argument:
|
||||
@configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
|
||||
def a_func(a, b=2, c=3):
|
||||
pass
|
||||
a1 = a_func(a=1, b=2) # regular call
|
||||
a2 = a_func(cfg) # call with a cfg
|
||||
a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
|
||||
Args:
|
||||
init_func (callable): a class's ``__init__`` method in usage 1. The
|
||||
class must have a ``from_config`` classmethod which takes `cfg` as
|
||||
the first argument.
|
||||
from_config (callable): the from_config function in usage 2. It must take `cfg`
|
||||
as its first argument.
|
||||
"""
|
||||
|
||||
def check_docstring(func):
|
||||
if func.__module__.startswith("fastreid."):
|
||||
assert (
|
||||
func.__doc__ is not None and "experimental" in func.__doc__.lower()
|
||||
), f"configurable {func} should be marked experimental"
|
||||
|
||||
if init_func is not None:
|
||||
assert (
|
||||
inspect.isfunction(init_func)
|
||||
and from_config is None
|
||||
and init_func.__name__ == "__init__"
|
||||
), "Incorrect use of @configurable. Check API documentation for examples."
|
||||
check_docstring(init_func)
|
||||
|
||||
@functools.wraps(init_func)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
try:
|
||||
from_config_func = type(self).from_config
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
"Class with @configurable must have a 'from_config' classmethod."
|
||||
) from e
|
||||
if not inspect.ismethod(from_config_func):
|
||||
raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
|
||||
|
||||
if _called_with_cfg(*args, **kwargs):
|
||||
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
|
||||
init_func(self, **explicit_args)
|
||||
else:
|
||||
init_func(self, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
else:
|
||||
if from_config is None:
|
||||
return configurable # @configurable() is made equivalent to @configurable
|
||||
assert inspect.isfunction(
|
||||
from_config
|
||||
), "from_config argument of configurable must be a function!"
|
||||
|
||||
def wrapper(orig_func):
|
||||
check_docstring(orig_func)
|
||||
|
||||
@functools.wraps(orig_func)
|
||||
def wrapped(*args, **kwargs):
|
||||
if _called_with_cfg(*args, **kwargs):
|
||||
explicit_args = _get_args_from_config(from_config, *args, **kwargs)
|
||||
return orig_func(**explicit_args)
|
||||
else:
|
||||
return orig_func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get_args_from_config(from_config_func, *args, **kwargs):
|
||||
"""
|
||||
Use `from_config` to obtain explicit arguments.
|
||||
Returns:
|
||||
dict: arguments to be used for cls.__init__
|
||||
"""
|
||||
signature = inspect.signature(from_config_func)
|
||||
if list(signature.parameters.keys())[0] != "cfg":
|
||||
if inspect.isfunction(from_config_func):
|
||||
name = from_config_func.__name__
|
||||
else:
|
||||
name = f"{from_config_func.__self__}.from_config"
|
||||
raise TypeError(f"{name} must take 'cfg' as the first argument!")
|
||||
support_var_arg = any(
|
||||
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
|
||||
for param in signature.parameters.values()
|
||||
)
|
||||
if support_var_arg: # forward all arguments to from_config, if from_config accepts them
|
||||
ret = from_config_func(*args, **kwargs)
|
||||
else:
|
||||
# forward supported arguments to from_config
|
||||
supported_arg_names = set(signature.parameters.keys())
|
||||
extra_kwargs = {}
|
||||
for name in list(kwargs.keys()):
|
||||
if name not in supported_arg_names:
|
||||
extra_kwargs[name] = kwargs.pop(name)
|
||||
ret = from_config_func(*args, **kwargs)
|
||||
# forward the other arguments to __init__
|
||||
ret.update(extra_kwargs)
|
||||
return ret
|
||||
|
||||
|
||||
def _called_with_cfg(*args, **kwargs):
|
||||
"""
|
||||
Returns:
|
||||
bool: whether the arguments contain CfgNode and should be considered
|
||||
forwarded to from_config.
|
||||
"""
|
||||
|
||||
if len(args) and isinstance(args[0], _CfgNode):
|
||||
return True
|
||||
if isinstance(kwargs.pop("cfg", None), _CfgNode):
|
||||
return True
|
||||
# `from_config`'s first argument is forced to be "cfg".
|
||||
# So the above check covers all cases.
|
||||
return False
|
||||
|
|
|
@ -23,7 +23,7 @@ _C.MODEL = CN()
|
|||
_C.MODEL.DEVICE = "cuda"
|
||||
_C.MODEL.META_ARCHITECTURE = "Baseline"
|
||||
|
||||
_C.MODEL.FREEZE_LAYERS = ['']
|
||||
_C.MODEL.FREEZE_LAYERS = []
|
||||
|
||||
# MoCo memory size
|
||||
_C.MODEL.QUEUE_SIZE = 8192
|
||||
|
@ -46,8 +46,14 @@ _C.MODEL.BACKBONE.WITH_IBN = False
|
|||
_C.MODEL.BACKBONE.WITH_SE = False
|
||||
# If use Non-local block in backbone
|
||||
_C.MODEL.BACKBONE.WITH_NL = False
|
||||
# Vision Transformer options
|
||||
_C.MODEL.BACKBONE.SIE_COE = 3.0
|
||||
_C.MODEL.BACKBONE.STRIDE_SIZE = (16, 16)
|
||||
_C.MODEL.BACKBONE.DROP_PATH_RATIO = 0.1
|
||||
_C.MODEL.BACKBONE.DROP_RATIO = 0.0
|
||||
_C.MODEL.BACKBONE.ATT_DROP_RATE = 0.0
|
||||
# If use ImageNet pretrain model
|
||||
_C.MODEL.BACKBONE.PRETRAIN = True
|
||||
_C.MODEL.BACKBONE.PRETRAIN = False
|
||||
# Pretrain model path
|
||||
_C.MODEL.BACKBONE.PRETRAIN_PATH = ''
|
||||
|
||||
|
@ -63,18 +69,18 @@ _C.MODEL.HEADS.NUM_CLASSES = 0
|
|||
# Embedding dimension in head
|
||||
_C.MODEL.HEADS.EMBEDDING_DIM = 0
|
||||
# If use BNneck in embedding
|
||||
_C.MODEL.HEADS.WITH_BNNECK = True
|
||||
_C.MODEL.HEADS.WITH_BNNECK = False
|
||||
# Triplet feature using feature before(after) bnneck
|
||||
_C.MODEL.HEADS.NECK_FEAT = "before" # options: before, after
|
||||
# Pooling layer type
|
||||
_C.MODEL.HEADS.POOL_LAYER = "avgpool"
|
||||
_C.MODEL.HEADS.POOL_LAYER = "GlobalAvgPool"
|
||||
|
||||
# Classification layer type
|
||||
_C.MODEL.HEADS.CLS_LAYER = "linear" # "arcSoftmax" or "circleSoftmax"
|
||||
_C.MODEL.HEADS.CLS_LAYER = "Linear" # ArcSoftmax" or "CircleSoftmax"
|
||||
|
||||
# Margin and Scale for margin-based classification layer
|
||||
_C.MODEL.HEADS.MARGIN = 0.15
|
||||
_C.MODEL.HEADS.SCALE = 128
|
||||
_C.MODEL.HEADS.MARGIN = 0.
|
||||
_C.MODEL.HEADS.SCALE = 1
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# REID LOSSES options
|
||||
|
@ -100,7 +106,7 @@ _C.MODEL.LOSSES.FL.SCALE = 1.0
|
|||
_C.MODEL.LOSSES.TRI = CN()
|
||||
_C.MODEL.LOSSES.TRI.MARGIN = 0.3
|
||||
_C.MODEL.LOSSES.TRI.NORM_FEAT = False
|
||||
_C.MODEL.LOSSES.TRI.HARD_MINING = True
|
||||
_C.MODEL.LOSSES.TRI.HARD_MINING = False
|
||||
_C.MODEL.LOSSES.TRI.SCALE = 1.0
|
||||
|
||||
# Circle Loss options
|
||||
|
@ -128,8 +134,10 @@ _C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
_C.KD = CN()
|
||||
_C.KD.MODEL_CONFIG = ""
|
||||
_C.KD.MODEL_WEIGHTS = ""
|
||||
_C.KD.MODEL_CONFIG = []
|
||||
_C.KD.MODEL_WEIGHTS = []
|
||||
_C.KD.EMA = CN({"ENABLED": False})
|
||||
_C.KD.EMA.MOMENTUM = 0.999
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# INPUT
|
||||
|
@ -140,18 +148,26 @@ _C.INPUT.SIZE_TRAIN = [256, 128]
|
|||
# Size of the image during test
|
||||
_C.INPUT.SIZE_TEST = [256, 128]
|
||||
|
||||
# `True` if cropping is used for data augmentation during training
|
||||
_C.INPUT.CROP = CN({"ENABLED": False})
|
||||
# Size of the image cropped
|
||||
_C.INPUT.CROP.SIZE = [224, 224]
|
||||
# Size of the origin size cropped
|
||||
_C.INPUT.CROP.SCALE = [0.16, 1]
|
||||
# Aspect ratio of the origin aspect ratio cropped
|
||||
_C.INPUT.CROP.RATIO = [3./4., 4./3.]
|
||||
|
||||
# Random probability for image horizontal flip
|
||||
_C.INPUT.DO_FLIP = True
|
||||
_C.INPUT.FLIP_PROB = 0.5
|
||||
_C.INPUT.FLIP = CN({"ENABLED": False})
|
||||
_C.INPUT.FLIP.PROB = 0.5
|
||||
|
||||
# Value of padding size
|
||||
_C.INPUT.DO_PAD = True
|
||||
_C.INPUT.PADDING_MODE = 'constant'
|
||||
_C.INPUT.PADDING = 10
|
||||
_C.INPUT.PADDING = CN({"ENABLED": False})
|
||||
_C.INPUT.PADDING.MODE = 'constant'
|
||||
_C.INPUT.PADDING.SIZE = 10
|
||||
|
||||
# Random color jitter
|
||||
_C.INPUT.CJ = CN()
|
||||
_C.INPUT.CJ.ENABLED = False
|
||||
_C.INPUT.CJ = CN({"ENABLED": False})
|
||||
_C.INPUT.CJ.PROB = 0.5
|
||||
_C.INPUT.CJ.BRIGHTNESS = 0.15
|
||||
_C.INPUT.CJ.CONTRAST = 0.15
|
||||
|
@ -159,24 +175,22 @@ _C.INPUT.CJ.SATURATION = 0.1
|
|||
_C.INPUT.CJ.HUE = 0.1
|
||||
|
||||
# Random Affine
|
||||
_C.INPUT.DO_AFFINE = False
|
||||
_C.INPUT.AFFINE = CN({"ENABLED": False})
|
||||
|
||||
# Auto augmentation
|
||||
_C.INPUT.DO_AUTOAUG = False
|
||||
_C.INPUT.AUTOAUG_PROB = 0.0
|
||||
_C.INPUT.AUTOAUG = CN({"ENABLED": False})
|
||||
_C.INPUT.AUTOAUG.PROB = 0.0
|
||||
|
||||
# Augmix augmentation
|
||||
_C.INPUT.DO_AUGMIX = False
|
||||
_C.INPUT.AUGMIX_PROB = 0.0
|
||||
_C.INPUT.AUGMIX = CN({"ENABLED": False})
|
||||
_C.INPUT.AUGMIX.PROB = 0.0
|
||||
|
||||
# Random Erasing
|
||||
_C.INPUT.REA = CN()
|
||||
_C.INPUT.REA.ENABLED = False
|
||||
_C.INPUT.REA = CN({"ENABLED": False})
|
||||
_C.INPUT.REA.PROB = 0.5
|
||||
_C.INPUT.REA.VALUE = [0.485*255, 0.456*255, 0.406*255]
|
||||
# Random Patch
|
||||
_C.INPUT.RPT = CN()
|
||||
_C.INPUT.RPT.ENABLED = False
|
||||
_C.INPUT.RPT = CN({"ENABLED": False})
|
||||
_C.INPUT.RPT.PROB = 0.5
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@ -194,21 +208,22 @@ _C.DATASETS.COMBINEALL = False
|
|||
# DataLoader
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.DATALOADER = CN()
|
||||
# P/K Sampler for data loading
|
||||
_C.DATALOADER.PK_SAMPLER = True
|
||||
# Naive sampler which don't consider balanced identity sampling
|
||||
_C.DATALOADER.NAIVE_WAY = True
|
||||
# Options: TrainingSampler, NaiveIdentitySampler, BalancedIdentitySampler
|
||||
_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
|
||||
# Number of instance for each person
|
||||
_C.DATALOADER.NUM_INSTANCE = 4
|
||||
_C.DATALOADER.NUM_WORKERS = 8
|
||||
|
||||
# For set re-weight
|
||||
_C.DATALOADER.SET_WEIGHT = []
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Solver
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.SOLVER = CN()
|
||||
|
||||
# AUTOMATIC MIXED PRECISION
|
||||
_C.SOLVER.FP16_ENABLED = False
|
||||
_C.SOLVER.AMP = CN({"ENABLED": False})
|
||||
|
||||
# Optimizer
|
||||
_C.SOLVER.OPT = "Adam"
|
||||
|
@ -216,14 +231,25 @@ _C.SOLVER.OPT = "Adam"
|
|||
_C.SOLVER.MAX_EPOCH = 120
|
||||
|
||||
_C.SOLVER.BASE_LR = 3e-4
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1.
|
||||
|
||||
# This LR is applied to the last classification layer if
|
||||
# you want to 10x higher than BASE_LR.
|
||||
_C.SOLVER.HEADS_LR_FACTOR = 1.
|
||||
|
||||
_C.SOLVER.MOMENTUM = 0.9
|
||||
_C.SOLVER.NESTEROV = True
|
||||
_C.SOLVER.NESTEROV = False
|
||||
|
||||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
# The weight decay that's applied to parameters of normalization layers
|
||||
# (typically the affine transformation)
|
||||
_C.SOLVER.WEIGHT_DECAY_NORM = 0.0005
|
||||
|
||||
# The previous detection code used a 2x higher LR and 0 WD for bias.
|
||||
# This is not useful (at least for recent models). You should avoid
|
||||
# changing these and they exists only to reproduce previous model
|
||||
# training if desired.
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1.0
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = _C.SOLVER.WEIGHT_DECAY
|
||||
|
||||
# Multi-step learning rate options
|
||||
_C.SOLVER.SCHED = "MultiStepLR"
|
||||
|
@ -238,61 +264,56 @@ _C.SOLVER.ETA_MIN_LR = 1e-7
|
|||
|
||||
# Warmup options
|
||||
_C.SOLVER.WARMUP_FACTOR = 0.1
|
||||
_C.SOLVER.WARMUP_EPOCHS = 10
|
||||
_C.SOLVER.WARMUP_ITERS = 1000
|
||||
_C.SOLVER.WARMUP_METHOD = "linear"
|
||||
|
||||
# Backbone freeze iters
|
||||
_C.SOLVER.FREEZE_ITERS = 0
|
||||
|
||||
# FC freeze iters
|
||||
_C.SOLVER.FREEZE_FC_ITERS = 0
|
||||
|
||||
|
||||
# SWA options
|
||||
# _C.SOLVER.SWA = CN()
|
||||
# _C.SOLVER.SWA.ENABLED = False
|
||||
# _C.SOLVER.SWA.ITER = 10
|
||||
# _C.SOLVER.SWA.PERIOD = 2
|
||||
# _C.SOLVER.SWA.LR_FACTOR = 10.
|
||||
# _C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
|
||||
# _C.SOLVER.SWA.LR_SCHED = False
|
||||
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 20
|
||||
|
||||
# Number of images per batch across all machines.
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 256, each GPU will
|
||||
# see 32 images per batch
|
||||
_C.SOLVER.IMS_PER_BATCH = 64
|
||||
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
# Gradient clipping
|
||||
_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False})
|
||||
# Type of gradient clipping, currently 2 values are supported:
|
||||
# - "value": the absolute values of elements of each gradients are clipped
|
||||
# - "norm": the norm of the gradient for each parameter is clipped thus
|
||||
# affecting all elements in the parameter
|
||||
_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "norm"
|
||||
# Maximum absolute value used for clipping gradients
|
||||
_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 5.0
|
||||
# Floating point number p for L-p norm to be used with the "norm"
|
||||
# gradient clipping type; for L-inf, please specify .inf
|
||||
_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0
|
||||
|
||||
_C.TEST = CN()
|
||||
|
||||
_C.TEST.EVAL_PERIOD = 20
|
||||
|
||||
# Number of images per batch in one process.
|
||||
# Number of images per batch across all machines.
|
||||
_C.TEST.IMS_PER_BATCH = 64
|
||||
_C.TEST.METRIC = "cosine"
|
||||
_C.TEST.ROC_ENABLED = False
|
||||
_C.TEST.FLIP_ENABLED = False
|
||||
_C.TEST.ROC = CN({"ENABLED": False})
|
||||
_C.TEST.FLIP = CN({"ENABLED": False})
|
||||
|
||||
# Average query expansion
|
||||
_C.TEST.AQE = CN()
|
||||
_C.TEST.AQE.ENABLED = False
|
||||
_C.TEST.AQE = CN({"ENABLED": False})
|
||||
_C.TEST.AQE.ALPHA = 3.0
|
||||
_C.TEST.AQE.QE_TIME = 1
|
||||
_C.TEST.AQE.QE_K = 5
|
||||
|
||||
# Re-rank
|
||||
_C.TEST.RERANK = CN()
|
||||
_C.TEST.RERANK.ENABLED = False
|
||||
_C.TEST.RERANK = CN({"ENABLED": False})
|
||||
_C.TEST.RERANK.K1 = 20
|
||||
_C.TEST.RERANK.K2 = 6
|
||||
_C.TEST.RERANK.LAMBDA = 0.3
|
||||
|
||||
# Precise batchnorm
|
||||
_C.TEST.PRECISE_BN = CN()
|
||||
_C.TEST.PRECISE_BN.ENABLED = False
|
||||
_C.TEST.PRECISE_BN = CN({"ENABLED": False})
|
||||
_C.TEST.PRECISE_BN.DATASET = 'Market1501'
|
||||
_C.TEST.PRECISE_BN.NUM_ITER = 300
|
||||
|
||||
|
|
|
@ -4,4 +4,14 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import build_reid_train_loader, build_reid_test_loader
|
||||
from . import transforms # isort:skip
|
||||
from .build import (
|
||||
build_reid_train_loader,
|
||||
build_reid_test_loader
|
||||
)
|
||||
from .common import CommDataset
|
||||
|
||||
# ensure the builtin datasets are registered
|
||||
from . import datasets, samplers # isort:skip
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
|
|
@ -4,87 +4,165 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch._six import container_abcs, string_classes, int_classes
|
||||
from torch.utils.data import DataLoader
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
|
||||
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
|
||||
from torch._six import string_classes
|
||||
else:
|
||||
string_classes = str
|
||||
|
||||
from collections import Mapping
|
||||
|
||||
from fastreid.config import configurable
|
||||
from fastreid.utils import comm
|
||||
from . import samplers
|
||||
from .common import CommDataset
|
||||
from .data_utils import DataLoaderX
|
||||
from .datasets import DATASET_REGISTRY
|
||||
from .transforms import build_transforms
|
||||
|
||||
__all__ = [
|
||||
"build_reid_train_loader",
|
||||
"build_reid_test_loader"
|
||||
]
|
||||
|
||||
_root = os.getenv("FASTREID_DATASETS", "datasets")
|
||||
|
||||
|
||||
def build_reid_train_loader(cfg, mapper=None, **kwargs):
|
||||
cfg = cfg.clone()
|
||||
|
||||
train_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL, **kwargs)
|
||||
if comm.is_main_process():
|
||||
dataset.show_train()
|
||||
train_items.extend(dataset.train)
|
||||
|
||||
if mapper is not None:
|
||||
transforms = mapper
|
||||
else:
|
||||
def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=None, **kwargs):
|
||||
if transforms is None:
|
||||
transforms = build_transforms(cfg, is_train=True)
|
||||
|
||||
train_set = CommDataset(train_items, transforms, relabel=True)
|
||||
if train_set is None:
|
||||
train_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
data = DATASET_REGISTRY.get(d)(root=_root, **kwargs)
|
||||
if comm.is_main_process():
|
||||
data.show_train()
|
||||
train_items.extend(data.train)
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
num_instance = cfg.DATALOADER.NUM_INSTANCE
|
||||
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
|
||||
train_set = CommDataset(train_items, transforms, relabel=True)
|
||||
|
||||
if cfg.DATALOADER.PK_SAMPLER:
|
||||
if cfg.DATALOADER.NAIVE_WAY:
|
||||
data_sampler = samplers.NaiveIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
|
||||
if sampler is None:
|
||||
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
||||
num_instance = cfg.DATALOADER.NUM_INSTANCE
|
||||
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Using training sampler {}".format(sampler_name))
|
||||
if sampler_name == "TrainingSampler":
|
||||
sampler = samplers.TrainingSampler(len(train_set))
|
||||
elif sampler_name == "NaiveIdentitySampler":
|
||||
sampler = samplers.NaiveIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
|
||||
elif sampler_name == "BalancedIdentitySampler":
|
||||
sampler = samplers.BalancedIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
|
||||
elif sampler_name == "SetReWeightSampler":
|
||||
set_weight = cfg.DATALOADER.SET_WEIGHT
|
||||
sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight)
|
||||
elif sampler_name == "ImbalancedDatasetSampler":
|
||||
sampler = samplers.ImbalancedDatasetSampler(train_set.img_items)
|
||||
else:
|
||||
data_sampler = samplers.BalancedIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
|
||||
else:
|
||||
data_sampler = samplers.TrainingSampler(len(train_set))
|
||||
batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
|
||||
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_set,
|
||||
return {
|
||||
"train_set": train_set,
|
||||
"sampler": sampler,
|
||||
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
||||
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
||||
}
|
||||
|
||||
|
||||
@configurable(from_config=_train_loader_from_config)
|
||||
def build_reid_train_loader(
|
||||
train_set, *, sampler=None, total_batch_size, num_workers=0,
|
||||
):
|
||||
"""
|
||||
Build a dataloader for object re-identification with some default features.
|
||||
This interface is experimental.
|
||||
|
||||
Returns:
|
||||
torch.utils.data.DataLoader: a dataloader.
|
||||
"""
|
||||
|
||||
mini_batch_size = total_batch_size // comm.get_world_size()
|
||||
|
||||
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, mini_batch_size, True)
|
||||
|
||||
train_loader = DataLoaderX(
|
||||
comm.get_local_rank(),
|
||||
dataset=train_set,
|
||||
num_workers=num_workers,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=fast_batch_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return train_loader
|
||||
|
||||
|
||||
def build_reid_test_loader(cfg, dataset_name, mapper=None, **kwargs):
|
||||
cfg = cfg.clone()
|
||||
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
|
||||
if comm.is_main_process():
|
||||
dataset.show_test()
|
||||
test_items = dataset.query + dataset.gallery
|
||||
|
||||
if mapper is not None:
|
||||
transforms = mapper
|
||||
else:
|
||||
def _test_loader_from_config(cfg, *, dataset_name=None, test_set=None, num_query=0, transforms=None, **kwargs):
|
||||
if transforms is None:
|
||||
transforms = build_transforms(cfg, is_train=False)
|
||||
|
||||
test_set = CommDataset(test_items, transforms, relabel=False)
|
||||
if test_set is None:
|
||||
assert dataset_name is not None, "dataset_name must be explicitly passed in when test_set is not provided"
|
||||
data = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
|
||||
if comm.is_main_process():
|
||||
data.show_test()
|
||||
test_items = data.query + data.gallery
|
||||
test_set = CommDataset(test_items, transforms, relabel=False)
|
||||
|
||||
mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
|
||||
# Update query number
|
||||
num_query = len(data.query)
|
||||
|
||||
return {
|
||||
"test_set": test_set,
|
||||
"test_batch_size": cfg.TEST.IMS_PER_BATCH,
|
||||
"num_query": num_query,
|
||||
}
|
||||
|
||||
|
||||
@configurable(from_config=_test_loader_from_config)
|
||||
def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
|
||||
"""
|
||||
Similar to `build_reid_train_loader`. This sampler coordinates all workers to produce
|
||||
the exact set of all samples
|
||||
This interface is experimental.
|
||||
|
||||
Args:
|
||||
test_set:
|
||||
test_batch_size:
|
||||
num_query:
|
||||
num_workers:
|
||||
|
||||
Returns:
|
||||
DataLoader: a torch DataLoader, that loads the given reid dataset, with
|
||||
the test-time transformation.
|
||||
|
||||
Examples:
|
||||
::
|
||||
data_loader = build_reid_test_loader(test_set, test_batch_size, num_query)
|
||||
# or, instantiate with a CfgNode:
|
||||
data_loader = build_reid_test_loader(cfg, "my_test")
|
||||
"""
|
||||
|
||||
mini_batch_size = test_batch_size // comm.get_world_size()
|
||||
data_sampler = samplers.InferenceSampler(len(test_set))
|
||||
batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
|
||||
test_loader = DataLoader(
|
||||
test_set,
|
||||
test_loader = DataLoaderX(
|
||||
comm.get_local_rank(),
|
||||
dataset=test_set,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=4, # save some memory
|
||||
num_workers=num_workers, # save some memory
|
||||
collate_fn=fast_batch_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
return test_loader, len(dataset.query)
|
||||
return test_loader, num_query
|
||||
|
||||
|
||||
def trivial_batch_collator(batch):
|
||||
|
@ -105,12 +183,12 @@ def fast_batch_collator(batched_inputs):
|
|||
out[i] += tensor
|
||||
return out
|
||||
|
||||
elif isinstance(elem, container_abcs.Mapping):
|
||||
elif isinstance(elem, Mapping):
|
||||
return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
|
||||
|
||||
elif isinstance(elem, float):
|
||||
return torch.tensor(batched_inputs, dtype=torch.float64)
|
||||
elif isinstance(elem, int_classes):
|
||||
elif isinstance(elem, int):
|
||||
return torch.tensor(batched_inputs)
|
||||
elif isinstance(elem, string_classes):
|
||||
return batched_inputs
|
||||
|
|
|
@ -3,8 +3,13 @@
|
|||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
import threading
|
||||
|
||||
import queue
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
|
@ -13,6 +18,7 @@ def read_image(file_name, format=None):
|
|||
"""
|
||||
Read an image into the given format.
|
||||
Will apply rotation and flipping if the image has such exif information.
|
||||
|
||||
Args:
|
||||
file_name (str): image file path
|
||||
format (str): one of the supported image modes in PIL, or "BGR"
|
||||
|
@ -52,3 +58,145 @@ def read_image(file_name, format=None):
|
|||
image = Image.fromarray(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
"""
|
||||
#based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
||||
This is a single-function package that transforms arbitrary generator into a background-thead generator that
|
||||
prefetches several batches of data in a parallel background thead.
|
||||
|
||||
This is useful if you have a computationally heavy process (CPU or GPU) that
|
||||
iteratively processes minibatches from the generator while the generator
|
||||
consumes some other resource (disk IO / loading from database / more CPU if you have unused cores).
|
||||
|
||||
By default these two processes will constantly wait for one another to finish. If you make generator work in
|
||||
prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time.
|
||||
We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc.
|
||||
|
||||
Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb
|
||||
This package contains this object
|
||||
- BackgroundGenerator(any_other_generator[,max_prefetch = something])
|
||||
"""
|
||||
|
||||
|
||||
class BackgroundGenerator(threading.Thread):
|
||||
"""
|
||||
the usage is below
|
||||
>> for batch in BackgroundGenerator(my_minibatch_iterator):
|
||||
>> doit()
|
||||
More details are written in the BackgroundGenerator doc
|
||||
>> help(BackgroundGenerator)
|
||||
"""
|
||||
|
||||
def __init__(self, generator, local_rank, max_prefetch=10):
|
||||
"""
|
||||
This function transforms generator into a background-thead generator.
|
||||
:param generator: generator or genexp or any
|
||||
It can be used with any minibatch generator.
|
||||
|
||||
It is quite lightweight, but not entirely weightless.
|
||||
Using global variables inside generator is not recommended (may raise GIL and zero-out the
|
||||
benefit of having a background thread.)
|
||||
The ideal use case is when everything it requires is store inside it and everything it
|
||||
outputs is passed through queue.
|
||||
|
||||
There's no restriction on doing weird stuff, reading/writing files, retrieving
|
||||
URLs [or whatever] wlilst iterating.
|
||||
|
||||
:param max_prefetch: defines, how many iterations (at most) can background generator keep
|
||||
stored at any moment of time.
|
||||
Whenever there's already max_prefetch batches stored in queue, the background process will halt until
|
||||
one of these batches is dequeued.
|
||||
|
||||
!Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator!
|
||||
|
||||
Setting max_prefetch to -1 lets it store as many batches as it can, which will work
|
||||
slightly (if any) faster, but will require storing
|
||||
all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size
|
||||
unless dequeued quickly enough.
|
||||
"""
|
||||
super().__init__()
|
||||
self.queue = queue.Queue(max_prefetch)
|
||||
self.generator = generator
|
||||
self.local_rank = local_rank
|
||||
self.daemon = True
|
||||
self.exit_event = threading.Event()
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
torch.cuda.set_device(self.local_rank)
|
||||
for item in self.generator:
|
||||
if self.exit_event.is_set():
|
||||
break
|
||||
self.queue.put(item)
|
||||
self.queue.put(None)
|
||||
|
||||
def next(self):
|
||||
next_item = self.queue.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
# Python 3 compatibility
|
||||
def __next__(self):
|
||||
return self.next()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
def __init__(self, local_rank, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.stream = torch.cuda.Stream(
|
||||
local_rank
|
||||
) # create a new cuda stream in each process
|
||||
self.local_rank = local_rank
|
||||
|
||||
def __iter__(self):
|
||||
self.iter = super().__iter__()
|
||||
self.iter = BackgroundGenerator(self.iter, self.local_rank)
|
||||
self.preload()
|
||||
return self
|
||||
|
||||
def _shutdown_background_thread(self):
|
||||
if not self.iter.is_alive():
|
||||
# avoid re-entrance or ill-conditioned thread state
|
||||
return
|
||||
|
||||
# Set exit event to True for background threading stopping
|
||||
self.iter.exit_event.set()
|
||||
|
||||
# Exhaust all remaining elements, so that the queue becomes empty,
|
||||
# and the thread should quit
|
||||
for _ in self.iter:
|
||||
pass
|
||||
|
||||
# Waiting for background thread to quit
|
||||
self.iter.join()
|
||||
|
||||
def preload(self):
|
||||
self.batch = next(self.iter, None)
|
||||
if self.batch is None:
|
||||
return None
|
||||
with torch.cuda.stream(self.stream):
|
||||
for k in self.batch:
|
||||
if isinstance(self.batch[k], torch.Tensor):
|
||||
self.batch[k] = self.batch[k].to(
|
||||
device=self.local_rank, non_blocking=True
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(
|
||||
self.stream
|
||||
) # wait tensor to put on GPU
|
||||
batch = self.batch
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
self.preload()
|
||||
return batch
|
||||
|
||||
# Signal for shutting down background thread
|
||||
def shutdown(self):
|
||||
# If the dataloader is to be freed, shutdown its BackgroundGenerator
|
||||
self._shutdown_background_thread()
|
||||
|
|
|
@ -14,6 +14,9 @@ __all__ = ['AirportALERT', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class AirportALERT(ImageDataset):
|
||||
"""Airport
|
||||
|
||||
"""
|
||||
dataset_dir = "AirportALERT"
|
||||
dataset_name = "airport"
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@ from .AirportALERT import AirportALERT
|
|||
from .iLIDS import iLIDS
|
||||
from .pku import PKU
|
||||
from .prai import PRAI
|
||||
from .prid import PRID
|
||||
from .grid import GRID
|
||||
from .saivt import SAIVT
|
||||
from .sensereid import SenseReID
|
||||
from .sysu_mm import SYSU_mm
|
||||
|
@ -38,5 +40,4 @@ from .veri import VeRi
|
|||
from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID
|
||||
from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
|
@ -16,10 +17,11 @@ logger = logging.getLogger(__name__)
|
|||
class Dataset(object):
|
||||
"""An abstract class representing a Dataset.
|
||||
This is the base class for ``ImageDataset`` and ``VideoDataset``.
|
||||
|
||||
Args:
|
||||
train (list): contains tuples of (img_path(s), pid, camid).
|
||||
query (list): contains tuples of (img_path(s), pid, camid).
|
||||
gallery (list): contains tuples of (img_path(s), pid, camid).
|
||||
train (list or Callable): contains tuples of (img_path(s), pid, camid).
|
||||
query (list or Callable): contains tuples of (img_path(s), pid, camid).
|
||||
gallery (list or Callable): contains tuples of (img_path(s), pid, camid).
|
||||
transform: transform function.
|
||||
mode (str): 'train', 'query' or 'gallery'.
|
||||
combineall (bool): combines train, query and gallery in a
|
||||
|
@ -30,17 +32,14 @@ class Dataset(object):
|
|||
|
||||
def __init__(self, train, query, gallery, transform=None, mode='train',
|
||||
combineall=False, verbose=True, **kwargs):
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
self._train = train
|
||||
self._query = query
|
||||
self._gallery = gallery
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
self.combineall = combineall
|
||||
self.verbose = verbose
|
||||
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
self.num_train_cams = self.get_num_cams(self.train)
|
||||
|
||||
if self.combineall:
|
||||
self.combine_all()
|
||||
|
||||
|
@ -54,6 +53,24 @@ class Dataset(object):
|
|||
raise ValueError('Invalid mode. Got {}, but expected to be '
|
||||
'one of [train | query | gallery]'.format(self.mode))
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
if callable(self._train):
|
||||
self._train = self._train()
|
||||
return self._train
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
if callable(self._query):
|
||||
self._query = self._query()
|
||||
return self._query
|
||||
|
||||
@property
|
||||
def gallery(self):
|
||||
if callable(self._gallery):
|
||||
self._gallery = self._gallery()
|
||||
return self._gallery
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -100,15 +117,14 @@ class Dataset(object):
|
|||
for img_path, pid, camid in data:
|
||||
if pid in self._junk_pids:
|
||||
continue
|
||||
pid = self.dataset_name + "_" + str(pid)
|
||||
camid = self.dataset_name + "_" + str(camid)
|
||||
pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid)
|
||||
camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid)
|
||||
combined.append((img_path, pid, camid))
|
||||
|
||||
_combine_data(self.query)
|
||||
_combine_data(self.gallery)
|
||||
|
||||
self.train = combined
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
self._train = combined
|
||||
|
||||
def check_before_run(self, required_files):
|
||||
"""Checks if required files exist before going deeper.
|
||||
|
@ -132,9 +148,6 @@ class ImageDataset(Dataset):
|
|||
data in each batch has shape (batch_size, channel, height, width).
|
||||
"""
|
||||
|
||||
def __init__(self, train, query, gallery, **kwargs):
|
||||
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def show_train(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
|
||||
|
|
|
@ -5,20 +5,18 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from scipy.io import loadmat
|
||||
from glob import glob
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
import pdb
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['CAVIARa',]
|
||||
__all__ = ['CAVIARa', ]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CAVIARa(ImageDataset):
|
||||
"""CAVIARa
|
||||
"""
|
||||
dataset_dir = "CAVIARa"
|
||||
dataset_name = "caviara"
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class CUHK03(ImageDataset):
|
|||
|
||||
import h5py
|
||||
from imageio import imwrite
|
||||
from scipy.io import loadmat
|
||||
from scipy import io
|
||||
|
||||
PathManager.mkdirs(self.imgs_detected_dir)
|
||||
PathManager.mkdirs(self.imgs_labeled_dir)
|
||||
|
@ -236,7 +236,7 @@ class CUHK03(ImageDataset):
|
|||
|
||||
print('Creating new split for detected images (767/700) ...')
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_det_mat_path),
|
||||
io.loadmat(self.split_new_det_mat_path),
|
||||
self.imgs_detected_dir
|
||||
)
|
||||
split = [{
|
||||
|
@ -256,7 +256,7 @@ class CUHK03(ImageDataset):
|
|||
|
||||
print('Creating new split for labeled images (767/700) ...')
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_lab_mat_path),
|
||||
io.loadmat(self.split_new_lab_mat_path),
|
||||
self.imgs_labeled_dir
|
||||
)
|
||||
split = [{
|
||||
|
|
|
@ -15,7 +15,7 @@ from ..datasets import DATASET_REGISTRY
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class cuhkSYSU(ImageDataset):
|
||||
r"""CUHK SYSU datasets.
|
||||
"""CUHK SYSU datasets.
|
||||
|
||||
The dataset is collected from two sources: street snap and movie.
|
||||
In street snap, 12,490 images and 6,057 query persons were collected
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
|
||||
__all__ = ['GRID', ]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class GRID(ImageDataset):
|
||||
"""GRID
|
||||
"""
|
||||
dataset_dir = "underground_reid"
|
||||
dataset_name = 'grid'
|
||||
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
self.root = root
|
||||
self.train_path = os.path.join(self.root, self.dataset_dir, 'images')
|
||||
|
||||
required_files = [self.train_path]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_train(self.train_path)
|
||||
|
||||
super().__init__(train, [], [], **kwargs)
|
||||
|
||||
def process_train(self, train_path):
|
||||
data = []
|
||||
img_paths = glob(os.path.join(train_path, "*.jpeg"))
|
||||
|
||||
for img_path in img_paths:
|
||||
img_name = os.path.basename(img_path)
|
||||
img_info = img_name.split('_')
|
||||
pid = self.dataset_name + "_" + img_info[0]
|
||||
camid = self.dataset_name + "_" + img_info[1]
|
||||
data.append([img_path, pid, camid])
|
||||
return data
|
|
@ -15,6 +15,8 @@ __all__ = ['iLIDS', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class iLIDS(ImageDataset):
|
||||
"""iLIDS
|
||||
"""
|
||||
dataset_dir = "iLIDS"
|
||||
dataset_name = "ilids"
|
||||
|
||||
|
|
|
@ -15,7 +15,9 @@ __all__ = ['LPW', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class LPW(ImageDataset):
|
||||
dataset_dir = "pep_256x128"
|
||||
"""LPW
|
||||
"""
|
||||
dataset_dir = "pep_256x128/data_slim"
|
||||
dataset_name = "lpw"
|
||||
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
|
|
|
@ -62,11 +62,10 @@ class Market1501(ImageDataset):
|
|||
required_files.append(self.extra_gallery_dir)
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir)
|
||||
query = self.process_dir(self.query_dir, is_train=False)
|
||||
gallery = self.process_dir(self.gallery_dir, is_train=False)
|
||||
if self.market1501_500k:
|
||||
gallery += self.process_dir(self.extra_gallery_dir, is_train=False)
|
||||
train = lambda: self.process_dir(self.train_dir)
|
||||
query = lambda: self.process_dir(self.query_dir, is_train=False)
|
||||
gallery = lambda: self.process_dir(self.gallery_dir, is_train=False) + \
|
||||
(self.process_dir(self.extra_gallery_dir, is_train=False) if self.market1501_500k else [])
|
||||
|
||||
super(Market1501, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ __all__ = ['PeS3D',]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PeS3D(ImageDataset):
|
||||
"""3Dpes
|
||||
"""
|
||||
dataset_dir = "3DPeS"
|
||||
dataset_name = "pes3d"
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ __all__ = ['PKU', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PKU(ImageDataset):
|
||||
"""PKU
|
||||
"""
|
||||
dataset_dir = "PKUv1a_128x48"
|
||||
dataset_name = 'pku'
|
||||
|
||||
|
|
|
@ -5,18 +5,18 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from scipy.io import loadmat
|
||||
from glob import glob
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
import pdb
|
||||
|
||||
__all__ = ['PRAI',]
|
||||
__all__ = ['PRAI', ]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PRAI(ImageDataset):
|
||||
"""PRAI
|
||||
"""
|
||||
dataset_dir = "PRAI-1581"
|
||||
dataset_name = 'prai'
|
||||
|
||||
|
@ -41,4 +41,3 @@ class PRAI(ImageDataset):
|
|||
camid = self.dataset_name + "_" + img_info[1]
|
||||
data.append([img_path, pid, camid])
|
||||
return data
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
|
||||
__all__ = ['PRID', ]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PRID(ImageDataset):
|
||||
"""PRID
|
||||
"""
|
||||
dataset_dir = "prid_2011"
|
||||
dataset_name = 'prid'
|
||||
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
self.root = root
|
||||
self.train_path = os.path.join(self.root, self.dataset_dir, 'slim_train')
|
||||
|
||||
required_files = [self.train_path]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_train(self.train_path)
|
||||
|
||||
super().__init__(train, [], [], **kwargs)
|
||||
|
||||
def process_train(self, train_path):
|
||||
data = []
|
||||
for root, dirs, files in os.walk(train_path):
|
||||
for img_name in filter(lambda x: x.endswith('.png'), files):
|
||||
img_path = os.path.join(root, img_name)
|
||||
pid = self.dataset_name + '_' + root.split('/')[-1].split('_')[1]
|
||||
camid = self.dataset_name + '_' + img_name.split('_')[0]
|
||||
data.append([img_path, pid, camid])
|
||||
return data
|
|
@ -15,6 +15,8 @@ __all__ = ['SAIVT', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SAIVT(ImageDataset):
|
||||
"""SAIVT
|
||||
"""
|
||||
dataset_dir = "SAIVT-SoftBio"
|
||||
dataset_name = "saivt"
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ __all__ = ['SenseReID', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SenseReID(ImageDataset):
|
||||
"""Sense reid
|
||||
"""
|
||||
dataset_dir = "SenseReID"
|
||||
dataset_name = "senseid"
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@ __all__ = ['Shinpuhkan', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Shinpuhkan(ImageDataset):
|
||||
"""shinpuhkan
|
||||
"""
|
||||
dataset_dir = "shinpuhkan"
|
||||
dataset_name = 'shinpuhkan'
|
||||
|
||||
|
|
|
@ -5,18 +5,18 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from scipy.io import loadmat
|
||||
from glob import glob
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
import pdb
|
||||
|
||||
__all__ = ['SYSU_mm', ]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SYSU_mm(ImageDataset):
|
||||
"""sysu mm
|
||||
"""
|
||||
dataset_dir = "SYSU-MM01"
|
||||
dataset_name = "sysumm01"
|
||||
|
||||
|
@ -35,7 +35,7 @@ class SYSU_mm(ImageDataset):
|
|||
data = []
|
||||
|
||||
file_path_list = ['cam1', 'cam2', 'cam4', 'cam5']
|
||||
|
||||
|
||||
for file_path in file_path_list:
|
||||
camid = self.dataset_name + "_" + file_path
|
||||
pid_list = os.listdir(os.path.join(train_path, file_path))
|
||||
|
@ -45,4 +45,3 @@ class SYSU_mm(ImageDataset):
|
|||
for img_path in img_list:
|
||||
data.append([img_path, pid, camid])
|
||||
return data
|
||||
|
||||
|
|
|
@ -5,20 +5,18 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from scipy.io import loadmat
|
||||
from glob import glob
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
import pdb
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['Thermalworld',]
|
||||
__all__ = ['Thermalworld', ]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Thermalworld(ImageDataset):
|
||||
"""thermal world
|
||||
"""
|
||||
dataset_dir = "thermalworld_rgb"
|
||||
dataset_name = "thermalworld"
|
||||
|
||||
|
@ -38,7 +36,7 @@ class Thermalworld(ImageDataset):
|
|||
pid_list = os.listdir(train_path)
|
||||
for pid_dir in pid_list:
|
||||
pid = self.dataset_name + "_" + pid_dir
|
||||
img_list = glob(os.path.join(train_path, pid_dir, "*.jpg"))
|
||||
img_list = glob(os.path.join(train_path, pid_dir, "*.jpg"))
|
||||
for img_path in img_list:
|
||||
camid = self.dataset_name + "_cam0"
|
||||
data.append([img_path, pid, camid])
|
||||
|
|
|
@ -58,10 +58,12 @@ class VehicleID(ImageDataset):
|
|||
line = line.strip()
|
||||
vid = int(line.split(' ')[1])
|
||||
imgid = line.split(' ')[0]
|
||||
img_path = osp.join(self.image_dir, imgid + '.jpg')
|
||||
img_path = osp.join(self.image_dir, f"{imgid}.jpg")
|
||||
imgid = int(imgid)
|
||||
if is_train:
|
||||
vid = self.dataset_name + "_" + str(vid)
|
||||
dataset.append((img_path, vid, int(imgid)))
|
||||
vid = f"{self.dataset_name}_{vid}"
|
||||
imgid = f"{self.dataset_name}_{imgid}"
|
||||
dataset.append((img_path, vid, imgid))
|
||||
|
||||
if is_train: return dataset
|
||||
else:
|
||||
|
|
|
@ -63,10 +63,12 @@ class VeRiWild(ImageDataset):
|
|||
for idx, line in enumerate(img_list_lines):
|
||||
line = line.strip()
|
||||
vid = int(line.split('/')[0])
|
||||
imgid = line.split('/')[1]
|
||||
imgid = line.split('/')[1].split('.')[0]
|
||||
camid = int(self.imgid2camid[imgid])
|
||||
if is_train:
|
||||
vid = self.dataset_name + "_" + str(vid)
|
||||
dataset.append((self.imgid2imgpath[imgid], vid, int(self.imgid2camid[imgid])))
|
||||
vid = f"{self.dataset_name}_{vid}"
|
||||
camid = f"{self.dataset_name}_{camid}"
|
||||
dataset.append((self.imgid2imgpath[imgid], vid, camid))
|
||||
|
||||
assert len(dataset) == len(img_list_lines)
|
||||
return dataset
|
||||
|
|
|
@ -4,5 +4,15 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler
|
||||
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, SetReWeightSampler
|
||||
from .data_sampler import TrainingSampler, InferenceSampler
|
||||
from .imbalance_sampler import ImbalancedDatasetSampler
|
||||
|
||||
__all__ = [
|
||||
"BalancedIdentitySampler",
|
||||
"NaiveIdentitySampler",
|
||||
"SetReWeightSampler",
|
||||
"TrainingSampler",
|
||||
"InferenceSampler",
|
||||
"ImbalancedDatasetSampler",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
# based on:
|
||||
# https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py
|
||||
|
||||
|
||||
import itertools
|
||||
from typing import Optional, List, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
from fastreid.utils import comm
|
||||
|
||||
|
||||
class ImbalancedDatasetSampler(Sampler):
|
||||
"""Samples elements randomly from a given list of indices for imbalanced dataset
|
||||
Arguments:
|
||||
data_source: a list of data items
|
||||
size: number of samples to draw
|
||||
"""
|
||||
|
||||
def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None,
|
||||
callback_get_label: Callable = None):
|
||||
self.data_source = data_source
|
||||
# consider all elements in the dataset
|
||||
self.indices = list(range(len(data_source)))
|
||||
# if num_samples is not provided, draw `len(indices)` samples in each iteration
|
||||
self._size = len(self.indices) if size is None else size
|
||||
self.callback_get_label = callback_get_label
|
||||
|
||||
# distribution of classes in the dataset
|
||||
label_to_count = {}
|
||||
for idx in self.indices:
|
||||
label = self._get_label(data_source, idx)
|
||||
label_to_count[label] = label_to_count.get(label, 0) + 1
|
||||
|
||||
# weight for each sample
|
||||
weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices]
|
||||
self.weights = torch.DoubleTensor(weights)
|
||||
|
||||
if seed is None:
|
||||
seed = comm.shared_random_seed()
|
||||
self._seed = int(seed)
|
||||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
|
||||
def _get_label(self, dataset, idx):
|
||||
if self.callback_get_label:
|
||||
return self.callback_get_label(dataset, idx)
|
||||
else:
|
||||
return dataset[idx][1]
|
||||
|
||||
def __iter__(self):
|
||||
start = self._rank
|
||||
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
||||
|
||||
def _infinite_indices(self):
|
||||
np.random.seed(self._seed)
|
||||
while True:
|
||||
for i in torch.multinomial(self.weights, self._size, replacement=True):
|
||||
yield self.indices[i]
|
|
@ -7,7 +7,7 @@
|
|||
import copy
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
@ -39,7 +39,7 @@ def reorder_index(batch_indices, world_size):
|
|||
|
||||
|
||||
class BalancedIdentitySampler(Sampler):
|
||||
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
|
||||
def __init__(self, data_source: List, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
|
||||
self.data_source = data_source
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = mini_batch_size // self.num_instances
|
||||
|
@ -119,6 +119,82 @@ class BalancedIdentitySampler(Sampler):
|
|||
batch_indices = []
|
||||
|
||||
|
||||
class SetReWeightSampler(Sampler):
|
||||
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, set_weight: list,
|
||||
seed: Optional[int] = None):
|
||||
self.data_source = data_source
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = mini_batch_size // self.num_instances
|
||||
|
||||
self.set_weight = set_weight
|
||||
|
||||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
self.batch_size = mini_batch_size * self._world_size
|
||||
|
||||
assert self.batch_size % (sum(self.set_weight) * self.num_instances) == 0 and \
|
||||
self.batch_size > sum(
|
||||
self.set_weight) * self.num_instances, "Batch size must be divisible by the sum set weight"
|
||||
|
||||
self.index_pid = dict()
|
||||
self.pid_cam = defaultdict(list)
|
||||
self.pid_index = defaultdict(list)
|
||||
|
||||
self.cam_pid = defaultdict(list)
|
||||
|
||||
for index, info in enumerate(data_source):
|
||||
pid = info[1]
|
||||
camid = info[2]
|
||||
self.index_pid[index] = pid
|
||||
self.pid_cam[pid].append(camid)
|
||||
self.pid_index[pid].append(index)
|
||||
self.cam_pid[camid].append(pid)
|
||||
|
||||
# Get sampler prob for each cam
|
||||
self.set_pid_prob = defaultdict(list)
|
||||
for camid, pid_list in self.cam_pid.items():
|
||||
index_per_pid = []
|
||||
for pid in pid_list:
|
||||
index_per_pid.append(len(self.pid_index[pid]))
|
||||
cam_image_number = sum(index_per_pid)
|
||||
prob = [i / cam_image_number for i in index_per_pid]
|
||||
self.set_pid_prob[camid] = prob
|
||||
|
||||
self.pids = sorted(list(self.pid_index.keys()))
|
||||
self.num_identities = len(self.pids)
|
||||
|
||||
if seed is None:
|
||||
seed = comm.shared_random_seed()
|
||||
self._seed = int(seed)
|
||||
|
||||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
|
||||
def __iter__(self):
|
||||
start = self._rank
|
||||
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
||||
|
||||
def _infinite_indices(self):
|
||||
np.random.seed(self._seed)
|
||||
while True:
|
||||
batch_indices = []
|
||||
for camid in range(len(self.cam_pid.keys())):
|
||||
select_pids = np.random.choice(self.cam_pid[camid], size=self.set_weight[camid], replace=False,
|
||||
p=self.set_pid_prob[camid])
|
||||
for pid in select_pids:
|
||||
index_list = self.pid_index[pid]
|
||||
if len(index_list) > self.num_instances:
|
||||
select_indexs = np.random.choice(index_list, size=self.num_instances, replace=False)
|
||||
else:
|
||||
select_indexs = np.random.choice(index_list, size=self.num_instances, replace=True)
|
||||
|
||||
batch_indices += select_indexs
|
||||
np.random.shuffle(batch_indices)
|
||||
|
||||
if len(batch_indices) == self.batch_size:
|
||||
yield from reorder_index(batch_indices, self._world_size)
|
||||
|
||||
|
||||
class NaiveIdentitySampler(Sampler):
|
||||
"""
|
||||
Randomly sample N identities, then for each identity,
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .autoaugment import *
|
||||
from .autoaugment import AutoAugment
|
||||
from .build import build_transforms
|
||||
from .transforms import *
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
|
|
@ -16,22 +16,28 @@ def build_transforms(cfg, is_train=True):
|
|||
if is_train:
|
||||
size_train = cfg.INPUT.SIZE_TRAIN
|
||||
|
||||
# crop
|
||||
do_crop = cfg.INPUT.CROP.ENABLED
|
||||
crop_size = cfg.INPUT.CROP.SIZE
|
||||
crop_scale = cfg.INPUT.CROP.SCALE
|
||||
crop_ratio = cfg.INPUT.CROP.RATIO
|
||||
|
||||
# augmix augmentation
|
||||
do_augmix = cfg.INPUT.DO_AUGMIX
|
||||
augmix_prob = cfg.INPUT.AUGMIX_PROB
|
||||
do_augmix = cfg.INPUT.AUGMIX.ENABLED
|
||||
augmix_prob = cfg.INPUT.AUGMIX.PROB
|
||||
|
||||
# auto augmentation
|
||||
do_autoaug = cfg.INPUT.DO_AUTOAUG
|
||||
autoaug_prob = cfg.INPUT.AUTOAUG_PROB
|
||||
do_autoaug = cfg.INPUT.AUTOAUG.ENABLED
|
||||
autoaug_prob = cfg.INPUT.AUTOAUG.PROB
|
||||
|
||||
# horizontal filp
|
||||
do_flip = cfg.INPUT.DO_FLIP
|
||||
flip_prob = cfg.INPUT.FLIP_PROB
|
||||
do_flip = cfg.INPUT.FLIP.ENABLED
|
||||
flip_prob = cfg.INPUT.FLIP.PROB
|
||||
|
||||
# padding
|
||||
do_pad = cfg.INPUT.DO_PAD
|
||||
padding = cfg.INPUT.PADDING
|
||||
padding_mode = cfg.INPUT.PADDING_MODE
|
||||
do_pad = cfg.INPUT.PADDING.ENABLED
|
||||
padding_size = cfg.INPUT.PADDING.SIZE
|
||||
padding_mode = cfg.INPUT.PADDING.MODE
|
||||
|
||||
# color jitter
|
||||
do_cj = cfg.INPUT.CJ.ENABLED
|
||||
|
@ -42,7 +48,7 @@ def build_transforms(cfg, is_train=True):
|
|||
cj_hue = cfg.INPUT.CJ.HUE
|
||||
|
||||
# random affine
|
||||
do_affine = cfg.INPUT.DO_AFFINE
|
||||
do_affine = cfg.INPUT.AFFINE.ENABLED
|
||||
|
||||
# random erasing
|
||||
do_rea = cfg.INPUT.REA.ENABLED
|
||||
|
@ -56,16 +62,24 @@ def build_transforms(cfg, is_train=True):
|
|||
if do_autoaug:
|
||||
res.append(T.RandomApply([AutoAugment()], p=autoaug_prob))
|
||||
|
||||
res.append(T.Resize(size_train, interpolation=3))
|
||||
if size_train[0] > 0:
|
||||
res.append(T.Resize(size_train[0] if len(size_train) == 1 else size_train, interpolation=3))
|
||||
|
||||
if do_crop:
|
||||
res.append(T.RandomResizedCrop(size=crop_size[0] if len(crop_size) == 1 else crop_size,
|
||||
interpolation=3,
|
||||
scale=crop_scale, ratio=crop_ratio))
|
||||
if do_pad:
|
||||
res.extend([T.Pad(padding_size, padding_mode=padding_mode),
|
||||
T.RandomCrop(size_train[0] if len(size_train) == 1 else size_train)])
|
||||
if do_flip:
|
||||
res.append(T.RandomHorizontalFlip(p=flip_prob))
|
||||
if do_pad:
|
||||
res.extend([T.Pad(padding, padding_mode=padding_mode), T.RandomCrop(size_train)])
|
||||
|
||||
if do_cj:
|
||||
res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob))
|
||||
if do_affine:
|
||||
res.append(T.RandomAffine(degrees=0, translate=None, scale=[0.9, 1.1], shear=None, resample=False,
|
||||
fillcolor=128))
|
||||
res.append(T.RandomAffine(degrees=10, translate=None, scale=[0.9, 1.1], shear=0.1, resample=False,
|
||||
fillcolor=0))
|
||||
if do_augmix:
|
||||
res.append(AugMix(prob=augmix_prob))
|
||||
res.append(ToTensor())
|
||||
|
@ -75,6 +89,12 @@ def build_transforms(cfg, is_train=True):
|
|||
res.append(RandomPatch(prob_happen=rpt_prob))
|
||||
else:
|
||||
size_test = cfg.INPUT.SIZE_TEST
|
||||
res.append(T.Resize(size_test, interpolation=3))
|
||||
do_crop = cfg.INPUT.CROP.ENABLED
|
||||
crop_size = cfg.INPUT.CROP.SIZE
|
||||
|
||||
if size_test[0] > 0:
|
||||
res.append(T.Resize(size_test[0] if len(size_test) == 1 else size_test, interpolation=3))
|
||||
if do_crop:
|
||||
res.append(T.CenterCrop(size=crop_size[0] if len(crop_size) == 1 else crop_size))
|
||||
res.append(ToTensor())
|
||||
return T.Compose(res)
|
||||
|
|
|
@ -11,7 +11,7 @@ import random
|
|||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from .functional import to_tensor, augmentations
|
||||
|
||||
|
@ -55,8 +55,7 @@ class RandomPatch(object):
|
|||
"""
|
||||
|
||||
def __init__(self, prob_happen=0.5, pool_capacity=50000, min_sample_size=100,
|
||||
patch_min_area=0.01, patch_max_area=0.5, patch_min_ratio=0.1,
|
||||
prob_rotate=0.5, prob_flip_leftright=0.5,
|
||||
patch_min_area=0.01, patch_max_area=0.5, patch_min_ratio=0.1, prob_flip_leftright=0.5,
|
||||
):
|
||||
self.prob_happen = prob_happen
|
||||
|
||||
|
@ -64,7 +63,6 @@ class RandomPatch(object):
|
|||
self.patch_max_area = patch_max_area
|
||||
self.patch_min_ratio = patch_min_ratio
|
||||
|
||||
self.prob_rotate = prob_rotate
|
||||
self.prob_flip_leftright = prob_flip_leftright
|
||||
|
||||
self.patchpool = deque(maxlen=pool_capacity)
|
||||
|
@ -83,23 +81,18 @@ class RandomPatch(object):
|
|||
|
||||
def transform_patch(self, patch):
|
||||
if random.uniform(0, 1) > self.prob_flip_leftright:
|
||||
patch = patch.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if random.uniform(0, 1) > self.prob_rotate:
|
||||
patch = patch.rotate(random.randint(-10, 10))
|
||||
patch = torch.flip(patch, dims=[2])
|
||||
return patch
|
||||
|
||||
def __call__(self, img):
|
||||
if isinstance(img, np.ndarray):
|
||||
img = Image.fromarray(img.astype(np.uint8))
|
||||
|
||||
W, H = img.size # original image size
|
||||
_, H, W = img.size() # original image size
|
||||
|
||||
# collect new patch
|
||||
w, h = self.generate_wh(W, H)
|
||||
if w is not None and h is not None:
|
||||
x1 = random.randint(0, W - w)
|
||||
y1 = random.randint(0, H - h)
|
||||
new_patch = img.crop((x1, y1, x1 + w, y1 + h))
|
||||
new_patch = img[..., y1:y1 + h, x1:x1 + w]
|
||||
self.patchpool.append(new_patch)
|
||||
|
||||
if len(self.patchpool) < self.min_sample_size:
|
||||
|
@ -110,52 +103,54 @@ class RandomPatch(object):
|
|||
|
||||
# paste a randomly selected patch on a random position
|
||||
patch = random.sample(self.patchpool, 1)[0]
|
||||
patchW, patchH = patch.size
|
||||
_, patchH, patchW = patch.size()
|
||||
x1 = random.randint(0, W - patchW)
|
||||
y1 = random.randint(0, H - patchH)
|
||||
patch = self.transform_patch(patch)
|
||||
img.paste(patch, (x1, y1))
|
||||
img[..., y1:y1 + patchH, x1:x1 + patchW] = patch
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class AugMix(object):
|
||||
""" Perform AugMix augmentation and compute mixture.
|
||||
Args:
|
||||
prob: Probability of taking augmix
|
||||
aug_prob_coeff: Probability distribution coefficients.
|
||||
mixture_width: Number of augmentation chains to mix per augmented example.
|
||||
mixture_depth: Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]'
|
||||
aug_severity: Severity of underlying augmentation operators (between 1 to 10).
|
||||
"""
|
||||
|
||||
def __init__(self, prob=0.5, aug_prob_coeff=0.1, mixture_width=3, mixture_depth=1, aug_severity=1):
|
||||
self.prob = prob
|
||||
"""
|
||||
Args:
|
||||
prob: Probability of taking augmix
|
||||
aug_prob_coeff: Probability distribution coefficients.
|
||||
mixture_width: Number of augmentation chains to mix per augmented example.
|
||||
mixture_depth: Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]'
|
||||
aug_severity: Severity of underlying augmentation operators (between 1 to 10).
|
||||
"""
|
||||
# fmt: off
|
||||
self.prob = prob
|
||||
self.aug_prob_coeff = aug_prob_coeff
|
||||
self.mixture_width = mixture_width
|
||||
self.mixture_depth = mixture_depth
|
||||
self.aug_severity = aug_severity
|
||||
self.augmentations = augmentations
|
||||
self.mixture_width = mixture_width
|
||||
self.mixture_depth = mixture_depth
|
||||
self.aug_severity = aug_severity
|
||||
self.augmentations = augmentations
|
||||
# fmt: on
|
||||
|
||||
def __call__(self, image):
|
||||
"""Perform AugMix augmentations and compute mixture.
|
||||
|
||||
Returns:
|
||||
mixed: Augmented and mixed image.
|
||||
"""
|
||||
if random.random() > self.prob:
|
||||
return np.asarray(image)
|
||||
# Avoid the warning: the given NumPy array is not writeable
|
||||
return np.asarray(image).copy()
|
||||
|
||||
ws = np.float32(
|
||||
np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
|
||||
m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
|
||||
|
||||
# image = np.asarray(image, dtype=np.float32).copy()
|
||||
# mix = np.zeros_like(image)
|
||||
mix = np.zeros([image.size[1], image.size[0], 3])
|
||||
# h, w = image.shape[0], image.shape[1]
|
||||
for i in range(self.mixture_width):
|
||||
image_aug = image.copy()
|
||||
# image_aug = Image.fromarray(image.copy().astype(np.uint8))
|
||||
depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4)
|
||||
for _ in range(depth):
|
||||
op = np.random.choice(self.augmentations)
|
||||
|
|
|
@ -11,12 +11,11 @@ since they are meant to represent the "common default behavior" people need in t
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from fastreid.data import build_reid_test_loader, build_reid_train_loader
|
||||
from fastreid.evaluation import (ReidEvaluator,
|
||||
|
@ -33,14 +32,6 @@ from fastreid.utils.logger import setup_logger
|
|||
from . import hooks
|
||||
from .train_loop import TrainerBase, AMPTrainer, SimpleTrainer
|
||||
|
||||
try:
|
||||
import apex
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example if you want to"
|
||||
"train with DDP")
|
||||
|
||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||
|
||||
|
||||
|
@ -93,7 +84,7 @@ def default_setup(cfg, args):
|
|||
PathManager.mkdirs(output_dir)
|
||||
|
||||
rank = comm.get_rank()
|
||||
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
||||
# setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
||||
logger = setup_logger(output_dir, distributed_rank=rank)
|
||||
|
||||
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
|
||||
|
@ -157,13 +148,10 @@ class DefaultPredictor:
|
|||
Returns:
|
||||
predictions (torch.tensor): the output features of the model
|
||||
"""
|
||||
inputs = {"images": image}
|
||||
inputs = {"images": image.to(self.model.device)}
|
||||
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
||||
predictions = self.model(inputs)
|
||||
# Normalize feature to compute cosine distance
|
||||
features = F.normalize(predictions)
|
||||
features = features.cpu().data
|
||||
return features
|
||||
return predictions.cpu()
|
||||
|
||||
|
||||
class DefaultTrainer(TrainerBase):
|
||||
|
@ -213,27 +201,22 @@ class DefaultTrainer(TrainerBase):
|
|||
data_loader = self.build_train_loader(cfg)
|
||||
cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
|
||||
optimizer_ckpt = dict(optimizer=optimizer)
|
||||
if cfg.SOLVER.FP16_ENABLED:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
|
||||
optimizer_ckpt.update(dict(amp=amp))
|
||||
optimizer, param_wrapper = self.build_optimizer(cfg, model)
|
||||
|
||||
# For training, wrap with DDP. But don't need this for inference.
|
||||
if comm.get_world_size() > 1:
|
||||
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
|
||||
# for part of the parameters is not updated.
|
||||
# model = DistributedDataParallel(
|
||||
# model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
|
||||
# )
|
||||
model = DistributedDataParallel(model, delay_allreduce=True)
|
||||
model = DistributedDataParallel(
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
|
||||
)
|
||||
|
||||
self._trainer = (AMPTrainer if cfg.SOLVER.FP16_ENABLED else SimpleTrainer)(
|
||||
model, data_loader, optimizer
|
||||
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
||||
model, data_loader, optimizer, param_wrapper
|
||||
)
|
||||
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
||||
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch)
|
||||
|
||||
# Assume no other objects need to be checkpointed.
|
||||
# We can later make it checkpoint the stateful hooks
|
||||
|
@ -242,16 +225,14 @@ class DefaultTrainer(TrainerBase):
|
|||
model,
|
||||
cfg.OUTPUT_DIR,
|
||||
save_to_disk=comm.is_main_process(),
|
||||
**optimizer_ckpt,
|
||||
optimizer=optimizer,
|
||||
**self.scheduler,
|
||||
)
|
||||
|
||||
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
|
||||
|
||||
self.start_epoch = 0
|
||||
self.max_epoch = cfg.SOLVER.MAX_EPOCH
|
||||
self.max_iter = self.max_epoch * self.iters_per_epoch
|
||||
self.warmup_epochs = cfg.SOLVER.WARMUP_EPOCHS
|
||||
self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
|
||||
self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
|
||||
self.cfg = cfg
|
||||
|
||||
|
@ -296,17 +277,6 @@ class DefaultTrainer(TrainerBase):
|
|||
hooks.LRScheduler(self.optimizer, self.scheduler),
|
||||
]
|
||||
|
||||
# if cfg.SOLVER.SWA.ENABLED:
|
||||
# ret.append(
|
||||
# hooks.SWA(
|
||||
# cfg.SOLVER.MAX_ITER,
|
||||
# cfg.SOLVER.SWA.PERIOD,
|
||||
# cfg.SOLVER.SWA.LR_FACTOR,
|
||||
# cfg.SOLVER.SWA.ETA_MIN_LR,
|
||||
# cfg.SOLVER.SWA.LR_SCHED,
|
||||
# )
|
||||
# )
|
||||
|
||||
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
|
||||
logger.info("Prepare precise BN dataset")
|
||||
ret.append(hooks.PreciseBN(
|
||||
|
@ -317,12 +287,13 @@ class DefaultTrainer(TrainerBase):
|
|||
cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
))
|
||||
|
||||
ret.append(hooks.LayerFreeze(
|
||||
self.model,
|
||||
cfg.MODEL.FREEZE_LAYERS,
|
||||
cfg.SOLVER.FREEZE_ITERS,
|
||||
cfg.SOLVER.FREEZE_FC_ITERS,
|
||||
))
|
||||
if len(cfg.MODEL.FREEZE_LAYERS) > 0 and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||
ret.append(hooks.LayerFreeze(
|
||||
self.model,
|
||||
cfg.MODEL.FREEZE_LAYERS,
|
||||
cfg.SOLVER.FREEZE_ITERS,
|
||||
))
|
||||
|
||||
# Do PreciseBN before checkpointer, because it updates the model and need to
|
||||
# be saved by checkpointer.
|
||||
# This is not always the best: if checkpointing has a different frequency,
|
||||
|
@ -409,12 +380,12 @@ class DefaultTrainer(TrainerBase):
|
|||
return build_optimizer(cfg, model)
|
||||
|
||||
@classmethod
|
||||
def build_lr_scheduler(cls, cfg, optimizer):
|
||||
def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch):
|
||||
"""
|
||||
It now calls :func:`fastreid.solver.build_lr_scheduler`.
|
||||
Overwrite it if you'd like a different scheduler.
|
||||
"""
|
||||
return build_lr_scheduler(cfg, optimizer)
|
||||
return build_lr_scheduler(cfg, optimizer, iters_per_epoch)
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
|
@ -426,7 +397,7 @@ class DefaultTrainer(TrainerBase):
|
|||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Prepare training set")
|
||||
return build_reid_train_loader(cfg)
|
||||
return build_reid_train_loader(cfg, combineall=cfg.DATASETS.COMBINEALL)
|
||||
|
||||
@classmethod
|
||||
def build_test_loader(cls, cfg, dataset_name):
|
||||
|
@ -436,7 +407,7 @@ class DefaultTrainer(TrainerBase):
|
|||
It now calls :func:`fastreid.data.build_reid_test_loader`.
|
||||
Overwrite it if you'd like a different data loader.
|
||||
"""
|
||||
return build_reid_test_loader(cfg, dataset_name)
|
||||
return build_reid_test_loader(cfg, dataset_name=dataset_name)
|
||||
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
|
||||
|
@ -465,18 +436,21 @@ class DefaultTrainer(TrainerBase):
|
|||
)
|
||||
results[dataset_name] = {}
|
||||
continue
|
||||
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP_ENABLED)
|
||||
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP.ENABLED)
|
||||
results[dataset_name] = results_i
|
||||
|
||||
if comm.is_main_process():
|
||||
assert isinstance(
|
||||
results, dict
|
||||
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
||||
results
|
||||
)
|
||||
print_csv_format(results)
|
||||
if comm.is_main_process():
|
||||
assert isinstance(
|
||||
results, dict
|
||||
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
||||
results
|
||||
)
|
||||
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
||||
results_i['dataset'] = dataset_name
|
||||
print_csv_format(results_i)
|
||||
|
||||
if len(results) == 1: results = list(results.values())[0]
|
||||
if len(results) == 1:
|
||||
results = list(results.values())[0]
|
||||
|
||||
return results
|
||||
|
||||
|
@ -512,5 +486,5 @@ class DefaultTrainer(TrainerBase):
|
|||
|
||||
|
||||
# Access basic attributes from the underlying trainer
|
||||
for _attr in ["model", "data_loader", "optimizer"]:
|
||||
setattr(DefaultTrainer, _attr, property(lambda self, x=_attr: getattr(self._trainer, x)))
|
||||
for _attr in ["model", "data_loader", "optimizer", "grad_scaler"]:
|
||||
setattr(DefaultTrainer, _attr, property(lambda self, x=_attr: getattr(self._trainer, x, None)))
|
||||
|
|
|
@ -11,7 +11,7 @@ from collections import Counter
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
from apex.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from fastreid.evaluation.testing import flatten_results_dict
|
||||
from fastreid.solver import optim
|
||||
|
@ -226,6 +226,7 @@ class LRScheduler(HookBase):
|
|||
"""
|
||||
self._optimizer = optimizer
|
||||
self._scheduler = scheduler
|
||||
self._scale = 0
|
||||
|
||||
# NOTE: some heuristics on what LR to summarize
|
||||
# summarize the param group with most parameters
|
||||
|
@ -246,15 +247,23 @@ class LRScheduler(HookBase):
|
|||
self._best_param_group_id = i
|
||||
break
|
||||
|
||||
def before_step(self):
|
||||
if self.trainer.grad_scaler is not None:
|
||||
self._scale = self.trainer.grad_scaler.get_scale()
|
||||
|
||||
def after_step(self):
|
||||
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
|
||||
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
|
||||
|
||||
next_iter = self.trainer.iter + 1
|
||||
if next_iter <= self.trainer.warmup_iters:
|
||||
if self.trainer.grad_scaler is None or self._scale == self.trainer.grad_scaler.get_scale():
|
||||
self._scheduler["warmup_sched"].step()
|
||||
|
||||
def after_epoch(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
next_epoch = self.trainer.epoch + 1
|
||||
if next_epoch <= self.trainer.warmup_epochs:
|
||||
self._scheduler["warmup_sched"].step()
|
||||
elif next_epoch >= self.trainer.delay_epochs:
|
||||
if next_iter > self.trainer.warmup_iters and next_epoch > self.trainer.delay_epochs:
|
||||
self._scheduler["lr_sched"].step()
|
||||
|
||||
|
||||
|
@ -357,19 +366,21 @@ class EvalHook(HookBase):
|
|||
)
|
||||
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
|
||||
|
||||
# Remove extra memory cache of main process due to evaluation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def after_epoch(self):
|
||||
next_epoch = self.trainer.epoch + 1
|
||||
is_final = next_epoch == self.trainer.max_epoch
|
||||
if is_final or (self._period > 0 and next_epoch % self._period == 0):
|
||||
self._do_eval()
|
||||
# Evaluation may take different time among workers.
|
||||
# A barrier make them start the next iteration together.
|
||||
comm.synchronize()
|
||||
|
||||
def after_epoch(self):
|
||||
next_epoch = self.trainer.epoch + 1
|
||||
if self._period > 0 and next_epoch % self._period == 0:
|
||||
self._do_eval()
|
||||
|
||||
def after_train(self):
|
||||
next_epoch = self.trainer.epoch + 1
|
||||
# This condition is to prevent the eval from running after a failed training
|
||||
if next_epoch % self._period != 0 and next_epoch >= self.trainer.max_epoch:
|
||||
self._do_eval()
|
||||
# func is likely a closure that holds reference to the trainer
|
||||
# therefore we clean it to avoid circular reference in the end
|
||||
del self._func
|
||||
|
@ -444,19 +455,16 @@ class PreciseBN(HookBase):
|
|||
|
||||
|
||||
class LayerFreeze(HookBase):
|
||||
def __init__(self, model, freeze_layers, freeze_iters, fc_freeze_iters):
|
||||
def __init__(self, model, freeze_layers, freeze_iters):
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
if isinstance(model, DistributedDataParallel):
|
||||
model = model.module
|
||||
self.model = model
|
||||
|
||||
self.freeze_layers = freeze_layers
|
||||
self.freeze_iters = freeze_iters
|
||||
self.fc_freeze_iters = fc_freeze_iters
|
||||
|
||||
self.is_frozen = False
|
||||
self.fc_frozen = False
|
||||
|
||||
def before_step(self):
|
||||
# Freeze specific layers
|
||||
|
@ -467,18 +475,6 @@ class LayerFreeze(HookBase):
|
|||
if self.trainer.iter >= self.freeze_iters and self.is_frozen:
|
||||
self.open_all_layer()
|
||||
|
||||
if self.trainer.max_iter - self.trainer.iter <= self.fc_freeze_iters \
|
||||
and not self.fc_frozen:
|
||||
self.freeze_classifier()
|
||||
|
||||
def freeze_classifier(self):
|
||||
for p in self.model.heads.classifier.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.fc_frozen = True
|
||||
self._logger.info("Freeze classifier training for "
|
||||
"last {} iterations".format(self.fc_freeze_iters))
|
||||
|
||||
def freeze_specific_layer(self):
|
||||
for layer in self.freeze_layers:
|
||||
if not hasattr(self.model, layer):
|
||||
|
@ -488,8 +484,6 @@ class LayerFreeze(HookBase):
|
|||
if name in self.freeze_layers:
|
||||
# Change BN in freeze layers to eval mode
|
||||
module.eval()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.is_frozen = True
|
||||
freeze_layers = ", ".join(self.freeze_layers)
|
||||
|
@ -499,8 +493,6 @@ class LayerFreeze(HookBase):
|
|||
for name, module in self.model.named_children():
|
||||
if name in self.freeze_layers:
|
||||
module.train()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_(True)
|
||||
|
||||
self.is_frozen = False
|
||||
|
||||
|
|
|
@ -11,11 +11,11 @@ from typing import Dict
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
import fastreid.utils.comm as comm
|
||||
from fastreid.utils.events import EventStorage, get_event_storage
|
||||
from fastreid.utils.params import ContiguousParams
|
||||
|
||||
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
|
||||
|
||||
|
@ -98,9 +98,10 @@ class TrainerBase:
|
|||
We made no assumptions about the existence of dataloader, optimizer, model, etc.
|
||||
Attributes:
|
||||
iter(int): the current iteration.
|
||||
epoch(int): the current epoch.
|
||||
start_iter(int): The iteration to start with.
|
||||
By convention the minimum possible value is 0.
|
||||
max_iter(int): The iteration to end training.
|
||||
max_epoch (int): The epoch to end training.
|
||||
storage(EventStorage): An EventStorage that's opened during the course of training.
|
||||
"""
|
||||
|
||||
|
@ -127,7 +128,7 @@ class TrainerBase:
|
|||
def train(self, start_epoch: int, max_epoch: int, iters_per_epoch: int):
|
||||
"""
|
||||
Args:
|
||||
start_iter, max_iter (int): See docs above
|
||||
start_epoch, max_epoch (int): See docs above
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Starting training from epoch {}".format(start_epoch))
|
||||
|
@ -197,7 +198,7 @@ class SimpleTrainer(TrainerBase):
|
|||
or write your own training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer):
|
||||
def __init__(self, model, data_loader, optimizer, param_wrapper):
|
||||
"""
|
||||
Args:
|
||||
model: a torch Module. Takes a data from data_loader and returns a
|
||||
|
@ -219,6 +220,7 @@ class SimpleTrainer(TrainerBase):
|
|||
self.data_loader = data_loader
|
||||
self._data_loader_iter = iter(data_loader)
|
||||
self.optimizer = optimizer
|
||||
self.param_wrapper = param_wrapper
|
||||
|
||||
def run_step(self):
|
||||
"""
|
||||
|
@ -254,6 +256,8 @@ class SimpleTrainer(TrainerBase):
|
|||
wrap the optimizer with your custom `step()` method.
|
||||
"""
|
||||
self.optimizer.step()
|
||||
if isinstance(self.param_wrapper, ContiguousParams):
|
||||
self.param_wrapper.assert_buffer_is_valid()
|
||||
|
||||
def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float):
|
||||
"""
|
||||
|
@ -299,29 +303,52 @@ class SimpleTrainer(TrainerBase):
|
|||
|
||||
class AMPTrainer(SimpleTrainer):
|
||||
"""
|
||||
Like :class:`SimpleTrainer`, but uses apex automatic mixed precision
|
||||
Like :class:`SimpleTrainer`, but uses automatic mixed precision
|
||||
in the training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer, param_wrapper, grad_scaler=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model, data_loader, optimizer: same as in :class:`SimpleTrainer`.
|
||||
grad_scaler: torch GradScaler to automatically scale gradients.
|
||||
"""
|
||||
unsupported = "AMPTrainer does not support single-process multi-device training!"
|
||||
if isinstance(model, DistributedDataParallel):
|
||||
assert not (model.device_ids and len(model.device_ids) > 1), unsupported
|
||||
assert not isinstance(model, DataParallel), unsupported
|
||||
|
||||
super().__init__(model, data_loader, optimizer, param_wrapper)
|
||||
|
||||
if grad_scaler is None:
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
grad_scaler = GradScaler()
|
||||
self.grad_scaler = grad_scaler
|
||||
|
||||
def run_step(self):
|
||||
"""
|
||||
Implement the AMP training logic.
|
||||
"""
|
||||
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
|
||||
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
start = time.perf_counter()
|
||||
data = next(self._data_loader_iter)
|
||||
data_time = time.perf_counter() - start
|
||||
|
||||
loss_dict = self.model(data)
|
||||
losses = sum(loss_dict.values())
|
||||
with autocast():
|
||||
loss_dict = self.model(data)
|
||||
losses = sum(loss_dict.values())
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with amp.scale_loss(losses, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
self.grad_scaler.scale(losses).backward()
|
||||
|
||||
self._write_metrics(loss_dict, data_time)
|
||||
|
||||
self.optimizer.step()
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
if isinstance(self.param_wrapper, ContiguousParams):
|
||||
self.param_wrapper.assert_buffer_is_valid()
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset
|
||||
from .rank import evaluate_rank
|
||||
from .roc import evaluate_roc
|
||||
from .reid_evaluation import ReidEvaluator
|
||||
from .clas_evaluator import ClasEvaluator
|
||||
from .testing import print_csv_format, verify_results
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
|
|
@ -5,15 +5,16 @@
|
|||
"""
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from fastreid.evaluation import DatasetEvaluator
|
||||
from fastreid.utils import comm
|
||||
from .evaluator import DatasetEvaluator
|
||||
|
||||
logger = logging.getLogger("fastreid.cls_evaluator")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
|
@ -33,47 +34,48 @@ def accuracy(output, target, topk=(1,)):
|
|||
return res
|
||||
|
||||
|
||||
class ClsEvaluator(DatasetEvaluator):
|
||||
class ClasEvaluator(DatasetEvaluator):
|
||||
def __init__(self, cfg, output_dir=None):
|
||||
self.cfg = cfg
|
||||
self._output_dir = output_dir
|
||||
self._cpu_device = torch.device('cpu')
|
||||
|
||||
self.pred_logits = []
|
||||
self.labels = []
|
||||
self._predictions = []
|
||||
|
||||
def reset(self):
|
||||
self.pred_logits = []
|
||||
self.labels = []
|
||||
self._predictions = []
|
||||
|
||||
def process(self, inputs, outputs):
|
||||
self.pred_logits.append(outputs.cpu())
|
||||
self.labels.extend(inputs["targets"])
|
||||
pred_logits = outputs.to(self._cpu_device, torch.float32)
|
||||
labels = inputs["targets"].to(self._cpu_device)
|
||||
|
||||
# measure accuracy
|
||||
acc1, = accuracy(pred_logits, labels, topk=(1,))
|
||||
num_correct_acc1 = acc1 * labels.size(0) / 100
|
||||
|
||||
self._predictions.append({"num_correct": num_correct_acc1, "num_samples": labels.size(0)})
|
||||
|
||||
def evaluate(self):
|
||||
if comm.get_world_size() > 1:
|
||||
comm.synchronize()
|
||||
pred_logits = comm.gather(self.pred_logits)
|
||||
pred_logits = sum(pred_logits, [])
|
||||
predictions = comm.gather(self._predictions, dst=0)
|
||||
predictions = list(itertools.chain(*predictions))
|
||||
|
||||
labels = comm.gather(self.labels)
|
||||
labels = sum(labels, [])
|
||||
|
||||
# fmt: off
|
||||
if not comm.is_main_process(): return {}
|
||||
# fmt: on
|
||||
|
||||
else:
|
||||
pred_logits = self.pred_logits
|
||||
labels = self.labels
|
||||
predictions = self._predictions
|
||||
|
||||
pred_logits = torch.cat(pred_logits, dim=0)
|
||||
labels = torch.stack(labels)
|
||||
total_correct_num = 0
|
||||
total_samples = 0
|
||||
for prediction in predictions:
|
||||
total_correct_num += prediction["num_correct"]
|
||||
total_samples += prediction["num_samples"]
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, = accuracy(pred_logits, labels, topk=(1,))
|
||||
acc1 = total_correct_num / total_samples * 100
|
||||
|
||||
self._results = OrderedDict()
|
||||
self._results["Acc@1"] = acc1
|
||||
|
||||
self._results["metric"] = acc1
|
||||
|
||||
return copy.deepcopy(self._results)
|
|
@ -6,6 +6,7 @@ from contextlib import contextmanager
|
|||
|
||||
import torch
|
||||
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.logger import log_every_n_seconds
|
||||
|
||||
|
||||
|
@ -96,6 +97,7 @@ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
|
|||
Returns:
|
||||
The return value of `evaluator.evaluate()`
|
||||
"""
|
||||
num_devices = comm.get_world_size()
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
|
||||
|
||||
|
@ -118,10 +120,11 @@ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
|
|||
inputs["images"] = inputs["images"].flip(dims=[3])
|
||||
flip_outputs = model(inputs)
|
||||
outputs = (outputs + flip_outputs) / 2
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
total_compute_time += time.perf_counter() - start_compute_time
|
||||
evaluator.process(inputs, outputs)
|
||||
|
||||
idx += 1
|
||||
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
|
||||
seconds_per_batch = total_compute_time / iters_after_start
|
||||
if idx >= num_warmup * 2 or seconds_per_batch > 30:
|
||||
|
@ -140,17 +143,18 @@ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
|
|||
total_time_str = str(datetime.timedelta(seconds=total_time))
|
||||
# NOTE this format is parsed by grep
|
||||
logger.info(
|
||||
"Total inference time: {} ({:.6f} s / batch per device)".format(
|
||||
total_time_str, total_time / (total - num_warmup)
|
||||
"Total inference time: {} ({:.6f} s / batch per device, on {} devices)".format(
|
||||
total_time_str, total_time / (total - num_warmup), num_devices
|
||||
)
|
||||
)
|
||||
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
|
||||
logger.info(
|
||||
"Total inference pure compute time: {} ({:.6f} s / batch per device)".format(
|
||||
total_compute_time_str, total_compute_time / (total - num_warmup)
|
||||
"Total inference pure compute time: {} ({:.6f} s / batch per device, on {} devices)".format(
|
||||
total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
|
||||
)
|
||||
)
|
||||
results = evaluator.evaluate()
|
||||
|
||||
# An evaluator may return None when not in main process.
|
||||
# Replace it by an empty dict instead to make it easier for downstream code to handle
|
||||
if results is None:
|
||||
|
|
|
@ -107,9 +107,6 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
||||
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
|
@ -127,7 +124,8 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
matches = (g_pids[order] == q_pid).astype(np.int32)
|
||||
raw_cmc = matches[keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
@ -163,7 +161,7 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||
|
||||
def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03):
|
||||
if use_metric_cuhk03:
|
||||
return eval_cuhk03(distmat, g_pids, q_camids, g_camids, max_rank)
|
||||
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
else:
|
||||
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
|
||||
|
@ -176,7 +174,7 @@ def evaluate_rank(
|
|||
g_camids,
|
||||
max_rank=50,
|
||||
use_metric_cuhk03=False,
|
||||
use_cython=True
|
||||
use_cython=True,
|
||||
):
|
||||
"""Evaluates CMC rank.
|
||||
Args:
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
all:
|
||||
python3 setup.py build_ext --inplace
|
||||
rm -rf build
|
||||
python3 test_cython.py
|
||||
clean:
|
||||
rm -rf build
|
||||
rm -f rank_cy.c *.so
|
||||
|
|
|
@ -2,4 +2,19 @@
|
|||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
"""
|
||||
|
||||
|
||||
def compile_helper():
|
||||
"""Compile helper function at runtime. Make sure this
|
||||
is invoked on a single process."""
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
path = os.path.abspath(os.path.dirname(__file__))
|
||||
ret = subprocess.run(["make", "-C", path])
|
||||
if ret.returncode != 0:
|
||||
print("Making cython reid evaluation module failed, exiting.")
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
|
|
@ -5,7 +5,6 @@ import cython
|
|||
import numpy as np
|
||||
cimport numpy as np
|
||||
from collections import defaultdict
|
||||
import faiss
|
||||
|
||||
|
||||
"""
|
||||
|
@ -160,7 +159,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
|
||||
cdef:
|
||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
||||
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||
long[:] matches
|
||||
|
||||
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
||||
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
|
||||
|
@ -193,14 +192,15 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
order[g_idx] = indices[q_idx, g_idx]
|
||||
num_g_real = 0
|
||||
meet_condition = 0
|
||||
matches = (np.asarray(g_pids)[np.asarray(order)] == q_pid).astype(np.int64)
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
for g_idx in range(num_g):
|
||||
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
|
||||
raw_cmc[num_g_real] = matches[q_idx][g_idx]
|
||||
raw_cmc[num_g_real] = matches[g_idx]
|
||||
num_g_real += 1
|
||||
# this condition is true if query appear in gallery
|
||||
if matches[q_idx][g_idx] > 1e-31:
|
||||
if matches[g_idx] > 1e-31:
|
||||
meet_condition = 1
|
||||
|
||||
if not meet_condition:
|
||||
|
|
|
@ -5,8 +5,8 @@ import os.path as osp
|
|||
|
||||
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
|
||||
|
||||
from fastreid.evaluation import evaluate_rank
|
||||
from fastreid.evaluation import evaluate_roc
|
||||
from fastreid.evaluation.rank import evaluate_rank
|
||||
from fastreid.evaluation.roc import evaluate_roc
|
||||
|
||||
"""
|
||||
Test the speed of cython-based evaluation code. The speed improvements
|
||||
|
@ -24,8 +24,8 @@ import sys
|
|||
import os.path as osp
|
||||
import numpy as np
|
||||
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
|
||||
from fastreid.evaluation import evaluate_rank
|
||||
from fastreid.evaluation import evaluate_roc
|
||||
from fastreid.evaluation.rank import evaluate_rank
|
||||
from fastreid.evaluation.roc import evaluate_roc
|
||||
num_q = 30
|
||||
num_g = 300
|
||||
dim = 512
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
"""
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
@ -16,8 +18,7 @@ from fastreid.utils import comm
|
|||
from fastreid.utils.compute_dist import build_dist
|
||||
from .evaluator import DatasetEvaluator
|
||||
from .query_expansion import aqe
|
||||
from .rank import evaluate_rank
|
||||
from .roc import evaluate_roc
|
||||
from .rank_cylib import compile_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -28,50 +29,55 @@ class ReidEvaluator(DatasetEvaluator):
|
|||
self._num_query = num_query
|
||||
self._output_dir = output_dir
|
||||
|
||||
self.features = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
self._cpu_device = torch.device('cpu')
|
||||
|
||||
self._predictions = []
|
||||
self._compile_dependencies()
|
||||
|
||||
def reset(self):
|
||||
self.features = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
self._predictions = []
|
||||
|
||||
def process(self, inputs, outputs):
|
||||
self.pids.extend(inputs["targets"])
|
||||
self.camids.extend(inputs["camids"])
|
||||
self.features.append(outputs.cpu())
|
||||
prediction = {
|
||||
'feats': outputs.to(self._cpu_device, torch.float32),
|
||||
'pids': inputs['targets'].to(self._cpu_device),
|
||||
'camids': inputs['camids'].to(self._cpu_device)
|
||||
|
||||
}
|
||||
self._predictions.append(prediction)
|
||||
|
||||
def evaluate(self):
|
||||
if comm.get_world_size() > 1:
|
||||
comm.synchronize()
|
||||
features = comm.gather(self.features)
|
||||
features = sum(features, [])
|
||||
predictions = comm.gather(self._predictions, dst=0)
|
||||
predictions = list(itertools.chain(*predictions))
|
||||
|
||||
pids = comm.gather(self.pids)
|
||||
pids = sum(pids, [])
|
||||
if not comm.is_main_process():
|
||||
return {}
|
||||
|
||||
camids = comm.gather(self.camids)
|
||||
camids = sum(camids, [])
|
||||
|
||||
# fmt: off
|
||||
if not comm.is_main_process(): return {}
|
||||
# fmt: on
|
||||
else:
|
||||
features = self.features
|
||||
pids = self.pids
|
||||
camids = self.camids
|
||||
predictions = self._predictions
|
||||
|
||||
features = []
|
||||
pids = []
|
||||
camids = []
|
||||
for prediction in predictions:
|
||||
features.append(prediction['feats'])
|
||||
pids.append(prediction['pids'])
|
||||
camids.append(prediction['camids'])
|
||||
|
||||
features = torch.cat(features, dim=0)
|
||||
pids = torch.cat(pids, dim=0).numpy()
|
||||
camids = torch.cat(camids, dim=0).numpy()
|
||||
# query feature, person ids and camera ids
|
||||
query_features = features[:self._num_query]
|
||||
query_pids = np.asarray(pids[:self._num_query])
|
||||
query_camids = np.asarray(camids[:self._num_query])
|
||||
query_pids = pids[:self._num_query]
|
||||
query_camids = camids[:self._num_query]
|
||||
|
||||
# gallery features, person ids and camera ids
|
||||
gallery_features = features[self._num_query:]
|
||||
gallery_pids = np.asarray(pids[self._num_query:])
|
||||
gallery_camids = np.asarray(camids[self._num_query:])
|
||||
gallery_pids = pids[self._num_query:]
|
||||
gallery_camids = camids[self._num_query:]
|
||||
|
||||
self._results = OrderedDict()
|
||||
|
||||
|
@ -97,6 +103,7 @@ class ReidEvaluator(DatasetEvaluator):
|
|||
rerank_dist = build_dist(query_features, gallery_features, metric="jaccard", k1=k1, k2=k2)
|
||||
dist = rerank_dist * (1 - lambda_value) + dist * lambda_value
|
||||
|
||||
from .rank import evaluate_rank
|
||||
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
|
||||
mAP = np.mean(all_AP)
|
||||
|
@ -107,7 +114,8 @@ class ReidEvaluator(DatasetEvaluator):
|
|||
self._results['mINP'] = mINP * 100
|
||||
self._results["metric"] = (mAP + cmc[0]) / 2 * 100
|
||||
|
||||
if self.cfg.TEST.ROC_ENABLED:
|
||||
if self.cfg.TEST.ROC.ENABLED:
|
||||
from .roc import evaluate_roc
|
||||
scores, labels = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
fprs, tprs, thres = metrics.roc_curve(labels, scores)
|
||||
|
||||
|
@ -116,3 +124,20 @@ class ReidEvaluator(DatasetEvaluator):
|
|||
self._results["TPR@FPR={:.0e}".format(fpr)] = tprs[ind]
|
||||
|
||||
return copy.deepcopy(self._results)
|
||||
|
||||
def _compile_dependencies(self):
|
||||
# Since we only evaluate results in rank(0), so we just need to compile
|
||||
# cython evaluation tool on rank(0)
|
||||
if comm.is_main_process():
|
||||
try:
|
||||
from .rank_cylib.rank_cy import evaluate_cy
|
||||
except ImportError:
|
||||
start_time = time.time()
|
||||
logger.info("> compiling reid evaluation cython tool")
|
||||
|
||||
compile_helper()
|
||||
|
||||
logger.info(
|
||||
">>> done with reid evaluation cython tool. Compilation time: {:.3f} "
|
||||
"seconds".format(time.time() - start_time))
|
||||
comm.synchronize()
|
||||
|
|
|
@ -8,23 +8,21 @@ import numpy as np
|
|||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def print_csv_format(results):
|
||||
"""
|
||||
Print main metrics in a format similar to Detectron,
|
||||
Print main metrics in a format similar to Detectron2,
|
||||
so that they are easy to copypaste into a spreadsheet.
|
||||
Args:
|
||||
results (OrderedDict[dict]): task_name -> {metric -> score}
|
||||
results (OrderedDict): {metric -> score}
|
||||
"""
|
||||
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed
|
||||
task = list(results.keys())[0]
|
||||
metrics = ["Datasets"] + [k for k in results[task]]
|
||||
# unordered results cannot be properly printed
|
||||
assert isinstance(results, OrderedDict) or not len(results), results
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
csv_results = []
|
||||
for task, res in results.items():
|
||||
csv_results.append((task, *list(res.values())))
|
||||
dataset_name = results.pop('dataset')
|
||||
metrics = ["Dataset"] + [k for k in results]
|
||||
csv_results = [(dataset_name, *list(results.values()))]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
|
|
|
@ -5,15 +5,15 @@
|
|||
"""
|
||||
|
||||
from .activation import *
|
||||
from .arc_softmax import ArcSoftmax
|
||||
from .circle_softmax import CircleSoftmax
|
||||
from .cos_softmax import CosSoftmax
|
||||
from .batch_drop import BatchDrop
|
||||
from .batch_norm import *
|
||||
from .context_block import ContextBlock
|
||||
from .drop import DropPath, DropBlock2d, drop_block_2d, drop_path
|
||||
from .frn import FRN, TLU
|
||||
from .gather_layer import GatherLayer
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
|
||||
from .non_local import Non_local
|
||||
from .pooling import *
|
||||
from .se_layer import SELayer
|
||||
from .splat import SplAtConv2d, DropBlock2D
|
||||
from .gather_layer import GatherLayer
|
||||
from .weight_init import (
|
||||
trunc_normal_, variance_scaling_, lecun_normal_, weights_init_kaiming, weights_init_classifier
|
||||
)
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = [
|
||||
"Linear",
|
||||
"ArcSoftmax",
|
||||
"CosSoftmax",
|
||||
"CircleSoftmax"
|
||||
]
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, num_classes, scale, margin):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.s = scale
|
||||
self.m = margin
|
||||
|
||||
def forward(self, logits, targets):
|
||||
return logits.mul_(self.s)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"num_classes={self.num_classes}, scale={self.s}, margin={self.m}"
|
||||
|
||||
|
||||
class CosSoftmax(Linear):
|
||||
r"""Implement of large margin cosine distance:
|
||||
"""
|
||||
|
||||
def forward(self, logits, targets):
|
||||
index = torch.where(targets != -1)[0]
|
||||
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
|
||||
m_hot.scatter_(1, targets[index, None], self.m)
|
||||
logits[index] -= m_hot
|
||||
logits.mul_(self.s)
|
||||
return logits
|
||||
|
||||
|
||||
class ArcSoftmax(Linear):
|
||||
|
||||
def forward(self, logits, targets):
|
||||
index = torch.where(targets != -1)[0]
|
||||
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
|
||||
m_hot.scatter_(1, targets[index, None], self.m)
|
||||
logits.acos_()
|
||||
logits[index] += m_hot
|
||||
logits.cos_().mul_(self.s)
|
||||
return logits
|
||||
|
||||
|
||||
class CircleSoftmax(Linear):
|
||||
|
||||
def forward(self, logits, targets):
|
||||
alpha_p = torch.clamp_min(-logits.detach() + 1 + self.m, min=0.)
|
||||
alpha_n = torch.clamp_min(logits.detach() + self.m, min=0.)
|
||||
delta_p = 1 - self.m
|
||||
delta_n = self.m
|
||||
|
||||
# When use model parallel, there are some targets not in class centers of local rank
|
||||
index = torch.where(targets != -1)[0]
|
||||
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
|
||||
m_hot.scatter_(1, targets[index, None], 1)
|
||||
|
||||
logits_p = alpha_p * (logits - delta_p)
|
||||
logits_n = alpha_n * (logits - delta_n)
|
||||
|
||||
logits[index] = logits_p[index] * m_hot + logits_n[index] * (1 - m_hot)
|
||||
|
||||
neg_index = torch.where(targets == -1)[0]
|
||||
logits[neg_index] = logits_n[neg_index]
|
||||
|
||||
logits.mul_(self.s)
|
||||
|
||||
return logits
|
|
@ -1,51 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class ArcSoftmax(nn.Module):
|
||||
def __init__(self, cfg, in_feat, num_classes):
|
||||
super().__init__()
|
||||
self.in_feat = in_feat
|
||||
self._num_classes = num_classes
|
||||
self.s = cfg.MODEL.HEADS.SCALE
|
||||
self.m = cfg.MODEL.HEADS.MARGIN
|
||||
|
||||
self.easy_margin = False
|
||||
|
||||
self.cos_m = math.cos(self.m)
|
||||
self.sin_m = math.sin(self.m)
|
||||
self.threshold = math.cos(math.pi - self.m)
|
||||
self.mm = math.sin(math.pi - self.m) * self.m
|
||||
|
||||
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
self.register_buffer('t', torch.zeros(1))
|
||||
|
||||
def forward(self, features, targets):
|
||||
cosine = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
||||
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
|
||||
if self.easy_margin:
|
||||
phi = torch.where(cosine > 0, phi, cosine)
|
||||
else:
|
||||
phi = torch.where(cosine > self.threshold, phi, cosine - self.mm)
|
||||
one_hot = torch.zeros(cosine.size(), device=cosine.device)
|
||||
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
||||
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
||||
output *= self.s
|
||||
return output
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
|
||||
self.in_feat, self._num_classes, self.s, self.m
|
||||
)
|
|
@ -1,32 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BatchDrop(nn.Module):
|
||||
"""ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
|
||||
batch drop mask
|
||||
"""
|
||||
|
||||
def __init__(self, h_ratio, w_ratio):
|
||||
super(BatchDrop, self).__init__()
|
||||
self.h_ratio = h_ratio
|
||||
self.w_ratio = w_ratio
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
h, w = x.size()[-2:]
|
||||
rh = round(self.h_ratio * h)
|
||||
rw = round(self.w_ratio * w)
|
||||
sx = random.randint(0, h - rh)
|
||||
sy = random.randint(0, w - rw)
|
||||
mask = x.new_ones(x.size())
|
||||
mask[:, :, sx:sx + rh, sy:sy + rw] = 0
|
||||
x = x * mask
|
||||
return x
|
|
@ -10,11 +10,6 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
from apex import parallel
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run model with syncBN")
|
||||
|
||||
__all__ = ["IBN", "get_norm"]
|
||||
|
||||
|
||||
|
@ -28,7 +23,7 @@ class BatchNorm(nn.BatchNorm2d):
|
|||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
|
||||
class SyncBatchNorm(parallel.SyncBatchNorm):
|
||||
class SyncBatchNorm(nn.SyncBatchNorm):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
|
@ -190,7 +185,7 @@ def get_norm(norm, out_channels, **kwargs):
|
|||
"""
|
||||
Args:
|
||||
norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
|
||||
or a callable that thakes a channel number and returns
|
||||
or a callable that takes a channel number and returns
|
||||
the normalization layer as a nn.Module
|
||||
out_channels: number of channels for normalization layer
|
||||
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class CircleSoftmax(nn.Module):
|
||||
def __init__(self, cfg, in_feat, num_classes):
|
||||
super().__init__()
|
||||
self.in_feat = in_feat
|
||||
self._num_classes = num_classes
|
||||
self.s = cfg.MODEL.HEADS.SCALE
|
||||
self.m = cfg.MODEL.HEADS.MARGIN
|
||||
|
||||
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets):
|
||||
sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.)
|
||||
alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.)
|
||||
delta_p = 1 - self.m
|
||||
delta_n = self.m
|
||||
|
||||
s_p = self.s * alpha_p * (sim_mat - delta_p)
|
||||
s_n = self.s * alpha_n * (sim_mat - delta_n)
|
||||
|
||||
targets = F.one_hot(targets, num_classes=self._num_classes)
|
||||
|
||||
pred_class_logits = targets * s_p + (1.0 - targets) * s_n
|
||||
|
||||
return pred_class_logits
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
|
||||
self.in_feat, self._num_classes, self.s, self.m
|
||||
)
|
|
@ -1,43 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class CosSoftmax(nn.Module):
|
||||
r"""Implement of large margin cosine distance:
|
||||
Args:
|
||||
in_feat: size of each input sample
|
||||
num_classes: size of each output sample
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, in_feat, num_classes):
|
||||
super().__init__()
|
||||
self.in_features = in_feat
|
||||
self._num_classes = num_classes
|
||||
self.s = cfg.MODEL.HEADS.SCALE
|
||||
self.m = cfg.MODEL.HEADS.MARGIN
|
||||
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, features, targets):
|
||||
# --------------------------- cos(theta) & phi(theta) ---------------------------
|
||||
cosine = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
phi = cosine - self.m
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
targets = F.one_hot(targets, num_classes=self._num_classes)
|
||||
output = (targets * phi) + ((1.0 - targets) * cosine)
|
||||
output *= self.s
|
||||
|
||||
return output
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
|
||||
self.in_feat, self._num_classes, self.s, self.m
|
||||
)
|
|
@ -0,0 +1,161 @@
|
|||
""" DropBlock, DropPath
|
||||
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
||||
Papers:
|
||||
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
|
||||
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
||||
Code:
|
||||
DropBlock impl inspired by two Tensorflow impl that I liked:
|
||||
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
|
||||
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def drop_block_2d(
|
||||
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
|
||||
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
||||
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
total_size = W * H
|
||||
clipped_block_size = min(block_size, min(W, H))
|
||||
# seed_drop_rate, the gamma parameter
|
||||
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(W - block_size + 1) * (H - block_size + 1))
|
||||
|
||||
# Forces the block to be inside the feature map.
|
||||
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
|
||||
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
|
||||
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
|
||||
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
||||
|
||||
if batchwise:
|
||||
# one mask for whole batch, quite a bit faster
|
||||
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
uniform_noise = torch.rand_like(x)
|
||||
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
||||
block_mask = -F.max_pool2d(
|
||||
-block_mask,
|
||||
kernel_size=clipped_block_size, # block_size,
|
||||
stride=1,
|
||||
padding=clipped_block_size // 2)
|
||||
|
||||
if with_noise:
|
||||
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||
if inplace:
|
||||
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
||||
else:
|
||||
x = x * block_mask + normal_noise * (1 - block_mask)
|
||||
else:
|
||||
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
|
||||
if inplace:
|
||||
x.mul_(block_mask * normalize_scale)
|
||||
else:
|
||||
x = x * block_mask * normalize_scale
|
||||
return x
|
||||
|
||||
|
||||
def drop_block_fast_2d(
|
||||
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
|
||||
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
||||
block mask at edges.
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
total_size = W * H
|
||||
clipped_block_size = min(block_size, min(W, H))
|
||||
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(W - block_size + 1) * (H - block_size + 1))
|
||||
|
||||
if batchwise:
|
||||
# one mask for whole batch, quite a bit faster
|
||||
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
|
||||
else:
|
||||
# mask per batch element
|
||||
block_mask = torch.rand_like(x) < gamma
|
||||
block_mask = F.max_pool2d(
|
||||
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
|
||||
|
||||
if with_noise:
|
||||
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||
if inplace:
|
||||
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
|
||||
else:
|
||||
x = x * (1. - block_mask) + normal_noise * block_mask
|
||||
else:
|
||||
block_mask = 1 - block_mask
|
||||
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
|
||||
if inplace:
|
||||
x.mul_(block_mask * normalize_scale)
|
||||
else:
|
||||
x = x * block_mask * normalize_scale
|
||||
return x
|
||||
|
||||
|
||||
class DropBlock2d(nn.Module):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
drop_prob=0.1,
|
||||
block_size=7,
|
||||
gamma_scale=1.0,
|
||||
with_noise=False,
|
||||
inplace=False,
|
||||
batchwise=False,
|
||||
fast=True):
|
||||
super(DropBlock2d, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.gamma_scale = gamma_scale
|
||||
self.block_size = block_size
|
||||
self.with_noise = with_noise
|
||||
self.inplace = inplace
|
||||
self.batchwise = batchwise
|
||||
self.fast = fast # FIXME finish comparisons of fast vs not
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self.drop_prob:
|
||||
return x
|
||||
if self.fast:
|
||||
return drop_block_fast_2d(
|
||||
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||
else:
|
||||
return drop_block_2d(
|
||||
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
|
@ -0,0 +1,31 @@
|
|||
""" Layer/Module Helpers
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import collections.abc
|
||||
from itertools import repeat
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_1tuple = _ntuple(1)
|
||||
to_2tuple = _ntuple(2)
|
||||
to_3tuple = _ntuple(3)
|
||||
to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
|
@ -8,20 +8,45 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
__all__ = ["Flatten",
|
||||
"GeneralizedMeanPooling",
|
||||
"GeneralizedMeanPoolingP",
|
||||
"FastGlobalAvgPool2d",
|
||||
"AdaptiveAvgMaxPool2d",
|
||||
"ClipGlobalAvgPool2d",
|
||||
]
|
||||
__all__ = [
|
||||
'Identity',
|
||||
'Flatten',
|
||||
'GlobalAvgPool',
|
||||
'GlobalMaxPool',
|
||||
'GeneralizedMeanPooling',
|
||||
'GeneralizedMeanPoolingP',
|
||||
'FastGlobalAvgPool',
|
||||
'AdaptiveAvgMaxPool',
|
||||
'ClipGlobalAvgPool',
|
||||
]
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1, 1, 1)
|
||||
|
||||
|
||||
class GlobalAvgPool(nn.AdaptiveAvgPool2d):
|
||||
def __init__(self, output_size=1, *args, **kwargs):
|
||||
super().__init__(output_size)
|
||||
|
||||
|
||||
class GlobalMaxPool(nn.AdaptiveMaxPool2d):
|
||||
def __init__(self, output_size=1, *args, **kwargs):
|
||||
super().__init__(output_size)
|
||||
|
||||
|
||||
class GeneralizedMeanPooling(nn.Module):
|
||||
r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
|
||||
The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
|
||||
|
@ -36,7 +61,7 @@ class GeneralizedMeanPooling(nn.Module):
|
|||
be the same as that of the input.
|
||||
"""
|
||||
|
||||
def __init__(self, norm=3, output_size=1, eps=1e-6):
|
||||
def __init__(self, norm=3, output_size=(1, 1), eps=1e-6, *args, **kwargs):
|
||||
super(GeneralizedMeanPooling, self).__init__()
|
||||
assert norm > 0
|
||||
self.p = float(norm)
|
||||
|
@ -45,7 +70,7 @@ class GeneralizedMeanPooling(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = x.clamp(min=self.eps).pow(self.p)
|
||||
return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
|
||||
return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' \
|
||||
|
@ -57,16 +82,16 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
|
|||
""" Same, but norm is trainable
|
||||
"""
|
||||
|
||||
def __init__(self, norm=3, output_size=1, eps=1e-6):
|
||||
def __init__(self, norm=3, output_size=(1, 1), eps=1e-6, *args, **kwargs):
|
||||
super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
|
||||
self.p = nn.Parameter(torch.ones(1) * norm)
|
||||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
def __init__(self):
|
||||
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||
self.gap = FastGlobalAvgPool2d()
|
||||
self.gmp = nn.AdaptiveMaxPool2d(1)
|
||||
class AdaptiveAvgMaxPool(nn.Module):
|
||||
def __init__(self, output_size=1, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.gap = FastGlobalAvgPool()
|
||||
self.gmp = GlobalMaxPool(output_size)
|
||||
|
||||
def forward(self, x):
|
||||
avg_feat = self.gap(x)
|
||||
|
@ -75,9 +100,9 @@ class AdaptiveAvgMaxPool2d(nn.Module):
|
|||
return feat
|
||||
|
||||
|
||||
class FastGlobalAvgPool2d(nn.Module):
|
||||
def __init__(self, flatten=False):
|
||||
super(FastGlobalAvgPool2d, self).__init__()
|
||||
class FastGlobalAvgPool(nn.Module):
|
||||
def __init__(self, flatten=False, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.flatten = flatten
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -88,10 +113,10 @@ class FastGlobalAvgPool2d(nn.Module):
|
|||
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
||||
|
||||
|
||||
class ClipGlobalAvgPool2d(nn.Module):
|
||||
def __init__(self):
|
||||
class ClipGlobalAvgPool(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.avgpool = FastGlobalAvgPool2d()
|
||||
self.avgpool = FastGlobalAvgPool()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avgpool(x)
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
def weights_init_classifier(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
# type: (Tensor, float, float, float, float) -> Tensor
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
if mode == 'fan_in':
|
||||
denom = fan_in
|
||||
elif mode == 'fan_out':
|
||||
denom = fan_out
|
||||
elif mode == 'fan_avg':
|
||||
denom = (fan_in + fan_out) / 2
|
||||
|
||||
variance = scale / denom
|
||||
|
||||
if distribution == "truncated_normal":
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
||||
elif distribution == "normal":
|
||||
tensor.normal_(std=math.sqrt(variance))
|
||||
elif distribution == "uniform":
|
||||
bound = math.sqrt(3 * variance)
|
||||
tensor.uniform_(-bound, bound)
|
||||
else:
|
||||
raise ValueError(f"invalid distribution {distribution}")
|
||||
|
||||
|
||||
def lecun_normal_(tensor):
|
||||
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
|
@ -4,4 +4,20 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .meta_arch import build_model
|
||||
from . import losses
|
||||
from .backbones import (
|
||||
BACKBONE_REGISTRY,
|
||||
build_resnet_backbone,
|
||||
build_backbone,
|
||||
)
|
||||
from .heads import (
|
||||
REID_HEADS_REGISTRY,
|
||||
build_heads,
|
||||
EmbeddingHead,
|
||||
)
|
||||
from .meta_arch import (
|
||||
build_model,
|
||||
META_ARCH_REGISTRY,
|
||||
)
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue