mmpretrain/docs/en/user_guides/finetune.md

9.3 KiB
Raw Blame History

Fine-tune Models

In most scenarios, we want to apply a pre-trained model without training from scratch, which might possibly introduce extra uncertainties about the model convergency and therefore, is time-consuming. The common sense is to learn from previous models trained on large dataset, which can hopefully provide better knowledge than a random beginner. Roughly speaking, this process is as known as fine-tuning.

Models pre-trained on the ImageNet dataset have been demonstrated to be effective for other datasets and other downstream tasks. Hence, this tutorial provides instructions for users to use the models provided in the Model Zoo for other datasets to obtain better performance.

There are two steps to fine-tune a model on a new dataset.

  • Add support for the new dataset following Prepare Dataset.
  • Modify the configs as will be discussed in this tutorial.

Assume we have a ResNet-50 model pre-trained on the ImageNet-2012 dataset and want to fine-tune on the CIFAR-10 dataset, we need to modify five parts in the config.

Inherit base configs

At first, create a new config file configs/tutorial/resnet50_finetune_cifar.py to store our fine-tune configs. Of course, the path can be customized by yourself.

To reuse the common parts among different base configs, we support inheriting configs from multiple existing configs.Including following four parts

  • Model configs: To fine-tune a ResNet-50 model, the new config needs to inherit configs/_base_/models/resnet50.py to build the basic structure of the model.
  • Dataset configs: To use the CIFAR-10 dataset, the new config can simply inherit configs/_base_/datasets/cifar10_bs16.py.
  • Schedule configs: The new config can inherit _base_/schedules/cifar10_bs128.py for CIFAR-10 dataset with a batch size of 128.
  • Runtime configs: For runtime settings such as basic hooks, etc., the new config needs to inherit configs/_base_/default_runtime.py.

To inherit all configs above, put the following code at the config file.

_base_ = [
    '../_base_/models/resnet50.py',
    '../_base_/datasets/cifar10_bs16.py',
    '../_base_/schedules/cifar10_bs128.py',
    '../_base_/default_runtime.py',
]

Besides, you can also choose to write the whole contents rather than use inheritance. Refers to configs/lenet/lenet5_mnist.py for more details.

Specify pre-trained model in configs

When fine-tuning a model, usually we want to load the pre-trained backbone weights and train a new classification head from scratch.

To load the pre-trained backbone, we need to change the initialization config of the backbone and use Pretrained initialization function. Besides, in the init_cfg, we use prefix='backbone' to tell the initialization function the prefix of the submodule that needs to be loaded in the checkpoint.

For example, backbone here means to load the backbone submodule. And here we use an online checkpoint, it will be downloaded automatically during training, you can also download the model manually and use a local path. And then we need to modify the head according to the class numbers of the new datasets by just changing num_classes in the head.

model = dict(
    backbone=dict(
        init_cfg=dict(
            type='Pretrained',
            checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
            prefix='backbone',
        )),
    head=dict(num_classes=10),
)
Here we only need to set the part of configs we want to modify, because the
inherited configs will be merged and get the entire configs.

When new dataset is small and shares the domain with the pre-trained dataset, we might want to freeze the first several stages' parameters of the backbone, that will help the network to keep ability to extract low-level information learnt from pre-trained model. In MMPretrain, you can simply specify how many stages to freeze by frozen_stages argument. For example, to freeze the first two stages' parameters, just use the following configs:

model = dict(
    backbone=dict(
        frozen_stages=2,
        init_cfg=dict(
            type='Pretrained',
            checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
            prefix='backbone',
        )),
    head=dict(num_classes=10),
)
Not all backbones support the `frozen_stages` argument by now. Please check
[the docs](https://mmpretrain.readthedocs.io/en/main/api.html#module-mmpretrain.models.backbones)
to confirm if your backbone supports it.

Modify dataset configs

When fine-tuning on a new dataset, usually we need to modify some dataset configs. Here, we need to modify the pipeline to resize the image from 32 to 224 to fit the input size of the model pre-trained on ImageNet, and modify dataloaders correspondingly.

# data pipeline settings
train_pipeline = [
    dict(type='RandomCrop', crop_size=32, padding=4),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='Resize', scale=224),
    dict(type='PackInputs'),
]
test_pipeline = [
    dict(type='Resize', scale=224),
    dict(type='PackInputs'),
]
# dataloader settings
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

Modify training schedule configs

The fine-tuning hyper parameters vary from the default schedule. It usually requires smaller learning rate and quicker decaying scheduler epochs.

# lr is set for a batch size of 128
optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
# learning policy
param_scheduler = dict(
    type='MultiStepLR', by_epoch=True, milestones=[15], gamma=0.1)
Refers to [Learn about Configs](config.md) for more detailed configurations.

Start Training

Now, we have finished the fine-tuning config file as following:

_base_ = [
    '../_base_/models/resnet50.py',
    '../_base_/datasets/cifar10_bs16.py',
    '../_base_/schedules/cifar10_bs128.py',
    '../_base_/default_runtime.py',
]

# Model config
model = dict(
    backbone=dict(
        frozen_stages=2,
        init_cfg=dict(
            type='Pretrained',
            checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
            prefix='backbone',
        )),
    head=dict(num_classes=10),
)

# Dataset config
# data pipeline settings
train_pipeline = [
    dict(type='RandomCrop', crop_size=32, padding=4),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='Resize', scale=224),
    dict(type='PackInputs'),
]
test_pipeline = [
    dict(type='Resize', scale=224),
    dict(type='PackInputs'),
]
# dataloader settings
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

# Training schedule config
# lr is set for a batch size of 128
optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
# learning policy
param_scheduler = dict(
    type='MultiStepLR', by_epoch=True, milestones=[15], gamma=0.1)

Here we use 8 GPUs on your computer to train the model with the following command:

bash tools/dist_train.sh configs/tutorial/resnet50_finetune_cifar.py 8

Also, you can use only one GPU to train the model with the following command:

python tools/train.py configs/tutorial/resnet50_finetune_cifar.py

But wait, an important config need to be changed if using one GPU. We need to change the dataset config as following:

train_dataloader = dict(
    batch_size=128,
    dataset=dict(pipeline=train_pipeline),
)
val_dataloader = dict(
    batch_size=128,
    dataset=dict(pipeline=test_pipeline),
)
test_dataloader = val_dataloader

It's because our training schedule is for a batch size of 128. If using 8 GPUs, just use batch_size=16 config in the base config file for every GPU, and the total batch size will be 128. But if using one GPU, you need to change it to 128 manually to match the training schedule.

Apply pre-trained model with command line

If you don't want to modify the configs, you could use --cfg-options to add your pre-trained model path to init_cfg.

For example, the command below will also load pre-trained model.

bash tools/dist_train.sh configs/tutorial/resnet50_finetune_cifar.py 8 \
    --cfg-options model.backbone.init_cfg.type='Pretrained' \
    model.backbone.init_cfg.checkpoint='https://download.openmmlab.com/mmselfsup/1.x/mocov3/mocov3_resnet50_8xb512-amp-coslr-100e_in1k/mocov3_resnet50_8xb512-amp-coslr-100e_in1k_20220927-f1144efa.pth' \
    model.backbone.init_cfg.prefix='backbone' \