simclr non-linear neck with syncbn

pull/5/head
xiaohangzhan 2020-06-18 00:37:23 +08:00
parent 00c25405c4
commit 8455d606c8
8 changed files with 198 additions and 6 deletions

View File

@ -9,6 +9,11 @@ PY_ARGS=${@:3}
GPUS=1 # in the standard setting, GPUS=1
PORT=${PORT:-29500}
if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then
echo "ERROR: Missing arguments."
exit
fi
WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)"
# train

View File

@ -8,6 +8,11 @@ PRETRAIN=$2
PY_ARGS=${@:3}
GPUS=8 # in the standard setting, GPUS=8
if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then
echo "ERROR: Missing arguments."
exit
fi
WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)"
# train

View File

@ -8,6 +8,11 @@ FEAT_LIST=$3 # e.g.: "feat5", "feat4 feat5". If leave empty, the default is "fea
GPUS=${4:-8}
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
if [ "$CFG" == "" ] || [ "$EPOCH" == "" ]; then
echo "ERROR: Missing arguments."
exit
fi
if [ ! -f $WORK_DIR/epoch_${EPOCH}.pth ]; then
echo "ERROR: File not exist: $WORK_DIR/epoch_${EPOCH}.pth"
exit

View File

@ -8,7 +8,12 @@ FEAT_LIST=$3 # e.g.: "feat5", "feat4 feat5". If leave empty, the default is "fea
GPUS=${4:-8}
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
if [ ! -f $PRETRAIN ] and [ "$PRETRAIN" != "random" ]; then
if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then
echo "ERROR: Missing arguments."
exit
fi
if [ ! -f $PRETRAIN ] && [ "$PRETRAIN" != "random" ]; then
echo "ERROR: PRETRAIN should be a file or a string \"random\", got: $PRETRAIN"
exit
fi

View File

@ -10,7 +10,7 @@ model = dict(
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='SyncBN')),
neck=dict(
type='NonLinearNeckV1',
type='NonLinearNeckV1', # simple fc-relu-fc neck
in_channels=2048,
hid_channels=2048,
out_channels=128,

View File

@ -0,0 +1,78 @@
_base_ = '../../base.py'
# model settings
model = dict(
type='SimCLR',
pretrained=None,
backbone=dict(
type='ResNet',
depth=50,
in_channels=3,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='SyncBN')),
neck=dict(
type='NonLinearNeckSimCLR', # SimCLR non-linear neck
in_channels=2048,
hid_channels=2048,
out_channels=128,
num_layers=2,
with_avg_pool=True),
head=dict(type='ContrastiveHead', temperature=0.1))
# dataset settings
data_source_cfg = dict(
type='ImageNet',
memcached=True,
mclient_path='/mnt/lustre/share/memcached_client')
data_train_list = 'data/imagenet/meta/train.txt'
data_train_root = 'data/imagenet/train'
dataset_type = 'ContrastiveDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='RandomResizedCrop', size=224),
dict(type='RandomHorizontalFlip'),
dict(
type='RandomAppliedTrans',
transforms=[
dict(
type='ColorJitter',
brightness=0.8,
contrast=0.8,
saturation=0.8,
hue=0.2)
],
p=0.8),
dict(type='RandomGrayscale', p=0.2),
dict(
type='RandomAppliedTrans',
transforms=[
dict(
type='GaussianBlur',
sigma_min=0.1,
sigma_max=2.0,
kernel_size=23)
],
p=0.5),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]
data = dict(
imgs_per_gpu=32, # total 32*8
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_source=dict(
list_file=data_train_list, root=data_train_root,
**data_source_cfg),
pipeline=train_pipeline))
# optimizer
optimizer = dict(type='LARS', lr=0.3, weight_decay=0.000001, momentum=0.9)
# learning policy
lr_config = dict(
policy='CosineAnealing',
min_lr=0.,
warmup='linear',
warmup_iters=10,
warmup_ratio=0.01,
warmup_by_epoch=True)
checkpoint_config = dict(interval=10)
# runtime settings
total_epochs = 200

View File

