[Fix] fix config of maskfeat (#1424)

pull/1447/head
Yixiao Fang 2023-03-30 11:45:18 +08:00 committed by GitHub
parent 445eb3223a
commit 9fb4e9c911
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 8 deletions

View File

@ -43,15 +43,12 @@ train_dataloader = dict(
# model settings
model = dict(
type='MaskFeat',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
backbone=dict(type='MaskFeatViT', arch='b', patch_size=16),
neck=dict(
type='LinearNeck',
in_channels=768,
out_channels=108,
norm_cfg=None,
init_cfg=dict(type='TruncNormal', layer='Linear', std=0.02, bias=0)),
head=dict(
type='MIMHead',
@ -67,13 +64,13 @@ optim_wrapper = dict(
type='AdamW', lr=2e-4 * 8, betas=(0.9, 0.999), weight_decay=0.05),
clip_grad=dict(max_norm=0.02),
paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
# commented 'pos_embed' and 'cls_token' to avoid loss stuck situation
norm_decay_mult=0.0,
flat_decay_mult=0.0,
custom_keys={
# 'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
# 'cls_token': dict(decay_mult=0.)
# 'cls_token': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.)
}))
# learning rate scheduler
@ -88,6 +85,7 @@ param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=270,
eta_min=1e-6,
by_epoch=True,
begin=30,
end=300,