[Fix]: Refine simmim-224-ft config (#473)
parent
0c30969d6f
commit
776d64b7cd
|
@ -5,44 +5,32 @@ model = dict(
|
|||
backbone=dict(
|
||||
img_size=224, stage_cfgs=dict(block_cfgs=dict(window_size=7))))
|
||||
|
||||
# dataset
|
||||
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
|
||||
preprocess_cfg = dict(
|
||||
pixel_mean=[123.675, 116.28, 103.53],
|
||||
pixel_std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = preprocess_cfg['pixel_mean'][::-1]
|
||||
bgr_std = preprocess_cfg['pixel_std'][::-1]
|
||||
|
||||
# train pipeline
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='mmcls.RandomResizedCrop',
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='mmcls.RandAugment',
|
||||
policies={{_base_.rand_increasing_policies}},
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
|
||||
dict(
|
||||
type='mmcls.RandomErasing',
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackSelfSupInputs', algorithm_keys=['gt_label']),
|
||||
max_area_ratio=0.3333333333333333,
|
||||
fill_color=[103.53, 116.28, 123.675],
|
||||
fill_std=[57.375, 57.12, 58.395]),
|
||||
dict(type='PackClsInputs')
|
||||
]
|
||||
|
||||
# test pipeline
|
||||
|
@ -55,12 +43,8 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackSelfSupInputs', algorithm_keys=['gt_label']),
|
||||
dict(type='PackClsInputs')
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=256,
|
||||
drop_last=False,
|
||||
workers_per_gpu=32,
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline))
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
|
|
Loading…
Reference in New Issue