133 lines
4.1 KiB
Python
133 lines
4.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# This is a BETA new format config file, and the usage may change recently.
|
|
from mmengine.config import read_base
|
|
|
|
with read_base():
|
|
from ..._base_.datasets.imagenet_bs64_swin_224 import *
|
|
from ..._base_.schedules.imagenet_bs1024_adamw_swin import *
|
|
from ..._base_.default_runtime import *
|
|
|
|
from mmengine.model import PretrainedInit, TruncNormalInit
|
|
from mmengine.optim import CosineAnnealingLR, LinearLR
|
|
from torch.optim import AdamW
|
|
|
|
from mmpretrain.engine.optimizers import \
|
|
LearningRateDecayOptimWrapperConstructor
|
|
from mmpretrain.models import (BEiTViT, ImageClassifier, LabelSmoothLoss,
|
|
LinearClsHead)
|
|
from mmpretrain.models.utils.batch_augments import CutMix, Mixup
|
|
|
|
# model settings
|
|
model = dict(
|
|
type=ImageClassifier,
|
|
backbone=dict(
|
|
type=BEiTViT,
|
|
arch='base',
|
|
img_size=224,
|
|
patch_size=16,
|
|
# 0.2 for 1600 epochs pretrained models and 0.1 for 300 epochs.
|
|
drop_path_rate=0.1,
|
|
out_type='avg_featmap',
|
|
use_abs_pos_emb=False,
|
|
use_rel_pos_bias=True,
|
|
use_shared_rel_pos_bias=False,
|
|
init_cfg=dict(type=PretrainedInit, checkpoint='', prefix='backbone.')),
|
|
neck=None,
|
|
head=dict(
|
|
type=LinearClsHead,
|
|
num_classes=1000,
|
|
in_channels=768,
|
|
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
|
|
init_cfg=[dict(type=TruncNormalInit, layer='Linear', std=0.02)]),
|
|
train_cfg=dict(
|
|
augments=[dict(type=Mixup, alpha=0.8),
|
|
dict(type=CutMix, alpha=1.0)]))
|
|
|
|
train_pipeline = [
|
|
dict(type=LoadImageFromFile),
|
|
dict(
|
|
type=RandomResizedCrop,
|
|
scale=224,
|
|
backend='pillow',
|
|
interpolation='bicubic'),
|
|
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
|
dict(
|
|
type=RandAugment,
|
|
policies='timm_increasing',
|
|
num_policies=2,
|
|
total_level=10,
|
|
magnitude_level=9,
|
|
magnitude_std=0.5,
|
|
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
|
|
dict(
|
|
type=RandomErasing,
|
|
erase_prob=0.25,
|
|
mode='rand',
|
|
min_area_ratio=0.02,
|
|
max_area_ratio=0.3333333333333333,
|
|
fill_color=[103.53, 116.28, 123.675],
|
|
fill_std=[57.375, 57.12, 58.395]),
|
|
dict(type=PackInputs)
|
|
]
|
|
test_pipeline = [
|
|
dict(type=LoadImageFromFile),
|
|
dict(
|
|
type=ResizeEdge,
|
|
scale=256,
|
|
edge='short',
|
|
backend='pillow',
|
|
interpolation='bicubic'),
|
|
dict(type=CenterCrop, crop_size=224),
|
|
dict(type=PackInputs)
|
|
]
|
|
|
|
train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
|
|
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
|
|
test_dataloader = val_dataloader
|
|
|
|
# optimizer wrapper
|
|
optim_wrapper = dict(
|
|
optimizer=dict(type=AdamW, lr=5e-4, weight_decay=0.05, betas=(0.9, 0.999)),
|
|
constructor=LearningRateDecayOptimWrapperConstructor,
|
|
paramwise_cfg=dict(
|
|
_delete_=True,
|
|
# 0.6 for 1600 epochs pretrained models and 0.65 for 300 epochs
|
|
layer_decay_rate=0.65,
|
|
custom_keys={
|
|
# the following configurations are designed for BEiT
|
|
'.ln': dict(decay_mult=0.0),
|
|
'.bias': dict(decay_mult=0.0),
|
|
'q_bias': dict(decay_mult=0.0),
|
|
'v_bias': dict(decay_mult=0.0),
|
|
'.cls_token': dict(decay_mult=0.0),
|
|
'.pos_embed': dict(decay_mult=0.0),
|
|
'.gamma': dict(decay_mult=0.0),
|
|
}))
|
|
|
|
# learning rate scheduler
|
|
param_scheduler = [
|
|
dict(
|
|
type=LinearLR,
|
|
start_factor=1e-4,
|
|
by_epoch=True,
|
|
begin=0,
|
|
end=20,
|
|
convert_to_iter_based=True),
|
|
dict(
|
|
type=CosineAnnealingLR,
|
|
by_epoch=True,
|
|
begin=20,
|
|
end=100,
|
|
eta_min=1e-6,
|
|
convert_to_iter_based=True)
|
|
]
|
|
|
|
# runtime settings
|
|
default_hooks = dict(
|
|
# save checkpoint per epoch.
|
|
checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=2))
|
|
|
|
train_cfg = dict(by_epoch=True, max_epochs=100)
|
|
|
|
randomness = dict(seed=0)
|