[Fix]: Refine simmim-224-ft config (#473)

pull/508/head
Yuan Liu 2022-09-07 15:18:27 +08:00 committed by GitHub
parent 0c30969d6f
commit 776d64b7cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 28 deletions

View File

@ -5,44 +5,32 @@ model = dict(
backbone=dict(
img_size=224, stage_cfgs=dict(block_cfgs=dict(window_size=7))))
# dataset
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
preprocess_cfg = dict(
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
to_rgb=True,
)
bgr_mean = preprocess_cfg['pixel_mean'][::-1]
bgr_std = preprocess_cfg['pixel_std'][::-1]
# train pipeline
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='mmcls.RandomResizedCrop',
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='mmcls.RandAugment',
policies={{_base_.rand_increasing_policies}},
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='mmcls.RandomErasing',
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackSelfSupInputs', algorithm_keys=['gt_label']),
max_area_ratio=0.3333333333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackClsInputs')
]
# test pipeline
@ -55,12 +43,8 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackSelfSupInputs', algorithm_keys=['gt_label']),
dict(type='PackClsInputs')
]
data = dict(
samples_per_gpu=256,
drop_last=False,
workers_per_gpu=32,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline))
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))