mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
commit
e5ddb3f9ca
22
README.md
22
README.md
@ -36,7 +36,7 @@ Below is the relations among Unsupervised Learning, Self-Supervised Learning and
|
||||
<tr><td><a href="https://arxiv.org/abs/1911.05722" target="_blank" rel="noopener noreferrer">MoCo</a></td><td>79.18</td><td>60.60</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2003.04297" target="_blank" rel="noopener noreferrer">MoCo v2</a></td><td>84.26</td><td>67.69</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2002.05709" target="_blank" rel="noopener noreferrer">SimCLR</a></td><td>78.95</td><td>61.57</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2006.07733" target="_blank" rel="noopener noreferrer">BYOL (bs4096)</a></td><td>85.10</td><td>69.14</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2006.07733" target="_blank" rel="noopener noreferrer">BYOL (epoch=300)</a></td><td>86.58</td><td>72.35</td></tr>
|
||||
</tbody></table>
|
||||
|
||||
- **Flexibility & Extensibility**
|
||||
@ -68,6 +68,8 @@ Below is the relations among Unsupervised Learning, Self-Supervised Learning and
|
||||
|
||||
Please refer to [CHANGELOG.md](docs/CHANGELOG.md) for details and release history.
|
||||
|
||||
[2020-10-14] `OpenSelfSup` v0.3.0 is released with some bugs fixed and support of new features.
|
||||
|
||||
[2020-06-26] `OpenSelfSup` v0.2.0 is released with benchmark results and support of new features.
|
||||
|
||||
[2020-06-16] `OpenSelfSup` v0.1.0 is released.
|
||||
@ -88,6 +90,20 @@ Please refer to [MODEL_ZOO.md](docs/MODEL_ZOO.md) for for a comprehensive set of
|
||||
|
||||
This project is released under the [Apache 2.0 license](LICENSE).
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this toolbox in your research, please consider cite:
|
||||
|
||||
```
|
||||
@inproceedings{zhan2020online,
|
||||
title={Online Deep Clustering for Unsupervised Representation Learning},
|
||||
author={Zhan, Xiaohang and Xie, Jiahao and Liu, Ziwei and Ong, Yew-Soon and Loy, Chen Change},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
pages={6688--6697},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
- This repo borrows the architecture design and part of the code from [MMDetection](https://github.com/open-mmlab/mmdetection).
|
||||
@ -96,6 +112,10 @@ This project is released under the [Apache 2.0 license](LICENSE).
|
||||
fair_self_supervision_benchmark](https://github.com/facebookresearch/fair_self_supervision_benchmark).
|
||||
- `openselfsup/third_party/clustering.py` is borrowed from [deepcluster](https://github.com/facebookresearch/deepcluster/blob/master/clustering.py).
|
||||
|
||||
## Contributors
|
||||
|
||||
We encourage researchers interested in Self-Supervised Learning to contribute to OpenSelfSup. Your contributions, including implementing or transferring new methods to OpenSelfSup, performing experiments, reproducing of results, parameter studies, etc, will be recorded in [MODEL_ZOO.md](docs/MODEL_ZOO.md). For now, the contributors include: Xiaohang Zhan ([@XiaohangZhan](http://github.com/XiaohangZhan)), Jiahao Xie ([@Jiahao000](https://github.com/Jiahao000)), Enze Xie ([@xieenze](https://github.com/xieenze)), Zijian He ([@scnuhealthy](https://github.com/scnuhealthy)).
|
||||
|
||||
## Contact
|
||||
|
||||
This repo is currently maintained by Xiaohang Zhan ([@XiaohangZhan](http://github.com/XiaohangZhan)), Jiahao Xie ([@Jiahao000](https://github.com/Jiahao000)) and Enze Xie ([@xieenze](https://github.com/xieenze)).
|
||||
|
108
configs/selfsup/byol/r50_bs256_accumulate16_ep300.py
Normal file
108
configs/selfsup/byol/r50_bs256_accumulate16_ep300.py
Normal file
@ -0,0 +1,108 @@
|
||||
import copy
|
||||
_base_ = '../../base.py'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='BYOL',
|
||||
pretrained=None,
|
||||
base_momentum=0.99,
|
||||
pre_conv=True,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='SyncBN')),
|
||||
neck=dict(
|
||||
type='NonLinearNeckSimCLR',
|
||||
in_channels=2048,
|
||||
hid_channels=4096,
|
||||
out_channels=256,
|
||||
num_layers=2,
|
||||
sync_bn=True,
|
||||
with_bias=True,
|
||||
with_last_bn=False,
|
||||
with_avg_pool=True),
|
||||
head=dict(type='LatentPredictHead',
|
||||
size_average=True,
|
||||
predictor=dict(type='NonLinearNeckSimCLR',
|
||||
in_channels=256, hid_channels=4096,
|
||||
out_channels=256, num_layers=2, sync_bn=True,
|
||||
with_bias=True, with_last_bn=False, with_avg_pool=False)))
|
||||
# dataset settings
|
||||
data_source_cfg = dict(
|
||||
type='ImageNet',
|
||||
memcached=True,
|
||||
mclient_path='/mnt/lustre/share/memcached_client')
|
||||
data_train_list = 'data/imagenet/meta/train.txt'
|
||||
data_train_root = 'data/imagenet/train'
|
||||
dataset_type = 'BYOLDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224, interpolation=3),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
dict(
|
||||
type='RandomAppliedTrans',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.2,
|
||||
hue=0.1)
|
||||
],
|
||||
p=0.8),
|
||||
dict(type='RandomGrayscale', p=0.2),
|
||||
dict(
|
||||
type='RandomAppliedTrans',
|
||||
transforms=[
|
||||
dict(
|
||||
type='GaussianBlur',
|
||||
sigma_min=0.1,
|
||||
sigma_max=2.0)
|
||||
],
|
||||
p=1.),
|
||||
dict(type='RandomAppliedTrans',
|
||||
transforms=[dict(type='Solarization')], p=0.),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
train_pipeline1 = copy.deepcopy(train_pipeline)
|
||||
train_pipeline2 = copy.deepcopy(train_pipeline)
|
||||
train_pipeline2[4]['p'] = 0.1 # gaussian blur
|
||||
train_pipeline2[5]['p'] = 0.2 # solarization
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=32, # total 32*8(gpu)*16(interval)=4096
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
list_file=data_train_list, root=data_train_root,
|
||||
**data_source_cfg),
|
||||
pipeline1=train_pipeline1,
|
||||
pipeline2=train_pipeline2))
|
||||
# additional hooks
|
||||
update_interval = 16 # interval for accumulate gradient
|
||||
custom_hooks = [
|
||||
dict(type='BYOLHook', end_momentum=1., update_interval=update_interval)
|
||||
]
|
||||
# optimizer
|
||||
optimizer = dict(type='LARS', lr=4.8, weight_decay=0.000001, momentum=0.9,
|
||||
paramwise_options={
|
||||
'(bn|gn)(\d+)?.(weight|bias)': dict(weight_decay=0., lars_exclude=True),
|
||||
'bias': dict(weight_decay=0., lars_exclude=True)})
|
||||
# apex
|
||||
use_fp16 = False
|
||||
optimizer_config = dict(update_interval=update_interval, use_fp16=use_fp16)
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=0.,
|
||||
warmup='linear',
|
||||
warmup_iters=10,
|
||||
warmup_ratio=0.0001, # cannot be 0
|
||||
warmup_by_epoch=True)
|
||||
checkpoint_config = dict(interval=10)
|
||||
# runtime settings
|
||||
total_epochs = 300
|
88
configs/selfsup/deepcluster/r50_withoutsobel.py
Normal file
88
configs/selfsup/deepcluster/r50_withoutsobel.py
Normal file
@ -0,0 +1,88 @@
|
||||
_base_ = '../../base.py'
|
||||
# model settings
|
||||
num_classes = 10000
|
||||
model = dict(
|
||||
type='DeepCluster',
|
||||
pretrained=None,
|
||||
with_sobel=False,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN')),
|
||||
neck=dict(type='AvgPoolNeck'),
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
with_avg_pool=False, # already has avgpool in the neck
|
||||
in_channels=2048,
|
||||
num_classes=num_classes))
|
||||
# dataset settings
|
||||
data_source_cfg = dict(
|
||||
type='ImageNet',
|
||||
memcached=True,
|
||||
mclient_path='/mnt/lustre/share/memcached_client')
|
||||
data_train_list = 'data/imagenet/meta/train.txt'
|
||||
data_train_root = 'data/imagenet/jpeg/train'
|
||||
dataset_type = 'DeepClusterDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
dict(type='RandomRotation', degrees=2),
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=1.0,
|
||||
hue=0.5),
|
||||
dict(type='RandomGrayscale', p=0.2),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
extract_pipeline = [
|
||||
dict(type='Resize', size=256),
|
||||
dict(type='CenterCrop', size=224),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
data = dict(
|
||||
imgs_per_gpu=64, # 32
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
list_file=data_train_list, root=data_train_root,
|
||||
**data_source_cfg),
|
||||
pipeline=train_pipeline))
|
||||
# additional hooks
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='DeepClusterHook',
|
||||
extractor=dict(
|
||||
imgs_per_gpu=128,
|
||||
workers_per_gpu=8,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
list_file=data_train_list,
|
||||
root=data_train_root,
|
||||
**data_source_cfg),
|
||||
pipeline=extract_pipeline)),
|
||||
clustering=dict(type='Kmeans', k=num_classes, pca_dim=256),
|
||||
unif_sampling=True,
|
||||
reweight=False,
|
||||
reweight_pow=0.5,
|
||||
initial=True, # call initially
|
||||
interval=1)
|
||||
]
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD', lr=0.3, momentum=0.9, weight_decay=0.00001,
|
||||
nesterov=False,
|
||||
paramwise_options={'\Ahead.': dict(momentum=0.)})
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[400])
|
||||
checkpoint_config = dict(interval=10)
|
||||
# runtime settings
|
||||
total_epochs = 200
|
@ -1,5 +1,21 @@
|
||||
## Changelog
|
||||
|
||||
### v0.3.0 (14/10/2020)
|
||||
|
||||
#### Highlight
|
||||
* Support Mixed Precision Training
|
||||
* Improvement of GaussianBlur doubles the training speed
|
||||
* More benchmarking results
|
||||
|
||||
#### Bug Fixes
|
||||
* Fix bugs in moco v2, now the results are reproducible.
|
||||
* Fix bugs in byol.
|
||||
|
||||
#### New Features
|
||||
* Mixed Precision Training
|
||||
* Improvement of GaussianBlur doubles the training speed of MoCo V2, SimCLR, BYOL
|
||||
* More benchmarking results, including Places, VOC, COCO
|
||||
|
||||
### v0.2.0 (26/6/2020)
|
||||
|
||||
#### Highlights
|
||||
|
@ -38,6 +38,12 @@ An example:
|
||||
SRUN_ARGS="-w xx.xx.xx.xx" bash tools/srun_train.sh Dummy configs/selfsup/odc/r50_v1.py 8 --resume_from work_dirs/selfsup/odc/r50_v1/epoch_100.pth
|
||||
```
|
||||
|
||||
### Train with multiple machines
|
||||
|
||||
If you launch with multiple machines simply connected with ethernet, you have to modify `tools/dist_train.sh` or create a new script, please refer to PyTorch [Launch utility](https://pytorch.org/docs/stable/distributed.html#launch-utility). Usually it is slow if you do not have high speed networking like InfiniBand.
|
||||
|
||||
If you launch with slurm, the command is the same as that on single machine described above. You only need to change ${GPUS}, e.g., to 16 for two 8-GPU machines.
|
||||
|
||||
### Launch multiple jobs on a single machine
|
||||
|
||||
If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs,
|
||||
|
@ -14,7 +14,7 @@ We have tested the following versions of OS and softwares:
|
||||
|
||||
- OS: Ubuntu 16.04/18.04 and CentOS 7.2
|
||||
- CUDA: 9.0/9.2/10.0/10.1
|
||||
- NCCL: 2.1.15/2.2.13/2.3.7/2.4.2
|
||||
- NCCL: 2.1.15/2.2.13/2.3.7/2.4.2 (PyTorch-1.1 w/ NCCL-2.4.2 has a deadlock bug, see [here](https://github.com/open-mmlab/OpenSelfSup/issues/6))
|
||||
- GCC(G++): 4.9/5.3/5.4/7.3
|
||||
|
||||
### Install openselfsup
|
||||
|
@ -1,24 +1,29 @@
|
||||
# Model Zoo
|
||||
|
||||
**OpenSelfSup needs your contribution!
|
||||
Since we don't have sufficient GPUs to run these large-scale experiments, your contributions, including parameter studies, reproducing of results, implementing new methods, etc, are essential to make OpenSelfSup better. Your contribution will be recorded in the below table, top contributors will be included in the author list of OpenSelfSup!**
|
||||
|
||||
## Pre-trained model download links and speed test.
|
||||
**Note**
|
||||
* The testing GPUs are NVIDIA Tesla V100.
|
||||
* If not specifically indicated, the testing GPUs are NVIDIA Tesla V100.
|
||||
* Experiments with the same batch size are directly comparable in speed.
|
||||
* The table records the implementors who implemented the methods (either by themselves or refactoring from other repos), and the experimenters who performed experiments and reproduced the results. The experimenters should be responsible for the evaluation results on all the benchmarks, and the implementors should be responsible for the implementation as well as the results; If the experimenter is not indicated, an implementator is the experimenter by default.
|
||||
|
||||
<table><thead><tr><th>Method</th><th>Config</th><th>Remarks</th><th>Download link</th><th>Batch size</th><th>Epochs</th><th>Time per epoch</th></tr></thead><tbody>
|
||||
<table><thead><tr><th>Method (Implementator)</th><th>Config (Experimenter)</th><th>Remarks</th><th>Download link</th><th>Batch size</th><th>Epochs</th><th>Time per epoch</th></tr></thead><tbody>
|
||||
<tr><td><a href="https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py" target="_blank" rel="noopener noreferrer">ImageNet</a></td><td>-</td><td>torchvision</td><td><a href="https://drive.google.com/file/d/11xA3TOcbD0qOrwpBfYonEDeseE1wMfBh/view?usp=sharing" target="_blank" rel="noopener noreferrer">imagenet_r50-21352794.pth</a></td><td>-</td><td>-</td><td>-</td></tr>
|
||||
<tr><td>Random</td><td>-</td><td>kaiming</td><td><a href="https://drive.google.com/file/d/1UaFTjd6sbKkZEE-f58Zv30bnx7C1qJBb/view?usp=sharing" target="_blank" rel="noopener noreferrer">random_r50-5d0fa71b.pth</a></td><td>-</td><td>-</td><td>-</td></tr>
|
||||
<tr><td><a href="https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Doersch_Unsupervised_Visual_Representation_ICCV_2015_paper.pdf" target="_blank" rel="noopener noreferrer">Relative-Loc</a></td><td>selfsup/relative_loc/r50.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1ibk1BI3PFQxZqcxuDfHs3n7JnWKCgl8x/view?usp=sharing" target="_blank" rel="noopener noreferrer">relative_loc_r50-342c9097.pth</a></td><td>512</td><td>70</td><td>21min17s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/1803.07728" target="_blank" rel="noopener noreferrer">Rotation-Pred</a></td><td>selfsup/rotation_pred/r50.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1t3oClmIvQ0p8RZ0V5yvQFltzjqBO823Y/view?usp=sharing" target="_blank" rel="noopener noreferrer">rotation_r50-cfab8ebb.pth</a></td><td>128</td><td>70</td><td>49min58s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/1807.05520" target="_blank" rel="noopener noreferrer">DeepCluster</a></td><td>selfsup/deepcluster/r50.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1GxgP7pI18JtFxDIC0hnHOanvUYajoLlg/view?usp=sharing" target="_blank" rel="noopener noreferrer">deepcluster_r50-bb8681e2.pth</a></td><td>512</td><td>200</td><td>41min57s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/1805.01978" target="_blank" rel="noopener noreferrer">NPID</a></td><td>selfsup/npid/r50.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1sm6I3Y5XnCWdbmeLSF4YupUtPe5nRQMI/view?usp=sharing" target="_blank" rel="noopener noreferrer">npid_r50-dec3df0c.pth</a></td><td>256</td><td>200</td><td>20min5s</td></tr>
|
||||
<tr><td><a href="https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Doersch_Unsupervised_Visual_Representation_ICCV_2015_paper.pdf" target="_blank" rel="noopener noreferrer">Relative-Loc</a> (<a href="https://github.com/Jiahao000">@Jiahao000</a>)</td><td>selfsup/relative_loc/r50.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1ibk1BI3PFQxZqcxuDfHs3n7JnWKCgl8x/view?usp=sharing" target="_blank" rel="noopener noreferrer">relative_loc_r50-342c9097.pth</a></td><td>512</td><td>70</td><td>21min17s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/1803.07728" target="_blank" rel="noopener noreferrer">Rotation-Pred</a> (<a href="https://github.com/XiaohangZhan">@XiaohangZhan</a>)</td><td>selfsup/rotation_pred/r50.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1t3oClmIvQ0p8RZ0V5yvQFltzjqBO823Y/view?usp=sharing" target="_blank" rel="noopener noreferrer">rotation_r50-cfab8ebb.pth</a></td><td>128</td><td>70</td><td>49min58s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/1807.05520" target="_blank" rel="noopener noreferrer">DeepCluster</a> (<a href="https://github.com/XiaohangZhan">@XiaohangZhan</a>)</td><td>selfsup/deepcluster/r50.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1GxgP7pI18JtFxDIC0hnHOanvUYajoLlg/view?usp=sharing" target="_blank" rel="noopener noreferrer">deepcluster_r50-bb8681e2.pth</a></td><td>512</td><td>200</td><td>41min57s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/1805.01978" target="_blank" rel="noopener noreferrer">NPID</a> (<a href="https://github.com/XiaohangZhan">@XiaohangZhan</a>)</td><td>selfsup/npid/r50.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1sm6I3Y5XnCWdbmeLSF4YupUtPe5nRQMI/view?usp=sharing" target="_blank" rel="noopener noreferrer">npid_r50-dec3df0c.pth</a></td><td>256</td><td>200</td><td>20min5s</td></tr>
|
||||
<tr><td></td><td>selfsup/npid/r50_ensure_neg.py</td><td>ensure_neg=True</td><td><a href="https://drive.google.com/file/d/1FldDrb6kzF3CZ7737mwCXVI6HE2aCSaF/view?usp=sharing" target="_blank" rel="noopener noreferrer">npid_r50_ensure_neg-ce09b7ae.pth</a></td><td></td><td></td><td></td></tr>
|
||||
<tr><td><a href="http://openaccess.thecvf.com/content_CVPR_2020/papers/Zhan_Online_Deep_Clustering_for_Unsupervised_Representation_Learning_CVPR_2020_paper.pdf" target="_blank" rel="noopener noreferrer">ODC</a></td><td>selfsup/odc/r50_v1.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1EdhJeZAyMsD_pEW7uMhLzos5xZLdariN/view?usp=sharing" target="_blank" rel="noopener noreferrer">odc_r50_v1-5af5dd0c.pth</a></td><td>512</td><td>440</td><td>28min22s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/1911.05722" target="_blank" rel="noopener noreferrer">MoCo</a></td><td>selfsup/moco/r50_v1.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1ANXfnoT8yBQQBBqR_kQLQorK20l65KMy/view?usp=sharing" target="_blank" rel="noopener noreferrer">moco_r50_v1-4ad89b5c.pth</a></td><td>256</td><td>200</td><td>22min58s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2003.04297" target="_blank" rel="noopener noreferrer">MoCo v2</a></td><td>selfsup/moco/r50_v2.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1ImO8A3uWbrTx21D1IqBDMUQvpN6wmv0d/view?usp=sharing" target="_blank" rel="noopener noreferrer">moco_r50_v2-e3b0c442.pth</a></td><td>256</td><td>200</td><td>55min43s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2002.05709" target="_blank" rel="noopener noreferrer">SimCLR</a></td><td>selfsup/simclr/r50_bs256_ep200.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1aZ43nSdivdNxHbM9DKVoZYVhZ8TNnmPp/view?usp=sharing" target="_blank" rel="noopener noreferrer">simclr_r50_bs256_ep200-4577e9a6.pth</a></td><td>256</td><td>200</td><td>1h1min7s</td></tr>
|
||||
<tr><td><a href="http://openaccess.thecvf.com/content_CVPR_2020/papers/Zhan_Online_Deep_Clustering_for_Unsupervised_Representation_Learning_CVPR_2020_paper.pdf" target="_blank" rel="noopener noreferrer">ODC</a> (<a href="https://github.com/XiaohangZhan">@XiaohangZhan</a>)</td><td>selfsup/odc/r50_v1.py (<a href="https://github.com/Jiahao000">@Jiahao000</a>)</td><td>default</td><td><a href="https://drive.google.com/file/d/1EdhJeZAyMsD_pEW7uMhLzos5xZLdariN/view?usp=sharing" target="_blank" rel="noopener noreferrer">odc_r50_v1-5af5dd0c.pth</a></td><td>512</td><td>440</td><td>28min22s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/1911.05722" target="_blank" rel="noopener noreferrer">MoCo</a> (<a href="https://github.com/XiaohangZhan">@XiaohangZhan</a>)</td><td>selfsup/moco/r50_v1.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1ANXfnoT8yBQQBBqR_kQLQorK20l65KMy/view?usp=sharing" target="_blank" rel="noopener noreferrer">moco_r50_v1-4ad89b5c.pth</a></td><td>256</td><td>200</td><td>22min58s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2003.04297" target="_blank" rel="noopener noreferrer">MoCo v2</a> (<a href="https://github.com/XiaohangZhan">@XiaohangZhan</a>)</td><td>selfsup/moco/r50_v2.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1ImO8A3uWbrTx21D1IqBDMUQvpN6wmv0d/view?usp=sharing" target="_blank" rel="noopener noreferrer">moco_r50_v2-e3b0c442.pth</a></td><td>256</td><td>200</td><td>25min7s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2002.05709" target="_blank" rel="noopener noreferrer">SimCLR</a> (<a href="https://github.com/XiaohangZhan">@XiaohangZhan</a>)</td><td>selfsup/simclr/r50_bs256_ep200.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1aZ43nSdivdNxHbM9DKVoZYVhZ8TNnmPp/view?usp=sharing" target="_blank" rel="noopener noreferrer">simclr_r50_bs256_ep200-4577e9a6.pth</a></td><td>256</td><td>200</td><td>32min13s</td></tr>
|
||||
<tr><td></td><td>selfsup/simclr/r50_bs256_ep200_mocov2_neck.py</td><td>-> MoCo v2 neck</td><td><a href="https://drive.google.com/file/d/1AXpSKqgWfnj6jCgN65BXSTCKFfuIVELa/view?usp=sharing" target="_blank" rel="noopener noreferrer">simclr_r50_bs256_ep200_mocov2_neck-0d6e5ff2.pth</a></td><td></td><td></td><td></td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2006.07733" target="_blank" rel="noopener noreferrer">BYOL</a></td><td>selfsup/byol/r50_bs4096_ep200.py</td><td>default</td><td><a href="https://drive.google.com/file/d/1Whj3j5E3ShQj_VufjrJSzWiq1xcZZCXN/view?usp=sharing" target="_blank" rel="noopener noreferrer">byol_r50-e3b0c442.pth</a></td><td>4096</td><td>200</td><td>14min40s</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2006.07733" target="_blank" rel="noopener noreferrer">BYOL</a> (<a href="https://github.com/XiaohangZhan">@XiaohangZhan</a>)</td><td>selfsup/byol/r50_bs4096_ep200.py (<a href="https://github.com/xieenze">@xieenze</a>)</td><td>default</td><td><a href="https://drive.google.com/file/d/1Whj3j5E3ShQj_VufjrJSzWiq1xcZZCXN/view?usp=sharing" target="_blank" rel="noopener noreferrer">byol_r50-e3b0c442.pth</a></td><td>4096 (256 for speed test)</td><td>200</td><td>30min57s</td></tr>
|
||||
<tr><td></td><td>selfsup/byol/r50_bs256_accumulate16_ep300.py (<a href="https://github.com/scnuhealthy">@scnuhealthy</a>)</td><td>default</td><td><a href="https://drive.google.com/file/d/12Zu9r3fE8qKF4OW6WQXa5Ec6VuA2m3j7/view?usp=sharing" target="_blank" rel="noopener noreferrer">byol_r50_bs256_accmulate16_ep300-5df46722.pth</a></td><td>256</td><td>300</td><td></td></tr>
|
||||
</tbody></table>
|
||||
|
||||
## Benchmarks
|
||||
@ -40,8 +45,10 @@
|
||||
<tr><td><a href="https://arxiv.org/abs/2002.05709" target="_blank" rel="noopener noreferrer">SimCLR</a></td><td>selfsup/simclr/r50_bs256_ep200.py</td><td>default</td><td>feat5</td><td>78.95</td><td>32.45</td><td>40.76</td><td>50.4</td><td>59.01</td><td>65.45</td><td>70.13</td><td>73.58</td><td>75.35</td></tr>
|
||||
<tr><td></td><td>selfsup/simclr/r50_bs256_ep200_mocov2_neck.py</td><td>-> MoCo v2 neck</td><td>feat5</td><td>77.65</td><td></td><td></td><td></td><td></td><td></td><td></td><td></td><td></td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2006.07733" target="_blank" rel="noopener noreferrer">BYOL</a></td><td>selfsup/byol/r50_bs4096_ep200.py</td><td>default</td><td>feat5</td><td>85.10</td><td>44.48</td><td>52.09</td><td>62.88</td><td>70.87</td><td>76.18</td><td>79.45</td><td>81.88</td><td>83.08</td></tr>
|
||||
<tr><td></td><td>selfsup/byol/r50_bs256_accumulate16_ep300.py</td><td>default</td><td>feat5</td><td>86.58</td><td></td><td></td><td></td><td></td><td></td><td></td><td></td><td></td></tr>
|
||||
</tbody></table>
|
||||
|
||||
|
||||
### ImageNet Linear Classification
|
||||
|
||||
**Note**
|
||||
@ -49,6 +56,7 @@
|
||||
* For DeepCluster, use the corresponding one with `_sobel`.
|
||||
* ImageNet (Multi) evaluates features in around 9k dimensions from different layers. Top-1 result of the last epoch is reported.
|
||||
* ImageNet (Last) evaluates the last feature after global average pooling, e.g., 2048 dimensions for resnet50. The best top-1 result among all epochs is reported.
|
||||
* Usually, we report the best result from ImageNet (Multi) and ImageNet (Last) to ensure fairness, since different methods achieve their best performance on different layers.
|
||||
|
||||
<table><thead><tr><th rowspan="2">Method</th><th rowspan="2">Config</th><th rowspan="2">Remarks</th><th colspan="5">ImageNet (Multi)</th><th>ImageNet (Last)</th></tr>
|
||||
<tr><td>feat1</td><td>feat2</td><td>feat3</td><td>feat4</td><td>feat5</td><td>avgpool</td></tr></thead><tbody>
|
||||
@ -64,6 +72,7 @@
|
||||
<tr><td><a href="https://arxiv.org/abs/2002.05709" target="_blank" rel="noopener noreferrer">SimCLR</a></td><td>selfsup/simclr/r50_bs256_ep200.py</td><td>default</td><td>17.09</td><td>31.37</td><td>41.38</td><td>54.35</td><td>61.57</td><td>60.06</td></tr>
|
||||
<tr><td></td><td>selfsup/simclr/r50_bs256_ep200_mocov2_neck.py</td><td>-> MoCo v2 neck</td><td>16.97</td><td>31.88</td><td>41.73</td><td>54.33</td><td>59.94</td><td>58.00</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2006.07733" target="_blank" rel="noopener noreferrer">BYOL</a></td><td>selfsup/byol/r50_bs4096_ep200.py</td><td>default</td><td>16.70</td><td>34.22</td><td>46.61</td><td>60.78</td><td>69.14</td><td>67.10</td></tr>
|
||||
<tr><td></td><td>selfsup/byol/r50_bs256_accumulate16_ep300.py</td><td>default</td><td>14.07</td><td>34.44</td><td>47.22</td><td>63.08</td><td>72.35</td><td></td></tr>
|
||||
</tbody></table>
|
||||
|
||||
### Places205 Linear Classification
|
||||
@ -125,7 +134,7 @@
|
||||
<tr><td><a href="https://arxiv.org/abs/1911.05722" target="_blank" rel="noopener noreferrer">MoCo</a></td><td>selfsup/moco/r50_v1.py</td><td>default</td><td>r50_lr0_01_head100.py</td><td>60.08</td><td>84.02</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2003.04297" target="_blank" rel="noopener noreferrer">MoCo v2</a></td><td>selfsup/moco/r50_v2.py</td><td>default</td><td>r50_lr0_01_head100.py</td><td>61.80</td><td>85.11</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2002.05709" target="_blank" rel="noopener noreferrer">SimCLR</a></td><td>selfsup/simclr/r50_bs256_ep200.py</td><td>default</td><td>r50_lr0_01_head100.py</td><td>58.46</td><td>82.60</td></tr>
|
||||
<tr><td></td><td>selfsup/simclr/r50_bs256_ep200_mocov2_neck.py</td><td>-> MoCo v2 neck</td><td></td><td>58.38</td><td>82.53</td></tr>
|
||||
<tr><td></td><td>selfsup/simclr/r50_bs256_ep200_mocov2_neck.py</td><td>-> MoCo v2 neck</td><td>r50_lr0_01_head100.py</td><td>58.38</td><td>82.53</td></tr>
|
||||
<tr><td><a href="https://arxiv.org/abs/2006.07733" target="_blank" rel="noopener noreferrer">BYOL</a></td><td>selfsup/byol/r50_bs4096_ep200.py</td><td>default</td><td>r50_lr0_01_head100.py</td><td>65.94</td><td>87.81</td></tr>
|
||||
</tbody></table>
|
||||
|
||||
@ -169,4 +178,3 @@
|
||||
<tr><td><a href="https://arxiv.org/abs/2006.07733">BYOL</a></td><td>selfsup/byol/r50_bs4096_ep200.py</td><td>default</td><td>60.5</td><td>40.3</td><td>43.9</td><td>56.8</td><td>35.1</td><td>37.3</td></tr>
|
||||
</tbody></table>
|
||||
|
||||
|
||||
|
@ -24,5 +24,5 @@ class ContrastiveDataset(BaseDataset):
|
||||
img_cat = torch.cat((img1.unsqueeze(0), img2.unsqueeze(0)), dim=0)
|
||||
return dict(img=img_cat)
|
||||
|
||||
def evaluate(self, scores, keyword, logger=None):
|
||||
def evaluate(self, scores, keyword, logger=None, **kwargs):
|
||||
raise NotImplemented
|
||||
|
@ -42,7 +42,7 @@ class Cifar100(object):
|
||||
assert split in ['train', 'test']
|
||||
try:
|
||||
self.cifar = CIFAR100(
|
||||
root=root, train=spilt == 'train', download=False)
|
||||
root=root, train=split == 'train', download=False)
|
||||
except:
|
||||
raise Exception("Please download CIFAR10 manually, \
|
||||
in case of downloading the dataset parallelly \
|
||||
|
@ -12,8 +12,10 @@ from openselfsup.utils import build_from_cfg
|
||||
from ..registry import PIPELINES
|
||||
|
||||
# register all existing transforms in torchvision
|
||||
_EXCLUDED_TRANSFORMS = ['GaussianBlur']
|
||||
for m in inspect.getmembers(_transforms, inspect.isclass):
|
||||
PIPELINES.register_module(m[1])
|
||||
if m[0] not in _EXCLUDED_TRANSFORMS:
|
||||
PIPELINES.register_module(m[1])
|
||||
|
||||
|
||||
@PIPELINES.register_module
|
||||
|
@ -1,5 +1,6 @@
|
||||
from math import cos, pi
|
||||
from mmcv.runner import Hook
|
||||
from mmcv.parallel import is_module_wrapper
|
||||
|
||||
from .registry import HOOKS
|
||||
|
||||
@ -17,17 +18,26 @@ class BYOLHook(Hook):
|
||||
for the target network. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, end_momentum=1., **kwargs):
|
||||
def __init__(self, end_momentum=1., update_interval=1, **kwargs):
|
||||
self.end_momentum = end_momentum
|
||||
self.update_interval = update_interval
|
||||
|
||||
def before_train_iter(self, runner):
|
||||
assert hasattr(runner.model.module, 'momentum'), \
|
||||
"The runner must have attribute \"momentum\" in BYOLHook."
|
||||
assert hasattr(runner.model.module, 'base_momentum'), \
|
||||
"The runner must have attribute \"base_momentum\" in BYOLHook."
|
||||
cur_iter = runner.iter
|
||||
max_iter = runner.max_iters
|
||||
base_m = runner.model.module.base_momentum
|
||||
m = self.end_momentum - (self.end_momentum - base_m) * (
|
||||
cos(pi * cur_iter / float(max_iter)) + 1) / 2
|
||||
runner.model.module.momentum = m
|
||||
if self.every_n_iters(runner, self.update_interval):
|
||||
cur_iter = runner.iter
|
||||
max_iter = runner.max_iters
|
||||
base_m = runner.model.module.base_momentum
|
||||
m = self.end_momentum - (self.end_momentum - base_m) * (
|
||||
cos(pi * cur_iter / float(max_iter)) + 1) / 2
|
||||
runner.model.module.momentum = m
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
if self.every_n_iters(runner, self.update_interval):
|
||||
if is_module_wrapper(runner.model):
|
||||
runner.model.module.momentum_update()
|
||||
else:
|
||||
runner.model.momentum_update()
|
||||
|
@ -70,6 +70,10 @@ class BYOL(nn.Module):
|
||||
param_tgt.data = param_tgt.data * self.momentum + \
|
||||
param_ol.data * (1. - self.momentum)
|
||||
|
||||
@torch.no_grad()
|
||||
def momentum_update(self):
|
||||
self._momentum_update()
|
||||
|
||||
def forward_train(self, img, **kwargs):
|
||||
"""Forward computation during training.
|
||||
|
||||
@ -93,7 +97,6 @@ class BYOL(nn.Module):
|
||||
|
||||
loss = self.head(proj_online_v1, proj_target_v2)['loss'] + \
|
||||
self.head(proj_online_v2, proj_target_v1)['loss']
|
||||
self._momentum_update()
|
||||
return dict(loss=loss)
|
||||
|
||||
def forward_test(self, img, **kwargs):
|
||||
|
1
openselfsup/third_party/clustering.py
vendored
1
openselfsup/third_party/clustering.py
vendored
@ -5,6 +5,7 @@ import time
|
||||
import numpy as np
|
||||
import faiss
|
||||
import torch
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
__all__ = ['Kmeans', 'PIC']
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# GENERATED VERSION FILE
|
||||
# TIME: Mon Sep 7 13:47:34 2020
|
||||
# TIME: Wed Oct 21 05:59:41 2020
|
||||
|
||||
__version__ = '0.2.0+d1b12bd'
|
||||
short_version = '0.2.0'
|
||||
__version__ = '0.3.0+df3689d'
|
||||
short_version = '0.3.0'
|
||||
|
Loading…
x
Reference in New Issue
Block a user