@ -1,7 +1,10 @@
import torch
import torch.nn as nn
from distutils.version import StrictVersion
from mmcv.cnn import kaiming_init, normal_init
from .registry import NECKS
from .utils import build_norm_layer
@NECKS.register_module
@ -80,7 +83,8 @@ class NonLinearNeckV0(nn.Module):
@NECKS.register_module
class NonLinearNeckV1(nn.Module):
'''Simple non-linear neck: fc-relu-fc
'''
def __init__(self,
in_channels,
hid_channels,
@ -117,6 +121,97 @@ class NonLinearNeckV1(nn.Module):
return [self.mlp(x.view(x.size(0), -1))]
@NECKS.register_module
class NonLinearNeckSimCLR(nn.Module):
'''SimCLR non-linear head.
Structure: fc(no_bias)-bn(has_bias)-[relu-fc(no_bias)-bn(no_bias)].
The substructures in [] can be repeated. For the SimCLR default setting,
the repeat time is 1.
However, PyTorch does not support to specify (weight=True, bias=False).
It only support \"affine\" including the weight and bias. Hence, the
second BatchNorm has bias in this implementation. This is different from
the offical implementation of SimCLR.
Since SyncBatchNorm in pytorch<1.4.0 does not support 2D input, the input is
expanded to 4D with shape: (N,C,1,1). I am not sure if this workaround
has no bugs. See the pull request here:
https://github.com/pytorch/pytorch/pull/29626
'''
def __init__(self,
in_channels,
hid_channels,
out_channels,
num_layers=2,
with_avg_pool=True):
super(NonLinearNeckSimCLR, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
if StrictVersion(torch.__version__) < StrictVersion("1.4.0"):
self.expand_for_syncbn = True
else:
self.expand_for_syncbn = False
self.relu = nn.ReLU(inplace=True)
self.fc0 = nn.Linear(in_channels, hid_channels, bias=False)
_, self.bn0 = build_norm_layer(
dict(type='SyncBN'), hid_channels)
self.fc_names = []
self.bn_names = []
for i in range(1, num_layers):
this_channels = out_channels if i == num_layers - 1 \
else hid_channels
self.add_module(
"fc{}".format(i),
nn.Linear(hid_channels, this_channels, bias=False))
self.add_module(
"bn{}".format(i),
build_norm_layer(dict(type='SyncBN'), this_channels)[1])
self.fc_names.append("fc{}".format(i))
self.bn_names.append("bn{}".format(i))
def init_weights(self, init_linear='normal'):
assert init_linear in ['normal', 'kaiming'], \
"Undefined init_linear: {}".format(init_linear)
for m in self.modules():
if isinstance(m, nn.Linear):
if init_linear == 'normal':
normal_init(m, std=0.01)
else:
kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m,
(nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _forward_syncbn(self, module, x):
assert x.dim() == 2
if self.expand_for_syncbn:
x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
else:
x = module(x)
return x
def forward(self, x):
assert len(x) == 1
if self.with_avg_pool:
x = self.avgpool(x[0])
x = x.view(x.size(0), -1)
x = self.fc0(x)
x = self._forward_syncbn(self.bn0, x)
for fc_name, bn_name in zip(self.fc_names, self.bn_names):
fc = getattr(self, fc_name)
bn = getattr(self, bn_name)
x = self.relu(x)
x = fc(x)
x = self._forward_syncbn(bn, x)
return [x]
@NECKS.register_module
class AvgPoolNeck(nn.Module):

View File

@ -6,7 +6,6 @@ from openselfsup.utils import print_log
from . import builder
from .registry import MODELS
from .utils import GatherLayer
import pdb
@MODELS.register_module
@ -58,10 +57,10 @@ class SimCLR(nn.Module):
s = torch.matmul(z, z.permute(1, 0)) # (2N)x(2N)
mask, pos_ind, neg_mask = self._create_buffer(N)
# remove diagonal, (2N)x(2N-1)
s = torch.masked_select(s, mask).reshape(s.size(0), -1)
s = torch.masked_select(s, mask == 1).reshape(s.size(0), -1)
positive = s[pos_ind].unsqueeze(1) # (2N)x1
# select negative, (2N)x(2N-2)
negative = torch.masked_select(s, neg_mask).reshape(s.size(0), -1)
negative = torch.masked_select(s, neg_mask == 1).reshape(s.size(0), -1)
losses = self.head(positive, negative)
return losses