simclr non-linear neck with syncbn
parent
00c25405c4
commit
8455d606c8
|
@ -9,6 +9,11 @@ PY_ARGS=${@:3}
|
|||
GPUS=1 # in the standard setting, GPUS=1
|
||||
PORT=${PORT:-29500}
|
||||
|
||||
if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then
|
||||
echo "ERROR: Missing arguments."
|
||||
exit
|
||||
fi
|
||||
|
||||
WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)"
|
||||
|
||||
# train
|
||||
|
|
|
@ -8,6 +8,11 @@ PRETRAIN=$2
|
|||
PY_ARGS=${@:3}
|
||||
GPUS=8 # in the standard setting, GPUS=8
|
||||
|
||||
if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then
|
||||
echo "ERROR: Missing arguments."
|
||||
exit
|
||||
fi
|
||||
|
||||
WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)"
|
||||
|
||||
# train
|
||||
|
|
|
@ -8,6 +8,11 @@ FEAT_LIST=$3 # e.g.: "feat5", "feat4 feat5". If leave empty, the default is "fea
|
|||
GPUS=${4:-8}
|
||||
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
|
||||
|
||||
if [ "$CFG" == "" ] || [ "$EPOCH" == "" ]; then
|
||||
echo "ERROR: Missing arguments."
|
||||
exit
|
||||
fi
|
||||
|
||||
if [ ! -f $WORK_DIR/epoch_${EPOCH}.pth ]; then
|
||||
echo "ERROR: File not exist: $WORK_DIR/epoch_${EPOCH}.pth"
|
||||
exit
|
||||
|
|
|
@ -8,7 +8,12 @@ FEAT_LIST=$3 # e.g.: "feat5", "feat4 feat5". If leave empty, the default is "fea
|
|||
GPUS=${4:-8}
|
||||
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
|
||||
|
||||
if [ ! -f $PRETRAIN ] and [ "$PRETRAIN" != "random" ]; then
|
||||
if [ "$CFG" == "" ] || [ "$PRETRAIN" == "" ]; then
|
||||
echo "ERROR: Missing arguments."
|
||||
exit
|
||||
fi
|
||||
|
||||
if [ ! -f $PRETRAIN ] && [ "$PRETRAIN" != "random" ]; then
|
||||
echo "ERROR: PRETRAIN should be a file or a string \"random\", got: $PRETRAIN"
|
||||
exit
|
||||
fi
|
||||
|
|
|
@ -10,7 +10,7 @@ model = dict(
|
|||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='SyncBN')),
|
||||
neck=dict(
|
||||
type='NonLinearNeckV1',
|
||||
type='NonLinearNeckV1', # simple fc-relu-fc neck
|
||||
in_channels=2048,
|
||||
hid_channels=2048,
|
||||
out_channels=128,
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
_base_ = '../../base.py'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='SimCLR',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='SyncBN')),
|
||||
neck=dict(
|
||||
type='NonLinearNeckSimCLR', # SimCLR non-linear neck
|
||||
in_channels=2048,
|
||||
hid_channels=2048,
|
||||
out_channels=128,
|
||||
num_layers=2,
|
||||
with_avg_pool=True),
|
||||
head=dict(type='ContrastiveHead', temperature=0.1))
|
||||
# dataset settings
|
||||
data_source_cfg = dict(
|
||||
type='ImageNet',
|
||||
memcached=True,
|
||||
mclient_path='/mnt/lustre/share/memcached_client')
|
||||
data_train_list = 'data/imagenet/meta/train.txt'
|
||||
data_train_root = 'data/imagenet/train'
|
||||
dataset_type = 'ContrastiveDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
dict(
|
||||
type='RandomAppliedTrans',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.8,
|
||||
contrast=0.8,
|
||||
saturation=0.8,
|
||||
hue=0.2)
|
||||
],
|
||||
p=0.8),
|
||||
dict(type='RandomGrayscale', p=0.2),
|
||||
dict(
|
||||
type='RandomAppliedTrans',
|
||||
transforms=[
|
||||
dict(
|
||||
type='GaussianBlur',
|
||||
sigma_min=0.1,
|
||||
sigma_max=2.0,
|
||||
kernel_size=23)
|
||||
],
|
||||
p=0.5),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
data = dict(
|
||||
imgs_per_gpu=32, # total 32*8
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
list_file=data_train_list, root=data_train_root,
|
||||
**data_source_cfg),
|
||||
pipeline=train_pipeline))
|
||||
# optimizer
|
||||
optimizer = dict(type='LARS', lr=0.3, weight_decay=0.000001, momentum=0.9)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnealing',
|
||||
min_lr=0.,
|
||||
warmup='linear',
|
||||
warmup_iters=10,
|
||||
warmup_ratio=0.01,
|
||||
warmup_by_epoch=True)
|
||||
checkpoint_config = dict(interval=10)
|
||||
# runtime settings
|
||||
total_epochs = 200
|
|
@ -1,7 +1,10 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from distutils.version import StrictVersion
|
||||
from mmcv.cnn import kaiming_init, normal_init
|
||||
|
||||
from .registry import NECKS
|
||||
from .utils import build_norm_layer
|
||||
|
||||
|
||||
@NECKS.register_module
|
||||
|
@ -80,7 +83,8 @@ class NonLinearNeckV0(nn.Module):
|
|||
|
||||
@NECKS.register_module
|
||||
class NonLinearNeckV1(nn.Module):
|
||||
|
||||
'''Simple non-linear neck: fc-relu-fc
|
||||
'''
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
|
@ -117,6 +121,97 @@ class NonLinearNeckV1(nn.Module):
|
|||
return [self.mlp(x.view(x.size(0), -1))]
|
||||
|
||||
|
||||
@NECKS.register_module
|
||||
class NonLinearNeckSimCLR(nn.Module):
|
||||
'''SimCLR non-linear head.
|
||||
Structure: fc(no_bias)-bn(has_bias)-[relu-fc(no_bias)-bn(no_bias)].
|
||||
The substructures in [] can be repeated. For the SimCLR default setting,
|
||||
the repeat time is 1.
|
||||
However, PyTorch does not support to specify (weight=True, bias=False).
|
||||
It only support \"affine\" including the weight and bias. Hence, the
|
||||
second BatchNorm has bias in this implementation. This is different from
|
||||
the offical implementation of SimCLR.
|
||||
Since SyncBatchNorm in pytorch<1.4.0 does not support 2D input, the input is
|
||||
expanded to 4D with shape: (N,C,1,1). I am not sure if this workaround
|
||||
has no bugs. See the pull request here:
|
||||
https://github.com/pytorch/pytorch/pull/29626
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
out_channels,
|
||||
num_layers=2,
|
||||
with_avg_pool=True):
|
||||
super(NonLinearNeckSimCLR, self).__init__()
|
||||
self.with_avg_pool = with_avg_pool
|
||||
if with_avg_pool:
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
if StrictVersion(torch.__version__) < StrictVersion("1.4.0"):
|
||||
self.expand_for_syncbn = True
|
||||
else:
|
||||
self.expand_for_syncbn = False
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc0 = nn.Linear(in_channels, hid_channels, bias=False)
|
||||
_, self.bn0 = build_norm_layer(
|
||||
dict(type='SyncBN'), hid_channels)
|
||||
|
||||
self.fc_names = []
|
||||
self.bn_names = []
|
||||
for i in range(1, num_layers):
|
||||
this_channels = out_channels if i == num_layers - 1 \
|
||||
else hid_channels
|
||||
self.add_module(
|
||||
"fc{}".format(i),
|
||||
nn.Linear(hid_channels, this_channels, bias=False))
|
||||
self.add_module(
|
||||
"bn{}".format(i),
|
||||
build_norm_layer(dict(type='SyncBN'), this_channels)[1])
|
||||
self.fc_names.append("fc{}".format(i))
|
||||
self.bn_names.append("bn{}".format(i))
|
||||
|
||||
def init_weights(self, init_linear='normal'):
|
||||
assert init_linear in ['normal', 'kaiming'], \
|
||||
"Undefined init_linear: {}".format(init_linear)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
if init_linear == 'normal':
|
||||
normal_init(m, std=0.01)
|
||||
else:
|
||||
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||||
elif isinstance(m,
|
||||
(nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _forward_syncbn(self, module, x):
|
||||
assert x.dim() == 2
|
||||
if self.expand_for_syncbn:
|
||||
x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
|
||||
else:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x) == 1
|
||||
if self.with_avg_pool:
|
||||
x = self.avgpool(x[0])
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc0(x)
|
||||
x = self._forward_syncbn(self.bn0, x)
|
||||
for fc_name, bn_name in zip(self.fc_names, self.bn_names):
|
||||
fc = getattr(self, fc_name)
|
||||
bn = getattr(self, bn_name)
|
||||
x = self.relu(x)
|
||||
x = fc(x)
|
||||
x = self._forward_syncbn(bn, x)
|
||||
return [x]
|
||||
|
||||
|
||||
@NECKS.register_module
|
||||
class AvgPoolNeck(nn.Module):
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ from openselfsup.utils import print_log
|
|||
from . import builder
|
||||
from .registry import MODELS
|
||||
from .utils import GatherLayer
|
||||
import pdb
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
|
@ -58,10 +57,10 @@ class SimCLR(nn.Module):
|
|||
s = torch.matmul(z, z.permute(1, 0)) # (2N)x(2N)
|
||||
mask, pos_ind, neg_mask = self._create_buffer(N)
|
||||
# remove diagonal, (2N)x(2N-1)
|
||||
s = torch.masked_select(s, mask).reshape(s.size(0), -1)
|
||||
s = torch.masked_select(s, mask == 1).reshape(s.size(0), -1)
|
||||
positive = s[pos_ind].unsqueeze(1) # (2N)x1
|
||||
# select negative, (2N)x(2N-2)
|
||||
negative = torch.masked_select(s, neg_mask).reshape(s.size(0), -1)
|
||||
negative = torch.masked_select(s, neg_mask == 1).reshape(s.size(0), -1)
|
||||
losses = self.head(positive, negative)
|
||||
return losses
|
||||
|
||||
|
|
Loading…
Reference in New Issue