From c89239a871db4c9c25925f87b3eba33d5b6e2b91 Mon Sep 17 00:00:00 2001 From: xiaohangzhan Date: Mon, 29 Jun 2020 00:10:34 +0800 Subject: [PATCH] update --- README.md | 20 +-- ...st_test_linear.sh => dist_train_linear.sh} | 11 +- benchmarks/dist_train_semi.sh | 24 ++++ ...un_test_linear.sh => srun_train_linear.sh} | 15 --- benchmarks/srun_train_semi.sh | 31 +++++ .../imagenet/r50_moco.py | 4 +- .../imagenet/r50_moco_sobel.py | 4 +- .../imagenet/r50_multihead.py | 4 +- .../imagenet/r50_multihead_sobel.py | 4 +- .../places205/r50_multihead.py | 2 +- .../places205/r50_multihead_sobel.py | 2 +- .../imagenet_10percent/r50.py | 0 .../imagenet_10percent/r50_sobel.py | 0 .../imagenet_1percent/r50_lr0_01_head1.py | 4 + .../imagenet_1percent/r50_lr0_01_head10.py | 4 + .../imagenet_1percent/r50_lr0_01_head100.py | 4 + .../r50_lr0_01_head1_sobel.py} | 25 +++- .../imagenet_1percent/r50_lr0_1_head1.py | 4 + .../imagenet_1percent/r50_lr0_1_head10.py | 4 + .../imagenet_1percent/r50_lr0_1_head100.py | 4 + .../classification/imagenet_1percent/r50.py | 58 --------- configs/selfsup/byol/r50.py | 97 ++++++++++++++ configs/selfsup/moco/r50_v2_simclr_neck.py | 76 +++++++++++ ...s256_simclr_neck.py => r50_bs256_ep200.py} | 7 +- ...s256.py => r50_bs256_ep200_mocov2_neck.py} | 9 +- .../{r50_bs512.py => r50_bs512_ep200.py} | 10 +- docs/CHANGELOG.md | 19 +++ docs/GETTING_STARTED.md | 34 ++++- docs/MODEL_ZOO.md | 30 ++++- openselfsup/datasets/__init__.py | 1 + openselfsup/datasets/byol.py | 35 +++++ openselfsup/datasets/classification.py | 4 +- openselfsup/datasets/pipelines/transforms.py | 16 +++ openselfsup/hooks/__init__.py | 1 + openselfsup/hooks/byol_hook.py | 29 +++++ openselfsup/hooks/optimizer_hook.py | 16 ++- openselfsup/models/__init__.py | 1 + openselfsup/models/backbones/resnet.py | 2 +- openselfsup/models/byol.py | 84 ++++++++++++ openselfsup/models/heads/__init__.py | 1 + openselfsup/models/heads/latent_pred_head.py | 58 +++++++++ openselfsup/models/necks.py | 123 +++++++++--------- openselfsup/utils/optimizers.py | 78 +++++++---- openselfsup/version.py | 6 +- setup.py | 2 +- tools/dist_test.sh | 17 +++ tools/extract.py | 4 + tools/prepare_data/convert_subset.py | 35 +++++ tools/publish_model.py | 27 ++-- tools/srun_test.sh | 30 +++++ 50 files changed, 839 insertions(+), 241 deletions(-) rename benchmarks/{dist_test_linear.sh => dist_train_linear.sh} (66%) create mode 100644 benchmarks/dist_train_semi.sh rename benchmarks/{srun_test_linear.sh => srun_train_linear.sh} (65%) create mode 100644 benchmarks/srun_train_semi.sh rename configs/{ => benchmarks}/linear_classification/imagenet/r50_moco.py (97%) rename configs/{ => benchmarks}/linear_classification/imagenet/r50_moco_sobel.py (97%) rename configs/{ => benchmarks}/linear_classification/imagenet/r50_multihead.py (97%) rename configs/{ => benchmarks}/linear_classification/imagenet/r50_multihead_sobel.py (97%) rename configs/{ => benchmarks}/linear_classification/places205/r50_multihead.py (99%) rename configs/{ => benchmarks}/linear_classification/places205/r50_multihead_sobel.py (99%) rename configs/{classification => benchmarks/semi_classification}/imagenet_10percent/r50.py (100%) rename configs/{classification => benchmarks/semi_classification}/imagenet_10percent/r50_sobel.py (100%) create mode 100644 configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head1.py create mode 100644 configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head10.py create mode 100644 configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head100.py rename configs/{classification/imagenet_1percent/r50_sobel.py => benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head1_sobel.py} (76%) create mode 100644 configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head1.py create mode 100644 configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head10.py create mode 100644 configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head100.py delete mode 100644 configs/classification/imagenet_1percent/r50.py create mode 100644 configs/selfsup/byol/r50.py create mode 100644 configs/selfsup/moco/r50_v2_simclr_neck.py rename configs/selfsup/simclr/{r50_bs256_simclr_neck.py => r50_bs256_ep200.py} (90%) rename configs/selfsup/simclr/{r50_bs256.py => r50_bs256_ep200_mocov2_neck.py} (87%) rename configs/selfsup/simclr/{r50_bs512.py => r50_bs512_ep200.py} (86%) create mode 100644 openselfsup/datasets/byol.py create mode 100644 openselfsup/hooks/byol_hook.py create mode 100644 openselfsup/models/byol.py create mode 100644 openselfsup/models/heads/latent_pred_head.py create mode 100644 tools/dist_test.sh create mode 100644 tools/prepare_data/convert_subset.py create mode 100644 tools/srun_test.sh diff --git a/README.md b/README.md index cf7212c3..a5811bc0 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ # OpenSelfSup +**News** +OpenSelfSup now supports [BYOL](https://arxiv.org/pdf/2006.07733.pdf)! + ## Introduction The master branch works with **PyTorch 1.1** or higher. @@ -16,19 +19,10 @@ Below is the relations among Unsupervised Learning, Self-Supervised Learning and ### Major features - **All methods in one repository** - -| | Support | -|-------------------------------------------------------------------------------------------------------------------------------------------------------|:--------:| -| [ImageNet](https://link.springer.com/article/10.1007/s11263-015-0816-y?sa_campaign=email/event/articleAuthor/onlineFirst#) | ✓ | -| [Relative-Loc](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Doersch_Unsupervised_Visual_Representation_ICCV_2015_paper.pdf) | progress | -| [Rotation-Pred](https://arxiv.org/abs/1803.07728) | ✓ | -| [DeepCluster](https://arxiv.org/abs/1807.05520) | ✓ | -| [ODC](http://openaccess.thecvf.com/content_CVPR_2020/papers/Zhan_Online_Deep_Clustering_for_Unsupervised_Representation_Learning_CVPR_2020_paper.pdf) | ✓ | -| [NPID](https://arxiv.org/abs/1805.01978) | ✓ | -| [MoCo](https://arxiv.org/abs/1911.05722) | ✓ | -| [MoCo v2](https://arxiv.org/abs/2003.04297) | ✓ | -| [SimCLR](https://arxiv.org/abs/2002.05709) | ✓ | -| [PIRL](http://openaccess.thecvf.com/content_CVPR_2020/papers/Misra_Self-Supervised_Learning_of_Pretext-Invariant_Representations_CVPR_2020_paper.pdf) | progress | + +* For comprehensive comparison in all benchmarks, refer to [MODEL_ZOO.md](docs/MODEL_ZOO.md). + +
MethodVOC07 SVM (best layer)ImageNet (best layer)
ImageNet87.1776.17
Random30.2213.70
Relative-Loc
Rotation-Pred67.3854.99
DeepCluster74.26
NPID74.5056.61
ODC78.4257.6
MoCo79.1860.60
MoCo v284.0566.72
SimCLR78.95
BYOL
- **Flexibility & Extensibility** diff --git a/benchmarks/dist_test_linear.sh b/benchmarks/dist_train_linear.sh similarity index 66% rename from benchmarks/dist_test_linear.sh rename to benchmarks/dist_train_linear.sh index c1a5b373..b1f5de4c 100755 --- a/benchmarks/dist_test_linear.sh +++ b/benchmarks/dist_train_linear.sh @@ -3,9 +3,9 @@ set -e set -x -CFG=$1 # use cfgs under "configs/linear_classification/" +CFG=$1 # use cfgs under "configs/benchmarks/linear_classification/" PRETRAIN=$2 -PY_ARGS=${@:3} +PY_ARGS=${@:3} # --resume_from --deterministic GPUS=1 # in the standard setting, GPUS=1 PORT=${PORT:-29500} @@ -22,10 +22,3 @@ python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ $CFG \ --pretrained $PRETRAIN \ --work_dir $WORK_DIR --seed 0 --launcher="pytorch" ${PY_ARGS} - -# test -python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ - tools/test.py \ - $CFG \ - $WORK_DIR/latest.pth \ - --work_dir $WORK_DIR --launcher="pytorch" diff --git a/benchmarks/dist_train_semi.sh b/benchmarks/dist_train_semi.sh new file mode 100644 index 00000000..b6d7e37b --- /dev/null +++ b/benchmarks/dist_train_semi.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -e +set -x + +CFG=$1 # use cfgs under "configs/benchmarks/semi_classification/imagenet_*percent/" +PRETRAIN=$2 +PY_ARGS=${@:3} +GPUS=4 # in the standard setting, GPUS=4 +PORT=${PORT:-29500} + +if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then + echo "ERROR: Missing arguments." + exit +fi + +WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)" + +# train +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + tools/train.py \ + $CFG \ + --pretrained $PRETRAIN \ + --work_dir $WORK_DIR --seed 0 --launcher="pytorch" ${PY_ARGS} diff --git a/benchmarks/srun_test_linear.sh b/benchmarks/srun_train_linear.sh similarity index 65% rename from benchmarks/srun_test_linear.sh rename to benchmarks/srun_train_linear.sh index c6969037..0ada0276 100644 --- a/benchmarks/srun_test_linear.sh +++ b/benchmarks/srun_train_linear.sh @@ -29,18 +29,3 @@ srun -p ${PARTITION} \ $CFG \ --pretrained $PRETRAIN \ --work_dir $WORK_DIR --seed 0 --launcher="slurm" ${PY_ARGS} - -# test -GLOG_vmodule=MemcachedClient=-1 \ -srun -p ${PARTITION} \ - --job-name=${JOB_NAME} \ - --gres=gpu:${GPUS_PER_NODE} \ - --ntasks=${GPUS} \ - --ntasks-per-node=${GPUS_PER_NODE} \ - --cpus-per-task=${CPUS_PER_TASK} \ - --kill-on-bad-exit=1 \ - ${SRUN_ARGS} \ - python -u tools/test.py \ - $CFG \ - $WORK_DIR/latest.pth \ - --work_dir $WORK_DIR --launcher="slurm" diff --git a/benchmarks/srun_train_semi.sh b/benchmarks/srun_train_semi.sh new file mode 100644 index 00000000..a1fc317c --- /dev/null +++ b/benchmarks/srun_train_semi.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +set -e +set -x + +PARTITION=$1 +CFG=$2 +PRETRAIN=$3 +PY_ARGS=${@:4} +JOB_NAME="openselfsup" +GPUS=8 # in the standard setting, GPUS=8 +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} + +WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)" + +# train +GLOG_vmodule=MemcachedClient=-1 \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/train.py \ + $CFG \ + --pretrained $PRETRAIN \ + --work_dir $WORK_DIR --seed 0 --launcher="slurm" ${PY_ARGS} diff --git a/configs/linear_classification/imagenet/r50_moco.py b/configs/benchmarks/linear_classification/imagenet/r50_moco.py similarity index 97% rename from configs/linear_classification/imagenet/r50_moco.py rename to configs/benchmarks/linear_classification/imagenet/r50_moco.py index 719d8320..2d63e834 100644 --- a/configs/linear_classification/imagenet/r50_moco.py +++ b/configs/benchmarks/linear_classification/imagenet/r50_moco.py @@ -1,4 +1,4 @@ -_base_ = '../../base.py' +_base_ = '../../../base.py' # model settings model = dict( type='Classification', @@ -39,7 +39,7 @@ test_pipeline = [ ] data = dict( imgs_per_gpu=256, # total 256 - workers_per_gpu=8, + workers_per_gpu=5, train=dict( type=dataset_type, data_source=dict( diff --git a/configs/linear_classification/imagenet/r50_moco_sobel.py b/configs/benchmarks/linear_classification/imagenet/r50_moco_sobel.py similarity index 97% rename from configs/linear_classification/imagenet/r50_moco_sobel.py rename to configs/benchmarks/linear_classification/imagenet/r50_moco_sobel.py index 4b341f16..515e29b4 100644 --- a/configs/linear_classification/imagenet/r50_moco_sobel.py +++ b/configs/benchmarks/linear_classification/imagenet/r50_moco_sobel.py @@ -1,4 +1,4 @@ -_base_ = '../../base.py' +_base_ = '../../../base.py' # model settings model = dict( type='Classification', @@ -39,7 +39,7 @@ test_pipeline = [ ] data = dict( imgs_per_gpu=256, # total 256 - workers_per_gpu=8, + workers_per_gpu=5, train=dict( type=dataset_type, data_source=dict( diff --git a/configs/linear_classification/imagenet/r50_multihead.py b/configs/benchmarks/linear_classification/imagenet/r50_multihead.py similarity index 97% rename from configs/linear_classification/imagenet/r50_multihead.py rename to configs/benchmarks/linear_classification/imagenet/r50_multihead.py index 8fd98f58..5cbe5f2e 100644 --- a/configs/linear_classification/imagenet/r50_multihead.py +++ b/configs/benchmarks/linear_classification/imagenet/r50_multihead.py @@ -1,4 +1,4 @@ -_base_ = '../../base.py' +_base_ = '../../../base.py' # model settings model = dict( type='Classification', @@ -51,7 +51,7 @@ test_pipeline = [ ] data = dict( imgs_per_gpu=256, # total 256 - workers_per_gpu=8, + workers_per_gpu=5, train=dict( type=dataset_type, data_source=dict( diff --git a/configs/linear_classification/imagenet/r50_multihead_sobel.py b/configs/benchmarks/linear_classification/imagenet/r50_multihead_sobel.py similarity index 97% rename from configs/linear_classification/imagenet/r50_multihead_sobel.py rename to configs/benchmarks/linear_classification/imagenet/r50_multihead_sobel.py index 237a83e5..2960c50e 100644 --- a/configs/linear_classification/imagenet/r50_multihead_sobel.py +++ b/configs/benchmarks/linear_classification/imagenet/r50_multihead_sobel.py @@ -1,4 +1,4 @@ -_base_ = '../../base.py' +_base_ = '../../../base.py' # model settings model = dict( type='Classification', @@ -51,7 +51,7 @@ test_pipeline = [ ] data = dict( imgs_per_gpu=256, # total 256 - workers_per_gpu=8, + workers_per_gpu=5, train=dict( type=dataset_type, data_source=dict( diff --git a/configs/linear_classification/places205/r50_multihead.py b/configs/benchmarks/linear_classification/places205/r50_multihead.py similarity index 99% rename from configs/linear_classification/places205/r50_multihead.py rename to configs/benchmarks/linear_classification/places205/r50_multihead.py index f9e6c0f2..60029cf9 100644 --- a/configs/linear_classification/places205/r50_multihead.py +++ b/configs/benchmarks/linear_classification/places205/r50_multihead.py @@ -51,7 +51,7 @@ test_pipeline = [ ] data = dict( imgs_per_gpu=256, # total 256 - workers_per_gpu=8, + workers_per_gpu=5, train=dict( type=dataset_type, data_source=dict( diff --git a/configs/linear_classification/places205/r50_multihead_sobel.py b/configs/benchmarks/linear_classification/places205/r50_multihead_sobel.py similarity index 99% rename from configs/linear_classification/places205/r50_multihead_sobel.py rename to configs/benchmarks/linear_classification/places205/r50_multihead_sobel.py index dca80e97..7edc9fdd 100644 --- a/configs/linear_classification/places205/r50_multihead_sobel.py +++ b/configs/benchmarks/linear_classification/places205/r50_multihead_sobel.py @@ -51,7 +51,7 @@ test_pipeline = [ ] data = dict( imgs_per_gpu=256, # total 256 - workers_per_gpu=8, + workers_per_gpu=5, train=dict( type=dataset_type, data_source=dict( diff --git a/configs/classification/imagenet_10percent/r50.py b/configs/benchmarks/semi_classification/imagenet_10percent/r50.py similarity index 100% rename from configs/classification/imagenet_10percent/r50.py rename to configs/benchmarks/semi_classification/imagenet_10percent/r50.py diff --git a/configs/classification/imagenet_10percent/r50_sobel.py b/configs/benchmarks/semi_classification/imagenet_10percent/r50_sobel.py similarity index 100% rename from configs/classification/imagenet_10percent/r50_sobel.py rename to configs/benchmarks/semi_classification/imagenet_10percent/r50_sobel.py diff --git a/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head1.py b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head1.py new file mode 100644 index 00000000..dfb6c97f --- /dev/null +++ b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head1.py @@ -0,0 +1,4 @@ +_base_ = 'base.py' +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005, + paramwise_options={'\Ahead.': dict(lr_mult=1)}) diff --git a/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head10.py b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head10.py new file mode 100644 index 00000000..a8fe6d76 --- /dev/null +++ b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head10.py @@ -0,0 +1,4 @@ +_base_ = 'base.py' +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005, + paramwise_options={'\Ahead.': dict(lr_mult=10)}) diff --git a/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head100.py b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head100.py new file mode 100644 index 00000000..12a80442 --- /dev/null +++ b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head100.py @@ -0,0 +1,4 @@ +_base_ = 'base.py' +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005, + paramwise_options={'\Ahead.': dict(lr_mult=100)}) diff --git a/configs/classification/imagenet_1percent/r50_sobel.py b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head1_sobel.py similarity index 76% rename from configs/classification/imagenet_1percent/r50_sobel.py rename to configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head1_sobel.py index ecd08758..ed16a61f 100644 --- a/configs/classification/imagenet_1percent/r50_sobel.py +++ b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_01_head1_sobel.py @@ -1,4 +1,4 @@ -_base_ = '../../base.py' +_base_ = '../../../base.py' # model settings model = dict( type='Classification', @@ -37,7 +37,7 @@ test_pipeline = [ dict(type='Normalize', **img_norm_cfg), ] data = dict( - imgs_per_gpu=32, # total 256 + imgs_per_gpu=64, # total 256 workers_per_gpu=2, train=dict( type=dataset_type, @@ -50,11 +50,28 @@ data = dict( data_source=dict( list_file=data_test_list, root=data_test_root, **data_source_cfg), pipeline=test_pipeline)) +# additional hooks +custom_hooks = [ + dict( + type='ValidateHook', + dataset=data['val'], + initial=False, + interval=20, + imgs_per_gpu=32, + workers_per_gpu=2, + eval_param=dict(topk=(1, 5))) +] # optimizer optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005, - paramwise_options={'\Ahead.': dict(lr_mult=100)}) + paramwise_options={'\Ahead.': dict(lr_mult=1)}) # learning policy lr_config = dict(policy='step', step=[12, 16], gamma=0.2) -checkpoint_config = dict(interval=2) +checkpoint_config = dict(interval=20) +log_config = dict( + interval=10, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) # runtime settings total_epochs = 20 diff --git a/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head1.py b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head1.py new file mode 100644 index 00000000..9c469652 --- /dev/null +++ b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head1.py @@ -0,0 +1,4 @@ +_base_ = 'base.py' +# optimizer +optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005, + paramwise_options={'\Ahead.': dict(lr_mult=1)}) diff --git a/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head10.py b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head10.py new file mode 100644 index 00000000..97d2be11 --- /dev/null +++ b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head10.py @@ -0,0 +1,4 @@ +_base_ = 'base.py' +# optimizer +optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005, + paramwise_options={'\Ahead.': dict(lr_mult=10)}) diff --git a/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head100.py b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head100.py new file mode 100644 index 00000000..94e75883 --- /dev/null +++ b/configs/benchmarks/semi_classification/imagenet_1percent/r50_lr0_1_head100.py @@ -0,0 +1,4 @@ +_base_ = 'base.py' +# optimizer +optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005, + paramwise_options={'\Ahead.': dict(lr_mult=100)}) diff --git a/configs/classification/imagenet_1percent/r50.py b/configs/classification/imagenet_1percent/r50.py deleted file mode 100644 index c9db1b65..00000000 --- a/configs/classification/imagenet_1percent/r50.py +++ /dev/null @@ -1,58 +0,0 @@ -_base_ = '../../base.py' -# model settings -model = dict( - type='Classification', - pretrained=None, - backbone=dict( - type='ResNet', - depth=50, - out_indices=[4], # 0: conv-1, x: stage-x - norm_cfg=dict(type='SyncBN')), - head=dict( - type='ClsHead', with_avg_pool=True, in_channels=2048, - num_classes=1000)) -# dataset settings -data_source_cfg = dict( - type='ImageNet', - memcached=True, - mclient_path='/mnt/lustre/share/memcached_client') -data_train_list = 'data/imagenet/meta/train_labeled_1percent.txt' -data_train_root = 'data/imagenet/train' -data_test_list = 'data/imagenet/meta/val_labeled.txt' -data_test_root = 'data/imagenet/val' -dataset_type = 'ClassificationDataset' -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -train_pipeline = [ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), -] -test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), -] -data = dict( - imgs_per_gpu=32, # total 256 - workers_per_gpu=2, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, root=data_train_root, - **data_source_cfg), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, root=data_test_root, **data_source_cfg), - pipeline=test_pipeline)) -# optimizer -optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005, - paramwise_options={'\Ahead.': dict(lr_mult=100)}) -# learning policy -lr_config = dict(policy='step', step=[12, 16], gamma=0.2) -checkpoint_config = dict(interval=2) -# runtime settings -total_epochs = 20 diff --git a/configs/selfsup/byol/r50.py b/configs/selfsup/byol/r50.py new file mode 100644 index 00000000..21ae1f2f --- /dev/null +++ b/configs/selfsup/byol/r50.py @@ -0,0 +1,97 @@ +import copy +_base_ = '../../base.py' +# model settings +model = dict( + type='BYOL', + pretrained=None, + base_momentum=0.996, + backbone=dict( + type='ResNet', + depth=50, + in_channels=3, + out_indices=[4], # 0: conv-1, x: stage-x + norm_cfg=dict(type='BN')), + neck=dict( + type='NonLinearNeckV2', + in_channels=2048, + hid_channels=4096, + out_channels=256, + with_avg_pool=True), + head=dict(type='LatentPredictHead', + predictor=dict(type='NonLinearNeckV2', + in_channels=256, hid_channels=4096, + out_channels=256, with_avg_pool=False))) +# dataset settings +data_source_cfg = dict( + type='ImageNet', + memcached=True, + mclient_path='/mnt/lustre/share/memcached_client') +data_train_list = 'data/imagenet/meta/train.txt' +data_train_root = 'data/imagenet/train' +dataset_type = 'BYOLDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='RandomResizedCrop', size=224, interpolation=3), # bicubic + dict(type='RandomHorizontalFlip'), + dict( + type='RandomAppliedTrans', + transforms=[ + dict( + type='ColorJitter', + brightness=0.4, + contrast=0.4, + saturation=0.2, + hue=0.1) + ], + p=0.8), + dict(type='RandomGrayscale', p=0.2), + dict( + type='RandomAppliedTrans', + transforms=[ + dict( + type='GaussianBlur', + sigma_min=0.1, + sigma_max=2.0, + kernel_size=23) + ], + p=1.), + dict(type='RandomAppliedTrans', + transforms=[dict(type='Solarization')], p=0.), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), +] +train_pipeline1 = copy.deepcopy(train_pipeline) +train_pipeline2 = copy.deepcopy(train_pipeline) +train_pipeline2[4]['p'] = 0.1 # gaussian blur +train_pipeline2[5]['p'] = 0.2 # solarization + +data = dict( + imgs_per_gpu=32, # total 32*8=256 + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_source=dict( + list_file=data_train_list, root=data_train_root, + **data_source_cfg), + pipeline1=train_pipeline1, + pipeline2=train_pipeline2)) +# additional hooks +custom_hooks = [ + dict(type='BYOLHook', end_momentum=1.) +] +# optimizer +optimizer = dict(type='LARS', lr=0.2, weight_decay=0.0000015, momentum=0.9, + paramwise_options={ + '(bn|gn)(\d+)?.(weight|bias)': dict(weight_decay=0., lars_exclude=True), + 'bias': dict(weight_decay=0., lars_exclude=True)}) +# learning policy +lr_config = dict( + policy='CosineAnealing', + min_lr=0., + warmup='linear', + warmup_iters=2, + warmup_ratio=0.0001, # cannot be 0 + warmup_by_epoch=True) +checkpoint_config = dict(interval=10) +# runtime settings +total_epochs = 200 diff --git a/configs/selfsup/moco/r50_v2_simclr_neck.py b/configs/selfsup/moco/r50_v2_simclr_neck.py new file mode 100644 index 00000000..2cfc3c19 --- /dev/null +++ b/configs/selfsup/moco/r50_v2_simclr_neck.py @@ -0,0 +1,76 @@ +_base_ = '../../base.py' +# model settings +model = dict( + type='MOCO', + pretrained=None, + queue_len=65536, + feat_dim=128, + momentum=0.999, + backbone=dict( + type='ResNet', + depth=50, + in_channels=3, + out_indices=[4], # 0: conv-1, x: stage-x + norm_cfg=dict(type='BN')), + neck=dict( + type='NonLinearNeckSimCLR', # SimCLR non-linear neck + in_channels=2048, + hid_channels=2048, + out_channels=128, + num_layers=2, + with_avg_pool=True), + head=dict(type='ContrastiveHead', temperature=0.2)) +# dataset settings +data_source_cfg = dict( + type='ImageNet', + memcached=True, + mclient_path='/mnt/lustre/share/memcached_client') +data_train_list = 'data/imagenet/meta/train.txt' +data_train_root = 'data/imagenet/train' +dataset_type = 'ContrastiveDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='RandomResizedCrop', size=224, scale=(0.2, 1.)), + dict( + type='RandomAppliedTrans', + transforms=[ + dict( + type='ColorJitter', + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.4) + ], + p=0.8), + dict(type='RandomGrayscale', p=0.2), + dict( + type='RandomAppliedTrans', + transforms=[ + dict( + type='GaussianBlur', + sigma_min=0.1, + sigma_max=2.0, + kernel_size=23) + ], + p=0.5), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), +] +data = dict( + imgs_per_gpu=32, # total 32*8=256 + workers_per_gpu=4, + drop_last=True, + train=dict( + type=dataset_type, + data_source=dict( + list_file=data_train_list, root=data_train_root, + **data_source_cfg), + pipeline=train_pipeline)) +# optimizer +optimizer = dict(type='SGD', lr=0.03, weight_decay=0.0001, momentum=0.9) +# learning policy +lr_config = dict(policy='CosineAnealing', min_lr=0.) +checkpoint_config = dict(interval=10) +# runtime settings +total_epochs = 200 diff --git a/configs/selfsup/simclr/r50_bs256_simclr_neck.py b/configs/selfsup/simclr/r50_bs256_ep200.py similarity index 90% rename from configs/selfsup/simclr/r50_bs256_simclr_neck.py rename to configs/selfsup/simclr/r50_bs256_ep200.py index 9766bdf2..f2318c97 100644 --- a/configs/selfsup/simclr/r50_bs256_simclr_neck.py +++ b/configs/selfsup/simclr/r50_bs256_ep200.py @@ -64,14 +64,17 @@ data = dict( **data_source_cfg), pipeline=train_pipeline)) # optimizer -optimizer = dict(type='LARS', lr=0.3, weight_decay=0.000001, momentum=0.9) +optimizer = dict(type='LARS', lr=0.3, weight_decay=0.000001, momentum=0.9, + paramwise_options={ + '(bn|gn)(\d+)?.(weight|bias)': dict(weight_decay=0., lars_exclude=True), + 'bias': dict(weight_decay=0., lars_exclude=True)}) # learning policy lr_config = dict( policy='CosineAnealing', min_lr=0., warmup='linear', warmup_iters=10, - warmup_ratio=0.01, + warmup_ratio=0.0001, warmup_by_epoch=True) checkpoint_config = dict(interval=10) # runtime settings diff --git a/configs/selfsup/simclr/r50_bs256.py b/configs/selfsup/simclr/r50_bs256_ep200_mocov2_neck.py similarity index 87% rename from configs/selfsup/simclr/r50_bs256.py rename to configs/selfsup/simclr/r50_bs256_ep200_mocov2_neck.py index 989e9abb..67cacd89 100644 --- a/configs/selfsup/simclr/r50_bs256.py +++ b/configs/selfsup/simclr/r50_bs256_ep200_mocov2_neck.py @@ -10,7 +10,7 @@ model = dict( out_indices=[4], # 0: conv-1, x: stage-x norm_cfg=dict(type='SyncBN')), neck=dict( - type='NonLinearNeckV1', # simple fc-relu-fc neck + type='NonLinearNeckV1', # simple fc-relu-fc neck in MoCo v2 in_channels=2048, hid_channels=2048, out_channels=128, @@ -63,14 +63,17 @@ data = dict( **data_source_cfg), pipeline=train_pipeline)) # optimizer -optimizer = dict(type='LARS', lr=0.3, weight_decay=0.000001, momentum=0.9) +optimizer = dict(type='LARS', lr=0.3, weight_decay=0.000001, momentum=0.9, + paramwise_options={ + '(bn|gn)(\d+)?.(weight|bias)': dict(weight_decay=0., lars_exclude=True), + 'bias': dict(weight_decay=0., lars_exclude=True)}) # learning policy lr_config = dict( policy='CosineAnealing', min_lr=0., warmup='linear', warmup_iters=10, - warmup_ratio=0.01, + warmup_ratio=0.0001, warmup_by_epoch=True) checkpoint_config = dict(interval=10) # runtime settings diff --git a/configs/selfsup/simclr/r50_bs512.py b/configs/selfsup/simclr/r50_bs512_ep200.py similarity index 86% rename from configs/selfsup/simclr/r50_bs512.py rename to configs/selfsup/simclr/r50_bs512_ep200.py index 8d8749b0..100dd9d2 100644 --- a/configs/selfsup/simclr/r50_bs512.py +++ b/configs/selfsup/simclr/r50_bs512_ep200.py @@ -10,10 +10,11 @@ model = dict( out_indices=[4], # 0: conv-1, x: stage-x norm_cfg=dict(type='SyncBN')), neck=dict( - type='NonLinearNeckV1', + type='NonLinearNeckSimCLR', # SimCLR non-linear neck in_channels=2048, hid_channels=2048, out_channels=128, + num_layers=2, with_avg_pool=True), head=dict(type='ContrastiveHead', temperature=0.1)) # dataset settings @@ -63,14 +64,17 @@ data = dict( **data_source_cfg), pipeline=train_pipeline)) # optimizer -optimizer = dict(type='LARS', lr=0.6, weight_decay=0.000001, momentum=0.9) +optimizer = dict(type='LARS', lr=0.6, weight_decay=0.000001, momentum=0.9, + paramwise_options={ + '(bn|gn)(\d+)?.(weight|bias)': dict(weight_decay=0., lars_exclude=True), + 'bias': dict(weight_decay=0., lars_exclude=True)}) # learning policy lr_config = dict( policy='CosineAnealing', min_lr=0., warmup='linear', warmup_iters=10, - warmup_ratio=0.01, + warmup_ratio=0.0001, warmup_by_epoch=True) checkpoint_config = dict(interval=10) # runtime settings diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 76877891..97b2f2fc 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,2 +1,21 @@ ## Changelog +### v0.2.0 (26/6/2020) + +#### Highlights +* Support BYOL +* Support semi-supervised benchmarks + +#### Bug Fixes +* Fix hash id in publish_model.py + +#### New Features + +* Support BYOL. +* Separate train and test scripts in linear/semi evaluation. +* Support semi-supevised benchmarks: benchmarks/dist_train_semi.sh. +* Move benchmarks related configs into configs/benchmarks/. +* Provide benchmarking results and model download links. +* Support updating network every several interations. +* Support LARS optimizer with nesterov. +* Support excluding specific parameters from LARS adaptation and weight decay required in SimCLR and BYOL. diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 78d70d5e..0248e97e 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -5,7 +5,7 @@ For installation instructions, please see [INSTALL.md](INSTALL.md). ## Train existing methods -**Note**: The default learning rate in config files is for 8 GPUs (except for those under `configs/linear_classification` that use 1 GPU). If using differnt number GPUs, the total batch size will change in proportion, you have to scale the learning rate following `new_lr = old_lr * new_ngpus / old_ngpus`. We recommend to use `tools/dist_train.sh` even with 1 gpu, since some methods do not support non-distributed training. +**Note**: The default learning rate in config files is for 8 GPUs (except for those under `configs/benchmarks/linear_classification` that use 1 GPU). If using differnt number GPUs, the total batch size will change in proportion, you have to scale the learning rate following `new_lr = old_lr * new_ngpus / old_ngpus`. We recommend to use `tools/dist_train.sh` even with 1 gpu, since some methods do not support non-distributed training. ### Train with single/multiple GPUs @@ -55,6 +55,14 @@ GPUS_PER_NODE=4 bash tools/srun_train.sh ${PARTITION} ${CONFIG_FILE} 4 --port 29 GPUS_PER_NODE=4 bash tools/srun_train.sh ${PARTITION} ${CONFIG_FILE} 4 --port 29501 ``` +### What if I do not have so many GPUs? + +Assuming that you only have 1 GPU that can contain 64 images in a batch, while you expect the batch size to be 256, you may add the following line into your config file. It performs network update every 4 iterations. In this way, the equivalent batch size is 256. Of course, it is about 4x slower than using 4 GPUs. Note that the workaround is not applicable for methods like SimCLR which require intra-batch communication. + +```python +optimizer_config = dict(update_interval=4) +``` + ## Benchmarks We provide several standard benchmarks to evaluate representation learning. The config files or scripts for evaluation mentioned below are NOT recommended to be changed if you want to use this repo in your publications. We hope that all methods are under a fair comparison. @@ -70,6 +78,7 @@ bash benchmarks/dist_test_svm_pretrain.sh ${CONFIG_FILE} ${PRETRAIN} ${FEAT_LIST bash benchmarks/dist_test_svm_pretrain.sh ${CONFIG_FILE} "random" ${FEAT_LIST} ${GPUS} ``` Augments: +- `${CONFIG_FILE}` the config file of the self-supervised experiment. - `${FEAT_LIST}` is a string to specify features from layer1 to layer5 to evaluate; e.g., if you want to evaluate layer5 only, then `FEAT_LIST` is `"feat5"`, if you want to evaluate all features, then then `FEAT_LIST` is `"feat1 feat2 feat3 feat4 feat5"` (separated by space). If left empty, the default `FEAT_LIST` is `"feat5"`. - `$GPUS` is the number of GPUs to extract features. @@ -91,20 +100,33 @@ Arguments: **Next**, train and test linear classification: ```shell -bash benchmarks/dist_test_linear.sh ${CONFIG_FILE} ${WEIGHT_FILE} [optional arguments] +# train +bash benchmarks/dist_train_linear.sh ${CONFIG_FILE} ${WEIGHT_FILE} [optional arguments] +# test (unnecessary if have validation in training) +bash tools/dist_test.sh ${CONFIG_FILE} ${GPUS} ${CHECKPOINT} ``` Augments: -- `CONFIG_FILE`: Use config files under "configs/linear_classification/". Note that if you want to test DeepCluster that has a sobel layer before the backbone, you have to use the config file named `*_sobel.py`, e.g., `configs/linear_classification/imagenet/r50_multihead_sobel.py`. +- `CONFIG_FILE`: Use config files under "configs/benchmarks/linear_classification/". Note that if you want to test DeepCluster that has a sobel layer before the backbone, you have to use the config file named `*_sobel.py`, e.g., `configs/benchmarks/linear_classification/imagenet/r50_multihead_sobel.py`. - Optional arguments include: - `--resume_from ${CHECKPOINT_FILE}`: Resume from a previous checkpoint file. - `--deterministic`: Switch on "deterministic" mode which slows down training but the results are reproducible. Working directories: -Where are the checkpoints and logs? E.g., if you use `configs/linear_classification/imagenet/r50_multihead.py` to evaluate `pretrains/moco_v1_epoch200.pth`, then the working directories for this evalution is `work_dirs/linear_classification/imagenet/r50_multihead/moco_v1_epoch200.pth/`. +Where are the checkpoints and logs? E.g., if you use `configs/benchmarks/linear_classification/imagenet/r50_multihead.py` to evaluate `pretrains/moco_v1_epoch200.pth`, then the working directories for this evalution is `work_dirs/benchmarks/linear_classification/imagenet/r50_multihead/moco_v1_epoch200.pth/`. ### ImageNet Semi-Supervised Classification -Coming soon +```shell +# train +bash benchmarks/dist_train_linear.sh ${CONFIG_FILE} ${WEIGHT_FILE} [optional arguments] +# test (unnecessary if have validation in training) +bash tools/dist_test.sh ${CONFIG_FILE} ${GPUS} ${CHECKPOINT} +``` +Augments: +- `CONFIG_FILE`: Use config files under "configs/benchmarks/semi_classification/". Note that if you want to test DeepCluster that has a sobel layer before the backbone, you have to use the config file named `*_sobel.py`, e.g., `configs/benchmarks/semi_classification/imagenet_1percent/r50_sobel.py`. +- Optional arguments include: + - `--resume_from ${CHECKPOINT_FILE}`: Resume from a previous checkpoint file. + - `--deterministic`: Switch on "deterministic" mode which slows down training but the results are reproducible. ### VOC07+12 / COCO17 Object Detection @@ -206,7 +228,7 @@ train_pipeline = [ * Parameter-wise optimization parameters. -You may specify optimization paramters including lr, momentum and weight_decay for a certain group of paramters in the config file with `paramwise_options`. `paramwise_options` is a dict whose key is regular expressions and value is options. Options include 6 fields: lr, lr_mult, momentum, momentum_mult, weight_decay, weight_decay_mult. +You may specify optimization paramters including lr, momentum and weight_decay for a certain group of paramters in the config file with `paramwise_options`. `paramwise_options` is a dict whose key is regular expressions and value is options. Options include 6 fields: lr, lr_mult, momentum, momentum_mult, weight_decay, weight_decay_mult, lars_exclude (only works with LARS optimizer). ```python # this config sets all normalization layers with weight_decay_mult=0.1, diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md index 5e2df16c..f47de241 100644 --- a/docs/MODEL_ZOO.md +++ b/docs/MODEL_ZOO.md @@ -1,6 +1,30 @@ -#Model Zoo +# Model Zoo + +## Pre-trained model download links + +
MethodConfigRemarksDownload link
ImageNet-torchvisionimagenet_r50-21352794.pth
Random-kaimingrandom_r50-5d0fa71b.pth
Relative-Locselfsup/relative_loc/r50.pydefault
Rotation-Predselfsup/rotation_pred/r50.pydefaultrotation_r50-cfab8ebb.pth
DeepClusterselfsup/deepcluster/r50.pydefaultdeepcluster_r50-bb8681e2.pth
NPIDselfsup/npid/r50.pydefaultnpid_r50-dec3df0c.pth
ODCselfsup/odc/r50_v1.pydefaultodc_r50_v1-5af5dd0c.pth
MoCoselfsup/moco/r50_v1.pydefaultmoco_r50_v1-4ad89b5c.pth
MoCo v2selfsup/moco/r50_v2.pydefaultmoco_r50_v2-58f10cfe.pth
selfsup/moco/r50_v2.py-> SimCLR neck
moco_r50_v2_simclr_neck-70379356.pth
SimCLRselfsup/simclr/r50_bs256_ep200.pydefaultsimclr_r50_bs256_ep200-4577e9a6.pth
selfsup/simclr/r50_bs256_ep200_mocov2_neck.py-> MoCo v2 necksimclr_r50_bs256_ep200_mocov2_neck-0d6e5ff2.pth
BYOLselfsup/byol/r50.pydefault
-## Evaluation +## Benchmarks -
ConfigRemarksVOC07 SVMImageNet (Multi)ImageNet (Last)
feat5 by defaultfeat1feat2feat3feat4feat5avgpool
ImageNet-torchvision87.2
Random-kaiming30.2 (feat2)
Relative-Loc
Rotation-Predselfsup/rotation_pred/r50.pydefault67.4 (feat4)
DeepClusterselfsup/deepcluster/r50.pydefault74.3
NPIDselfsup/npid/r50.pydefault74.5
ODCselfsup/odc/r50_v1.pydefault78.214.831.642.555.757.6
MoCoselfsup/moco/r50_v1.pydefault79.215.3233.0844.6857.2760.6
MoCo v2selfsup/moco/r50_v2.pydefault84.115.3534.5745.8160.9666.72
selfsup/moco/r50_v2.pyMoCo_v2 neck -> SimCLR neck
SimCLRselfsup/simclr/r50_bs256.pydefault
selfsup/simclr/r50_bs256_simple_neck.pySimCLR neck -> MoCo_v2 neck77.65
+### VOC07 SVM & SVM Low-shot + +
MethodConfigRemarksBest layerVOC07 SVMImageNet (Multi)
124816326496
ImageNet-torchvisionfeat587.1752.9963.5573.778.7981.7683.7585.1885.97
Random-kaimingfeat230.22
Relative-Locfeat5
Rotation-Predselfsup/rotation_pred/r50.pydefaultfeat467.38
DeepClusterselfsup/deepcluster/r50.pydefaultfeat574.26
NPIDselfsup/npid/r50.pydefaultfeat574.50
ODCselfsup/odc/r50_v1.pydefaultfeat578.42
MoCoselfsup/moco/r50_v1.pydefaultfeat579.18
MoCo v2selfsup/moco/r50_v2.pydefaultfeat584.05
selfsup/moco/r50_v2_simclr_neck.py-> SimCLR neck
feat584.00
SimCLRselfsup/simclr/r50_bs256_ep200.pydefaultfeat578.95
selfsup/simclr/r50_bs256_ep200_mocov2_neck.py-> MoCo v2 neckfeat577.65
BYOLselfsup/byol/r50.pydefault
+ +### ImageNet Linear Classification + +
MethodConfigRemarksImageNet (Multi)ImageNet (Last)
feat1feat2feat3feat4feat5avgpool
ImageNet-torchvision15.1833.9647.8667.5676.1774.12
Random-kaiming4.35
Relative-Locselfsup/relative_loc/r50.pydefault
Rotation-Predselfsup/rotation_pred/r50.pydefault12.8934.3044.9154.9949.09
DeepClusterselfsup/deepcluster/r50.pydefault46.92
NPIDselfsup/npid/r50.pydefault14.2831.2040.6854.4656.61
ODCselfsup/odc/r50_v1.pydefault14.7531.5542.4955.7257.57
MoCoselfsup/moco/r50_v1.pydefault15.3233.0844.6857.2760.6061.02
MoCo v2selfsup/moco/r50_v2.pydefault15.3534.5745.8160.9666.7267.02
selfsup/moco/r50_v2.py-> SimCLR neck
SimCLRselfsup/simclr/r50_bs256.pydefault
BYOLselfsup/byol/r50.pydefault
+ +### Place Linear Classification + + +### ImageNet Semi-Supervised Classification + +**Note** +* In this benchmark, the necks or heads are removed and only the backbone CNN is evaluated by appending a linear classification head. All parameters are fine-tuned. +* When training with 1% ImageNet, we find hyper-parameters especially the learning rate greatly influence the performance. Hence, we prepare a list of settings with the base learning rate from \{0.01, 0.1\} and the learning rate multiplier for the head from \{1, 10, 100\}. We choose the best performing setting for each method. +* Please use `--deterministic` in this benchmark. + +
MethodConfigRemarksOptimal setting for ImageNet 1%ImageNet 1%
top-1top-5
ImageNet-torchvisionr50_lr0_01_head1.py63.1085.73
Random-kaimingr50_lr0_01_head1.py1.564.99
Relative-Locselfsup/relative_loc/r50.pydefault
Rotation-Predselfsup/rotation_pred/r50.pydefaultr50_lr0_01_head100.py18.9844.05
DeepClusterselfsup/deepcluster/r50.pydefaultr50_lr0_01_head1_sobel.py33.4458.62
NPIDselfsup/npid/r50.pydefaultr50_lr0_01_head100.py27.9554.37
ODCselfsup/odc/r50_v1.pydefaultr50_lr0_1_head100.py32.3961.02
MoCoselfsup/moco/r50_v1.pydefaultr50_lr0_01_head100.py33.1561.30
MoCo v2selfsup/moco/r50_v2.pydefaultr50_lr0_01_head100.py38.7167.90
selfsup/moco/r50_v2.py-> SimCLR neck
SimCLRselfsup/simclr/r50_bs256_ep200.pydefaultr50_lr0_01_head100.py36.0964.50
selfsup/simclr/r50_bs256_ep200_mocov2_neck.py-> MoCo v2 neck
BYOLselfsup/byol/r50.pydefault
+ +### PASCAL VOC07+12 Object Detection diff --git a/openselfsup/datasets/__init__.py b/openselfsup/datasets/__init__.py index 11e81bbe..d97a226b 100644 --- a/openselfsup/datasets/__init__.py +++ b/openselfsup/datasets/__init__.py @@ -1,4 +1,5 @@ from .builder import build_dataset +from .byol import BYOLDataset from .data_sources import * from .pipelines import * from .classification import ClassificationDataset diff --git a/openselfsup/datasets/byol.py b/openselfsup/datasets/byol.py new file mode 100644 index 00000000..abe90a63 --- /dev/null +++ b/openselfsup/datasets/byol.py @@ -0,0 +1,35 @@ +import torch +from torch.utils.data import Dataset + +from openselfsup.utils import build_from_cfg + +from torchvision.transforms import Compose + +from .registry import DATASETS, PIPELINES +from .builder import build_datasource + + +@DATASETS.register_module +class BYOLDataset(Dataset): + """Dataset for BYOL. + """ + + def __init__(self, data_source, pipeline1, pipeline2): + self.data_source = build_datasource(data_source) + pipeline1 = [build_from_cfg(p, PIPELINES) for p in pipeline1] + self.pipeline1 = Compose(pipeline1) + pipeline2 = [build_from_cfg(p, PIPELINES) for p in pipeline2] + self.pipeline2 = Compose(pipeline2) + + def __len__(self): + return self.data_source.get_length() + + def __getitem__(self, idx): + img = self.data_source.get_sample(idx) + img1 = self.pipeline1(img) + img2 = self.pipeline2(img) + img_cat = torch.cat((img1.unsqueeze(0), img2.unsqueeze(0)), dim=0) + return dict(img=img_cat) + + def evaluate(self, scores, keyword, logger=None, **kwargs): + raise NotImplemented diff --git a/openselfsup/datasets/classification.py b/openselfsup/datasets/classification.py index 584cd952..7e222d8a 100644 --- a/openselfsup/datasets/classification.py +++ b/openselfsup/datasets/classification.py @@ -35,9 +35,9 @@ class ClassificationDataset(BaseDataset): for k in topk: correct_k = correct[:k].view(-1).float().sum(0).item() acc = correct_k * 100.0 / num - eval_res["{}_acc@{}".format(keyword, k)] = acc + eval_res["{}_top{}".format(keyword, k)] = acc if logger is not None and logger != 'silent': print_log( - "{}_acc@{}: {:.03f}".format(keyword, k, acc), + "{}_top{}: {:.03f}".format(keyword, k, acc), logger=logger) return eval_res diff --git a/openselfsup/datasets/pipelines/transforms.py b/openselfsup/datasets/pipelines/transforms.py index c5715688..46575dd8 100644 --- a/openselfsup/datasets/pipelines/transforms.py +++ b/openselfsup/datasets/pipelines/transforms.py @@ -90,3 +90,19 @@ class GaussianBlur(object): def __repr__(self): repr_str = self.__class__.__name__ return repr_str + + +@PIPELINES.register_module +class Solarization(object): + + def __init__(self, threshold=128): + self.threshold = threshold + + def __call__(self, img): + img = np.array(img) + img = np.where(img < self.threshold, img, 255 -img) + return Image.fromarray(img.astype(np.uint8)) + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str diff --git a/openselfsup/hooks/__init__.py b/openselfsup/hooks/__init__.py index cdcd6cde..584b3b29 100644 --- a/openselfsup/hooks/__init__.py +++ b/openselfsup/hooks/__init__.py @@ -1,4 +1,5 @@ from .builder import build_hook +from .byol_hook import BYOLHook from .deepcluster_hook import DeepClusterHook from .odc_hook import ODCHook from .optimizer_hook import DistOptimizerHook diff --git a/openselfsup/hooks/byol_hook.py b/openselfsup/hooks/byol_hook.py new file mode 100644 index 00000000..38da5436 --- /dev/null +++ b/openselfsup/hooks/byol_hook.py @@ -0,0 +1,29 @@ +from math import cos, pi +from mmcv.runner import Hook + +from .registry import HOOKS + + +@HOOKS.register_module +class BYOLHook(Hook): + '''Hook in BYOL + + This hook including momentum adjustment in BYOL following: + m = 1 - ( 1- m_0) * (cos(pi * k / K) + 1) / 2, + k: current step, K: total steps. + ''' + + def __init__(self, end_momentum=1., **kwargs): + self.end_momentum = end_momentum + + def before_train_iter(self, runner): + assert hasattr(runner.model.module, 'momentum'), \ + "The runner must have attribute \"momentum\" in BYOLHook." + assert hasattr(runner.model.module, 'base_momentum'), \ + "The runner must have attribute \"base_momentum\" in BYOLHook." + cur_iter = runner.iter + max_iter = runner.max_iters + base_m = runner.model.module.base_momentum + m = self.end_momentum - (self.end_momentum - base_m) * ( + cos(pi * cur_iter / float(max_iter)) + 1) / 2 + runner.model.module.momentum = m diff --git a/openselfsup/hooks/optimizer_hook.py b/openselfsup/hooks/optimizer_hook.py index e8c1b7c9..f7c38775 100644 --- a/openselfsup/hooks/optimizer_hook.py +++ b/openselfsup/hooks/optimizer_hook.py @@ -3,14 +3,20 @@ from mmcv.runner import OptimizerHook class DistOptimizerHook(OptimizerHook): - def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1): + def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1): self.grad_clip = grad_clip self.coalesce = coalesce self.bucket_size_mb = bucket_size_mb + self.update_interval = update_interval + + def before_run(self, runner): + runner.optimizer.zero_grad() def after_train_iter(self, runner): - runner.optimizer.zero_grad() + runner.outputs['loss'] /= self.update_interval runner.outputs['loss'].backward() - if self.grad_clip is not None: - self.clip_grads(runner.model.parameters()) - runner.optimizer.step() + if self.every_n_iters(runner, self.update_interval): + if self.grad_clip is not None: + self.clip_grads(runner.model.parameters()) + runner.optimizer.step() + runner.optimizer.zero_grad() diff --git a/openselfsup/models/__init__.py b/openselfsup/models/__init__.py index 94f94385..2ca81628 100644 --- a/openselfsup/models/__init__.py +++ b/openselfsup/models/__init__.py @@ -1,5 +1,6 @@ from .backbones import * # noqa: F401,F403 from .builder import (build_backbone, build_model, build_head, build_loss) +from .byol import BYOL from .heads import * from .classification import Classification from .deepcluster import DeepCluster diff --git a/openselfsup/models/backbones/resnet.py b/openselfsup/models/backbones/resnet.py index db2c5fe0..f76ee975 100644 --- a/openselfsup/models/backbones/resnet.py +++ b/openselfsup/models/backbones/resnet.py @@ -386,7 +386,7 @@ class ResNet(nn.Module): def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) + load_checkpoint(self, pretrained, strict=True, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): diff --git a/openselfsup/models/byol.py b/openselfsup/models/byol.py new file mode 100644 index 00000000..f06d97cd --- /dev/null +++ b/openselfsup/models/byol.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn + +from openselfsup.utils import print_log + +from . import builder +from .registry import MODELS + + +@MODELS.register_module +class BYOL(nn.Module): + '''BYOL unofficial implementation. Paper: https://arxiv.org/abs/2006.07733 + ''' + + def __init__(self, + backbone, + neck=None, + head=None, + pretrained=None, + base_momentum=0.996, + **kwargs): + super(BYOL, self).__init__() + self.online_net = nn.Sequential( + builder.build_backbone(backbone), builder.build_neck(neck)) + self.target_net = nn.Sequential( + builder.build_backbone(backbone), builder.build_neck(neck)) + self.backbone = self.online_net[0] + for param in self.target_net.parameters(): + param.requires_grad = False + self.head = builder.build_head(head) + self.init_weights(pretrained=pretrained) + + self.base_momentum = base_momentum + self.momentum = base_momentum + + def init_weights(self, pretrained=None): + if pretrained is not None: + print_log('load model from: {}'.format(pretrained), logger='root') + self.online_net[0].init_weights(pretrained=pretrained) # backbone + self.online_net[1].init_weights(init_linear='kaiming') # projection + for param_ol, param_tgt in zip(self.online_net.parameters(), + self.target_net.parameters()): + param_tgt.data.copy_(param_ol.data) + # init the predictor in the head + self.head.init_weights() + + @torch.no_grad() + def _momentum_update(self): + """ + Momentum update of the target network. + """ + for param_ol, param_tgt in zip(self.online_net.parameters(), + self.target_net.parameters()): + param_tgt.data = param_tgt.data * self.momentum + \ + param_ol.data * (1. - self.momentum) + + def forward_train(self, img, **kwargs): + assert img.dim() == 5, \ + "Input must have 5 dims, got: {}".format(img.dim()) + img_v1 = img[:, 0, ...].contiguous() + img_v2 = img[:, 1, ...].contiguous() + img_cat1 = torch.cat([img_v1, img_v2], dim=0) + img_cat2 = torch.cat([img_v2, img_v1], dim=0) + # compute query features + proj_online = self.online_net(img_cat1)[0] + with torch.no_grad(): + proj_target = self.target_net(img_cat2)[0].clone().detach() + + losses = self.head(proj_online, proj_target) + self._momentum_update() + return losses + + def forward_test(self, img, **kwargs): + pass + + def forward(self, img, mode='train', **kwargs): + if mode == 'train': + return self.forward_train(img, **kwargs) + elif mode == 'test': + return self.forward_test(img, **kwargs) + elif mode == 'extract': + return self.backbone(img) + else: + raise Exception("No such mode: {}".format(mode)) diff --git a/openselfsup/models/heads/__init__.py b/openselfsup/models/heads/__init__.py index c6bd865c..426b13b6 100644 --- a/openselfsup/models/heads/__init__.py +++ b/openselfsup/models/heads/__init__.py @@ -1,3 +1,4 @@ from .contrastive_head import ContrastiveHead from .cls_head import ClsHead +from .latent_pred_head import LatentPredictHead from .multi_cls_head import MultiClsHead diff --git a/openselfsup/models/heads/latent_pred_head.py b/openselfsup/models/heads/latent_pred_head.py new file mode 100644 index 00000000..50b14dfb --- /dev/null +++ b/openselfsup/models/heads/latent_pred_head.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from mmcv.cnn import normal_init + +from ..registry import HEADS +from .. import builder + +@HEADS.register_module +class LatentPredictHead(nn.Module): + '''Head for contrastive learning. + ''' + + def __init__(self, predictor): + super(LatentPredictHead, self).__init__() + self.predictor = builder.build_neck(predictor) + + def init_weights(self, init_linear='normal'): + self.predictor.init_weights(init_linear=init_linear) + + def forward(self, input, target): + ''' + Args: + input (Tensor): NxC input features. + target (Tensor): NxC target features. + ''' + N = input.size(0) + pred = self.predictor([input])[0] + pred_norm = nn.functional.normalize(pred, dim=1) + target_norm = nn.functional.normalize(target, dim=1) + loss = 2 - 2 * (pred_norm * target_norm).sum() / N + return dict(loss=loss) + + +@HEADS.register_module +class LatentClsHead(nn.Module): + '''Head for contrastive learning. + ''' + + def __init__(self, predictor): + super(LatentClsHead, self).__init__() + self.predictor = nn.Linear(predictor.in_channels, + predictor.num_classes) + self.criterion = nn.CrossEntropyLoss() + + def init_weights(self, init_linear='normal'): + normal_init(self.predictor, std=0.01) + + def forward(self, input, target): + ''' + Args: + input (Tensor): NxC input features. + target (Tensor): NxC target features. + ''' + pred = self.predictor(input) + with torch.no_grad(): + label = torch.argmax(self.predictor(target), dim=1).detach() + loss = self.criterion(pred, label) + return dict(loss=loss) diff --git a/openselfsup/models/necks.py b/openselfsup/models/necks.py index f4cd140e..9fa3bb0a 100644 --- a/openselfsup/models/necks.py +++ b/openselfsup/models/necks.py @@ -7,8 +7,27 @@ from .registry import NECKS from .utils import build_norm_layer +def _init_weights(module, init_linear='normal'): + assert init_linear in ['normal', 'kaiming'], \ + "Undefined init_linear: {}".format(init_linear) + for m in module.modules(): + if isinstance(m, nn.Linear): + if init_linear == 'normal': + normal_init(m, std=0.01) + else: + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, + (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @NECKS.register_module class LinearNeck(nn.Module): + '''Linear neck: fc only + ''' def __init__(self, in_channels, out_channels, with_avg_pool=True): super(LinearNeck, self).__init__() @@ -18,31 +37,19 @@ class LinearNeck(nn.Module): self.fc = nn.Linear(in_channels, out_channels) def init_weights(self, init_linear='normal'): - assert init_linear in ['normal', 'kaiming'], \ - "Undefined init_linear: {}".format(init_linear) - for m in self.modules(): - if isinstance(m, nn.Linear): - if init_linear == 'normal': - normal_init(m, std=0.01) - else: - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, - (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) + _init_weights(self, init_linear) def forward(self, x): assert len(x) == 1 + x = x[0] if self.with_avg_pool: - x = self.avgpool(x[0]) + x = self.avgpool(x) return [self.fc(x.view(x.size(0), -1))] @NECKS.register_module class NonLinearNeckV0(nn.Module): - '''The non-linear neck in ODC + '''The non-linear neck in ODC, fc-bn-relu-dropout-fc-relu ''' def __init__(self, @@ -61,25 +68,13 @@ class NonLinearNeckV0(nn.Module): nn.Linear(hid_channels, out_channels), nn.ReLU(inplace=True)) def init_weights(self, init_linear='normal'): - assert init_linear in ['normal', 'kaiming'], \ - "Undefined init_linear: {}".format(init_linear) - for m in self.modules(): - if isinstance(m, nn.Linear): - if init_linear == 'normal': - normal_init(m, std=0.01) - else: - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, - (nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) + _init_weights(self, init_linear) def forward(self, x): assert len(x) == 1 + x = x[0] if self.with_avg_pool: - x = self.avgpool(x[0]) + x = self.avgpool(x) return [self.mlp(x.view(x.size(0), -1))] @@ -101,25 +96,43 @@ class NonLinearNeckV1(nn.Module): nn.Linear(hid_channels, out_channels)) def init_weights(self, init_linear='normal'): - assert init_linear in ['normal', 'kaiming'], \ - "Undefined init_linear: {}".format(init_linear) - for m in self.modules(): - if isinstance(m, nn.Linear): - if init_linear == 'normal': - normal_init(m, std=0.01) - else: - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, - (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) + _init_weights(self, init_linear) def forward(self, x): assert len(x) == 1 + x = x[0] if self.with_avg_pool: - x = self.avgpool(x[0]) + x = self.avgpool(x) + return [self.mlp(x.view(x.size(0), -1))] + + +@NECKS.register_module +class NonLinearNeckV2(nn.Module): + '''The non-linear neck in byol: fc-bn-relu-fc + ''' + def __init__(self, + in_channels, + hid_channels, + out_channels, + with_avg_pool=True): + super(NonLinearNeckV2, self).__init__() + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), + nn.BatchNorm1d(hid_channels), + nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels)) + + def init_weights(self, init_linear='normal'): + _init_weights(self, init_linear) + + def forward(self, x): + assert len(x) == 1, "Got: {}".format(len(x)) + x = x[0] + if self.with_avg_pool: + x = self.avgpool(x) return [self.mlp(x.view(x.size(0), -1))] @@ -178,20 +191,7 @@ class NonLinearNeckSimCLR(nn.Module): self.bn_names.append("bn{}".format(i)) def init_weights(self, init_linear='normal'): - assert init_linear in ['normal', 'kaiming'], \ - "Undefined init_linear: {}".format(init_linear) - for m in self.modules(): - if isinstance(m, nn.Linear): - if init_linear == 'normal': - normal_init(m, std=0.01) - else: - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, - (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) + _init_weights(self, init_linear) def _forward_syncbn(self, module, x): assert x.dim() == 2 @@ -203,8 +203,9 @@ class NonLinearNeckSimCLR(nn.Module): def forward(self, x): assert len(x) == 1 + x = x[0] if self.with_avg_pool: - x = self.avgpool(x[0]) + x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc0(x) x = self._forward_syncbn(self.bn0, x) diff --git a/openselfsup/utils/optimizers.py b/openselfsup/utils/optimizers.py index 8e756ea5..d6ee8467 100644 --- a/openselfsup/utils/optimizers.py +++ b/openselfsup/utils/optimizers.py @@ -1,7 +1,7 @@ -""" Layer-wise adaptive rate scaling for SGD in PyTorch! """ import torch from torch.optim.optimizer import Optimizer, required from torch.optim import * +from .larc import LARC class LARS(Optimizer): @@ -14,15 +14,17 @@ class LARS(Optimizer): momentum (float, optional): momentum factor (default: 0) ("m") weight_decay (float, optional): weight decay (L2 penalty) (default: 0) ("\beta") + dampening (float, optional): dampening for momentum (default: 0) eta (float, optional): LARS coefficient - max_epoch: maximum training epoch to determine polynomial LR decay. + nesterov (bool, optional): enables Nesterov momentum (default: False) Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. Large Batch Training of Convolutional Networks: https://arxiv.org/abs/1708.03888 Example: - >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) + >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, + >>> weight_decay=1e-4, eta=1e-3) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() @@ -31,9 +33,11 @@ class LARS(Optimizer): def __init__(self, params, lr=required, - momentum=.9, - weight_decay=.0005, - eta=0.001): + momentum=0, + dampening=0, + weight_decay=0, + eta=0.001, + nesterov=False): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -45,51 +49,69 @@ class LARS(Optimizer): raise ValueError("Invalid LARS coefficient value: {}".format(eta)) defaults = dict( - lr=lr, momentum=momentum, weight_decay=weight_decay, eta=eta) + lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov, eta=eta) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(LARS, self).__init__(params, defaults) + def __setstate__(self, state): + super(LARS, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. - epoch: current epoch to calculate polynomial LR decay schedule. - if None, uses self.epoch and increments it. """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] + dampening = group['dampening'] eta = group['eta'] + nesterov = group['nesterov'] lr = group['lr'] + lars_exclude = group.get('lars_exclude', False) for p in group['params']: if p.grad is None: continue - param_state = self.state[p] - d_p = p.grad.data + d_p = p.grad - weight_norm = torch.norm(p.data) - grad_norm = torch.norm(d_p) - - # Compute local learning rate for this layer - local_lr = eta * weight_norm / \ - (grad_norm + weight_decay * weight_norm) - - # Update the momentum term - actual_lr = local_lr * lr - - if 'momentum_buffer' not in param_state: - buf = param_state['momentum_buffer'] = \ - torch.zeros_like(p.data) + if lars_exclude: + local_lr = 1. else: - buf = param_state['momentum_buffer'] - buf.mul_(momentum).add_(actual_lr, d_p + weight_decay * p.data) - p.data.add_(-buf) + weight_norm = torch.norm(p).item() + grad_norm = torch.norm(d_p).item() + # Compute local learning rate for this layer + local_lr = eta * weight_norm / \ + (grad_norm + weight_decay * weight_norm) + + actual_lr = local_lr * lr + d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = \ + torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + p.add_(-d_p) return loss diff --git a/openselfsup/version.py b/openselfsup/version.py index 7b4c41b6..b0338668 100644 --- a/openselfsup/version.py +++ b/openselfsup/version.py @@ -1,5 +1,5 @@ # GENERATED VERSION FILE -# TIME: Wed Jun 17 21:13:55 2020 +# TIME: Mon Jun 29 00:10:22 2020 -__version__ = '0.1.0+696d049' -short_version = '0.1.0' +__version__ = '0.2.0+6891da7' +short_version = '0.2.0' diff --git a/setup.py b/setup.py index ea3d3a53..2e2a392c 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ def readme(): MAJOR = 0 -MINOR = 1 +MINOR = 2 PATCH = 0 SUFFIX = '' if PATCH != '': diff --git a/tools/dist_test.sh b/tools/dist_test.sh new file mode 100644 index 00000000..38f33896 --- /dev/null +++ b/tools/dist_test.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +set -x + +CFG=$1 +GPUS=$2 +CHECKPOINT=$3 +PORT=${PORT:-29500} + +WORK_DIR="$(dirname $CHECKPOINT)/" + +# test +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + tools/test.py \ + $CFG \ + $CHECKPOINT \ + --work_dir $WORK_DIR --launcher="pytorch" diff --git a/tools/extract.py b/tools/extract.py index d5d0b2d0..c4075416 100644 --- a/tools/extract.py +++ b/tools/extract.py @@ -74,6 +74,8 @@ def parse_args(): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--port', type=int, default=29500, + help='port only works when launcher=="slurm"') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -107,6 +109,8 @@ def main(): distributed = False else: distributed = True + if args.launcher == 'slurm': + cfg.dist_params['port'] = args.port init_dist(args.launcher, **cfg.dist_params) # create work_dir diff --git a/tools/prepare_data/convert_subset.py b/tools/prepare_data/convert_subset.py new file mode 100644 index 00000000..3513705e --- /dev/null +++ b/tools/prepare_data/convert_subset.py @@ -0,0 +1,35 @@ +''' +SimCLR provides list files for semi-supervised benchmarks: +https://github.com/google-research/simclr/tree/master/imagenet_subsets/ +This script convert the list files into the required format in OpenSelfSup. +''' +import argparse + +parser = argparse.ArgumentParser( + description='Convert ImageNet subset lists provided by simclr.') +parser.add_argument('input', help='Input list file.') +parser.add_argument('output', help='Output list file.') +args = parser.parse_args() + +# create dict +with open("data/imagenet/meta/train_labeled.txt", 'r') as f: + lines = f.readlines() +keys = [l.split('/')[0] for l in lines] +labels = [l.strip().split()[1] for l in lines] +mapping = {} +for k,l in zip(keys, labels): + if k not in mapping: + mapping[k] = l + else: + assert mapping[k] == l + +# convert +with open(args.input, 'r') as f: + lines = f.readlines() +fns = [l.strip() for l in lines] +sample_keys = [l.split('_')[0] for l in lines] +sample_labels = [mapping[k] for k in sample_keys] +output_lines = ["{}/{} {}\n".format(k, fn, l) for \ + k,fn,l in zip(sample_keys, fns, sample_labels)] +with open(args.output, 'w') as f: + f.writelines(output_lines) diff --git a/tools/publish_model.py b/tools/publish_model.py index 4dd35332..9da1cf83 100644 --- a/tools/publish_model.py +++ b/tools/publish_model.py @@ -1,8 +1,6 @@ import argparse import subprocess -import torch - def parse_args(): parser = argparse.ArgumentParser( @@ -12,22 +10,23 @@ def parse_args(): return args -def process_checkpoint(in_file, out_file): - checkpoint = torch.load(in_file, map_location='cpu') - # remove optimizer for smaller file size - if 'optimizer' in checkpoint: - del checkpoint['optimizer'] - # if it is necessary to remove some sensitive data in checkpoint['meta'], - # add the code here. - torch.save(checkpoint, in_file + ".tmp.pth") - sha = subprocess.check_output(['sha256sum', out_file]).decode() - final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth' - subprocess.Popen(['mv', in_file + ".tmp.pth", final_file]) +def process_checkpoint(in_file): + tmp_file = in_file + ".tmp" + subprocess.Popen(['cp', in_file, tmp_file]) + sha = subprocess.check_output(['sha256sum', tmp_file]).decode() + out_file = in_file + if out_file.endswith('.pth'): + out_file = out_file[:-4] + final_file = out_file + f'-{sha[:8]}.pth' + assert final_file != in_file, \ + "The output filename is the same as the input file." + print("Output file: {}".format(final_file)) + subprocess.Popen(['mv', tmp_file, final_file]) def main(): args = parse_args() - process_checkpoint(args.in_file, args.in_file) + process_checkpoint(args.in_file) if __name__ == '__main__': diff --git a/tools/srun_test.sh b/tools/srun_test.sh new file mode 100644 index 00000000..41b1772c --- /dev/null +++ b/tools/srun_test.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +set -x + +PARTITION=$1 +CFG=$2 +GPUS=$3 +CHECKPOINT=$4 +PY_ARGS=${@:5} # --port +JOB_NAME="openselfsup" +GPUS_PER_NODE=${GPUS_PER_NODE:-1} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} + +WORK_DIR="$(dirname $CHECKPOINT)/" + +# test +GLOG_vmodule=MemcachedClient=-1 \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/test.py \ + $CFG \ + $CHECKPOINT \ + --work_dir $WORK_DIR --launcher="slurm" $PY_ARGS