diff --git a/benchmarks/dist_test_linear.sh b/benchmarks/dist_test_linear.sh index ad3a9b1e..c1a5b373 100755 --- a/benchmarks/dist_test_linear.sh +++ b/benchmarks/dist_test_linear.sh @@ -9,6 +9,11 @@ PY_ARGS=${@:3} GPUS=1 # in the standard setting, GPUS=1 PORT=${PORT:-29500} +if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then + echo "ERROR: Missing arguments." + exit +fi + WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)" # train diff --git a/benchmarks/dist_test_semi.sh b/benchmarks/dist_test_semi.sh index 8ca3b736..9764a3d3 100644 --- a/benchmarks/dist_test_semi.sh +++ b/benchmarks/dist_test_semi.sh @@ -8,6 +8,11 @@ PRETRAIN=$2 PY_ARGS=${@:3} GPUS=8 # in the standard setting, GPUS=8 +if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then + echo "ERROR: Missing arguments." + exit +fi + WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)" # train diff --git a/benchmarks/dist_test_svm_epoch.sh b/benchmarks/dist_test_svm_epoch.sh index 9daa32cf..31ac42c4 100644 --- a/benchmarks/dist_test_svm_epoch.sh +++ b/benchmarks/dist_test_svm_epoch.sh @@ -8,6 +8,11 @@ FEAT_LIST=$3 # e.g.: "feat5", "feat4 feat5". If leave empty, the default is "fea GPUS=${4:-8} WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/ +if [ "$CFG" == "" ] || [ "$EPOCH" == "" ]; then + echo "ERROR: Missing arguments." + exit +fi + if [ ! -f $WORK_DIR/epoch_${EPOCH}.pth ]; then echo "ERROR: File not exist: $WORK_DIR/epoch_${EPOCH}.pth" exit diff --git a/benchmarks/dist_test_svm_pretrain.sh b/benchmarks/dist_test_svm_pretrain.sh index b12bade8..911fe07d 100644 --- a/benchmarks/dist_test_svm_pretrain.sh +++ b/benchmarks/dist_test_svm_pretrain.sh @@ -8,7 +8,12 @@ FEAT_LIST=$3 # e.g.: "feat5", "feat4 feat5". If leave empty, the default is "fea GPUS=${4:-8} WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/ -if [ ! -f $PRETRAIN ] and [ "$PRETRAIN" != "random" ]; then +if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then + echo "ERROR: Missing arguments." + exit +fi + +if [ ! -f $PRETRAIN ] && [ "$PRETRAIN" != "random" ]; then echo "ERROR: PRETRAIN should be a file or a string \"random\", got: $PRETRAIN" exit fi diff --git a/configs/selfsup/simclr/r50_bs256.py b/configs/selfsup/simclr/r50_bs256.py index cb087730..989e9abb 100644 --- a/configs/selfsup/simclr/r50_bs256.py +++ b/configs/selfsup/simclr/r50_bs256.py @@ -10,7 +10,7 @@ model = dict( out_indices=[4], # 0: conv-1, x: stage-x norm_cfg=dict(type='SyncBN')), neck=dict( - type='NonLinearNeckV1', + type='NonLinearNeckV1', # simple fc-relu-fc neck in_channels=2048, hid_channels=2048, out_channels=128, diff --git a/configs/selfsup/simclr/r50_bs256_simclr_neck.py b/configs/selfsup/simclr/r50_bs256_simclr_neck.py new file mode 100644 index 00000000..9766bdf2 --- /dev/null +++ b/configs/selfsup/simclr/r50_bs256_simclr_neck.py @@ -0,0 +1,78 @@ +_base_ = '../../base.py' +# model settings +model = dict( + type='SimCLR', + pretrained=None, + backbone=dict( + type='ResNet', + depth=50, + in_channels=3, + out_indices=[4], # 0: conv-1, x: stage-x + norm_cfg=dict(type='SyncBN')), + neck=dict( + type='NonLinearNeckSimCLR', # SimCLR non-linear neck + in_channels=2048, + hid_channels=2048, + out_channels=128, + num_layers=2, + with_avg_pool=True), + head=dict(type='ContrastiveHead', temperature=0.1)) +# dataset settings +data_source_cfg = dict( + type='ImageNet', + memcached=True, + mclient_path='/mnt/lustre/share/memcached_client') +data_train_list = 'data/imagenet/meta/train.txt' +data_train_root = 'data/imagenet/train' +dataset_type = 'ContrastiveDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='RandomResizedCrop', size=224), + dict(type='RandomHorizontalFlip'), + dict( + type='RandomAppliedTrans', + transforms=[ + dict( + type='ColorJitter', + brightness=0.8, + contrast=0.8, + saturation=0.8, + hue=0.2) + ], + p=0.8), + dict(type='RandomGrayscale', p=0.2), + dict( + type='RandomAppliedTrans', + transforms=[ + dict( + type='GaussianBlur', + sigma_min=0.1, + sigma_max=2.0, + kernel_size=23) + ], + p=0.5), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), +] +data = dict( + imgs_per_gpu=32, # total 32*8 + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_source=dict( + list_file=data_train_list, root=data_train_root, + **data_source_cfg), + pipeline=train_pipeline)) +# optimizer +optimizer = dict(type='LARS', lr=0.3, weight_decay=0.000001, momentum=0.9) +# learning policy +lr_config = dict( + policy='CosineAnealing', + min_lr=0., + warmup='linear', + warmup_iters=10, + warmup_ratio=0.01, + warmup_by_epoch=True) +checkpoint_config = dict(interval=10) +# runtime settings +total_epochs = 200 diff --git a/openselfsup/models/necks.py b/openselfsup/models/necks.py index 7d509cad..2240e571 100644 --- a/openselfsup/models/necks.py +++ b/openselfsup/models/necks.py @@ -1,7 +1,10 @@ +import torch import torch.nn as nn +from distutils.version import StrictVersion from mmcv.cnn import kaiming_init, normal_init from .registry import NECKS +from .utils import build_norm_layer @NECKS.register_module @@ -80,7 +83,8 @@ class NonLinearNeckV0(nn.Module): @NECKS.register_module class NonLinearNeckV1(nn.Module): - + '''Simple non-linear neck: fc-relu-fc + ''' def __init__(self, in_channels, hid_channels, @@ -117,6 +121,97 @@ class NonLinearNeckV1(nn.Module): return [self.mlp(x.view(x.size(0), -1))] +@NECKS.register_module +class NonLinearNeckSimCLR(nn.Module): + '''SimCLR non-linear head. + Structure: fc(no_bias)-bn(has_bias)-[relu-fc(no_bias)-bn(no_bias)]. + The substructures in [] can be repeated. For the SimCLR default setting, + the repeat time is 1. + However, PyTorch does not support to specify (weight=True, bias=False). + It only support \"affine\" including the weight and bias. Hence, the + second BatchNorm has bias in this implementation. This is different from + the offical implementation of SimCLR. + Since SyncBatchNorm in pytorch<1.4.0 does not support 2D input, the input is + expanded to 4D with shape: (N,C,1,1). I am not sure if this workaround + has no bugs. See the pull request here: + https://github.com/pytorch/pytorch/pull/29626 + ''' + + def __init__(self, + in_channels, + hid_channels, + out_channels, + num_layers=2, + with_avg_pool=True): + super(NonLinearNeckSimCLR, self).__init__() + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if StrictVersion(torch.__version__) < StrictVersion("1.4.0"): + self.expand_for_syncbn = True + else: + self.expand_for_syncbn = False + + self.relu = nn.ReLU(inplace=True) + self.fc0 = nn.Linear(in_channels, hid_channels, bias=False) + _, self.bn0 = build_norm_layer( + dict(type='SyncBN'), hid_channels) + + self.fc_names = [] + self.bn_names = [] + for i in range(1, num_layers): + this_channels = out_channels if i == num_layers - 1 \ + else hid_channels + self.add_module( + "fc{}".format(i), + nn.Linear(hid_channels, this_channels, bias=False)) + self.add_module( + "bn{}".format(i), + build_norm_layer(dict(type='SyncBN'), this_channels)[1]) + self.fc_names.append("fc{}".format(i)) + self.bn_names.append("bn{}".format(i)) + + def init_weights(self, init_linear='normal'): + assert init_linear in ['normal', 'kaiming'], \ + "Undefined init_linear: {}".format(init_linear) + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_linear == 'normal': + normal_init(m, std=0.01) + else: + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, + (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _forward_syncbn(self, module, x): + assert x.dim() == 2 + if self.expand_for_syncbn: + x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1) + else: + x = module(x) + return x + + def forward(self, x): + assert len(x) == 1 + if self.with_avg_pool: + x = self.avgpool(x[0]) + x = x.view(x.size(0), -1) + x = self.fc0(x) + x = self._forward_syncbn(self.bn0, x) + for fc_name, bn_name in zip(self.fc_names, self.bn_names): + fc = getattr(self, fc_name) + bn = getattr(self, bn_name) + x = self.relu(x) + x = fc(x) + x = self._forward_syncbn(bn, x) + return [x] + + @NECKS.register_module class AvgPoolNeck(nn.Module): diff --git a/openselfsup/models/simclr.py b/openselfsup/models/simclr.py index 16ece396..c27d62b4 100644 --- a/openselfsup/models/simclr.py +++ b/openselfsup/models/simclr.py @@ -6,7 +6,6 @@ from openselfsup.utils import print_log from . import builder from .registry import MODELS from .utils import GatherLayer -import pdb @MODELS.register_module @@ -58,10 +57,10 @@ class SimCLR(nn.Module): s = torch.matmul(z, z.permute(1, 0)) # (2N)x(2N) mask, pos_ind, neg_mask = self._create_buffer(N) # remove diagonal, (2N)x(2N-1) - s = torch.masked_select(s, mask).reshape(s.size(0), -1) + s = torch.masked_select(s, mask == 1).reshape(s.size(0), -1) positive = s[pos_ind].unsqueeze(1) # (2N)x1 # select negative, (2N)x(2N-2) - negative = torch.masked_select(s, neg_mask).reshape(s.size(0), -1) + negative = torch.masked_select(s, neg_mask == 1).reshape(s.size(0), -1) losses = self.head(positive, negative) return losses