_base_ = [ '../_base_/datasets/imagenet_bs256_simmim_192.py', '../_base_/default_runtime.py', ] # model settings model = dict( type='SimMIM', backbone=dict( type='SimMIMSwinTransformer', arch='base', img_size=192, stage_cfgs=dict(block_cfgs=dict(window_size=6))), neck=dict( type='SimMIMLinearDecoder', in_channels=128 * 2**3, encoder_stride=32), head=dict( type='SimMIMHead', patch_size=4, loss=dict(type='PixelReconstructionLoss', criterion='L1', channel=3))) # optimizer wrapper optim_wrapper = dict( type='AmpOptimWrapper', optimizer=dict( type='AdamW', lr=1e-4 * 2048 / 512, betas=(0.9, 0.999), weight_decay=0.05), clip_grad=dict(max_norm=5.0), paramwise_cfg=dict( custom_keys={ 'norm': dict(decay_mult=0.0), 'bias': dict(decay_mult=0.0), 'absolute_pos_embed': dict(decay_mult=0.), 'relative_position_bias_table': dict(decay_mult=0.) })) # learning rate scheduler param_scheduler = [ dict( type='LinearLR', start_factor=5e-7 / 1e-4, by_epoch=True, begin=0, end=10, convert_to_iter_based=True), dict( type='MultiStepLR', milestones=[700], by_epoch=True, begin=10, end=800, convert_to_iter_based=True) ] # runtime train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800) default_hooks = dict( # only keeps the latest 3 checkpoints checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3)) # NOTE: `auto_scale_lr` is for automatically scaling LR # based on the actual training batch size. auto_scale_lr = dict(base_batch_size=2048)