mmpretrain/configs/swav/swav_resnet50_8xb32-mcrop-c...

160 lines
4.0 KiB
Python
Raw Normal View History

_base_ = [
'../_base_/schedules/imagenet_lars_coslr_200e.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='SelfSupDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
num_crops = [2, 6]
color_distort_strength = 1.0
view_pipeline1 = [
dict(
type='RandomResizedCrop',
scale=224,
crop_ratio_range=(0.14, 1.),
backend='pillow'),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.8 * color_distort_strength,
contrast=0.8 * color_distort_strength,
saturation=0.8 * color_distort_strength,
hue=0.2 * color_distort_strength)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(
type='GaussianBlur',
magnitude_range=(0.1, 2.0),
magnitude_std='inf',
prob=0.5),
dict(type='RandomFlip', prob=0.5),
]
view_pipeline2 = [
dict(
type='RandomResizedCrop',
scale=96,
crop_ratio_range=(0.05, 0.14),
backend='pillow'),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.8 * color_distort_strength,
contrast=0.8 * color_distort_strength,
saturation=0.8 * color_distort_strength,
hue=0.2 * color_distort_strength)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(
type='GaussianBlur',
magnitude_range=(0.1, 2.0),
magnitude_std='inf',
prob=0.5),
dict(type='RandomFlip', prob=0.5),
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiView',
num_views=num_crops,
transforms=[view_pipeline1, view_pipeline2]),
dict(type='PackInputs')
]
batch_size = 32
train_dataloader = dict(
batch_size=batch_size,
num_workers=8,
drop_last=True,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))
# model settings
model = dict(
type='SwAV',
data_preprocessor=dict(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
norm_cfg=dict(type='SyncBN'),
zero_init_residual=True),
neck=dict(
type='SwAVNeck',
in_channels=2048,
hid_channels=2048,
out_channels=128,
with_avg_pool=True),
head=dict(
type='SwAVHead',
loss=dict(
type='SwAVLoss',
feat_dim=128, # equal to neck['out_channels']
epsilon=0.05,
temperature=0.1,
num_crops=num_crops,
)))
# optimizer
optim_wrapper = dict(type='OptimWrapper', optimizer=dict(type='LARS', lr=0.6))
find_unused_parameters = True
# learning policy
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=200,
eta_min=6e-4,
by_epoch=True,
begin=0,
end=200,
convert_to_iter_based=True)
]
# runtime settings
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
# additional hooks
custom_hooks = [
dict(
type='SwAVHook',
priority='VERY_HIGH',
batch_size=batch_size,
epoch_queue_starts=15,
crops_for_assign=[0, 1],
feat_dim=128,
queue_length=3840,
frozen_layers_cfg=dict(prototypes=5005))
]