mmpretrain/projects/dino/config/dino_vit-base-p16_8xb64-amp...

105 lines
3.1 KiB
Python

model = dict(
type='DINO',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='mmpretrain.VisionTransformer', arch='b', patch_size=16),
neck=dict(
type='DINONeck',
in_channels=768,
out_channels=65536,
hidden_channels=2048,
bottleneck_channels=256),
head=dict(
type='DINOHead',
out_channels=65536,
num_crops=10,
student_temp=0.1,
center_momentum=0.9))
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='DINOMultiCrop',
global_crops_scale=(0.4, 1.0),
local_crops_scale=(0.05, 0.4),
local_crops_number=8),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=32,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type='mmpretrain.ImageNet',
data_root='/data/imagenet/',
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline,
))
optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05),
paramwise_cfg=dict(
custom_keys=dict(
ln=dict(decay_mult=0.0),
bias=dict(decay_mult=0.0),
pos_embed=dict(decay_mult=0.0),
mask_token=dict(decay_mult=0.0),
cls_token=dict(decay_mult=0.0))),
loss_scale='dynamic')
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-09,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=90,
by_epoch=True,
begin=10,
end=100,
convert_to_iter_based=True)
]
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)
default_scope = 'mmpretrain'
default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
sampler_seed=dict(type='DistSamplerSeedHook'))
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
log_processor = dict(
window_size=10,
custom_cfg=[dict(data_src='', method='mean', window_size='global')])
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='UniversalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
log_level = 'INFO'
load_from = None
resume = True
randomness = dict(seed=2, diff_rank_seed=True)
custom_hooks = [
dict(
type='DINOTeacherTempWarmupHook',
warmup_teacher_temp=0.04,
teacher_temp=0.04,
teacher_temp_warmup_epochs=0,
max_epochs=100)
]