[Fix] fix config of maskfeat (#1424)

pull/1447/head
Yixiao Fang 2023-03-30 11:45:18 +08:00 committed by GitHub
parent 445eb3223a
commit 9fb4e9c911
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 8 deletions

View File

@ -43,15 +43,12 @@ train_dataloader = dict(
# model settings # model settings
model = dict( model = dict(
type='MaskFeat', type='MaskFeat',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
backbone=dict(type='MaskFeatViT', arch='b', patch_size=16), backbone=dict(type='MaskFeatViT', arch='b', patch_size=16),
neck=dict( neck=dict(
type='LinearNeck', type='LinearNeck',
in_channels=768, in_channels=768,
out_channels=108, out_channels=108,
norm_cfg=None,
init_cfg=dict(type='TruncNormal', layer='Linear', std=0.02, bias=0)), init_cfg=dict(type='TruncNormal', layer='Linear', std=0.02, bias=0)),
head=dict( head=dict(
type='MIMHead', type='MIMHead',
@ -67,13 +64,13 @@ optim_wrapper = dict(
type='AdamW', lr=2e-4 * 8, betas=(0.9, 0.999), weight_decay=0.05), type='AdamW', lr=2e-4 * 8, betas=(0.9, 0.999), weight_decay=0.05),
clip_grad=dict(max_norm=0.02), clip_grad=dict(max_norm=0.02),
paramwise_cfg=dict( paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0, bias_decay_mult=0.0,
# commented 'pos_embed' and 'cls_token' to avoid loss stuck situation norm_decay_mult=0.0,
flat_decay_mult=0.0,
custom_keys={ custom_keys={
# 'pos_embed': dict(decay_mult=0.), # 'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.), # 'cls_token': dict(decay_mult=0.),
# 'cls_token': dict(decay_mult=0.) 'mask_token': dict(decay_mult=0.)
})) }))
# learning rate scheduler # learning rate scheduler
@ -88,6 +85,7 @@ param_scheduler = [
dict( dict(
type='CosineAnnealingLR', type='CosineAnnealingLR',
T_max=270, T_max=270,
eta_min=1e-6,
by_epoch=True, by_epoch=True,
begin=30, begin=30,
end=300, end=300,