2022-08-29 10:18:38 +08:00
# Fine-tune Models
2020-07-08 12:59:15 +08:00
2022-08-29 10:18:38 +08:00
In most scenarios, we want to apply a model on new datasets without training from scratch, which might possibly introduce extra uncertainties about the model convergency and therefore, is time-consuming.
The common sense is to learn from previous models trained on large dataset, which can hopefully provide better knowledge than a random beginner. Roughly speaking, this process is as known as fine-tuning.
2021-09-28 18:05:50 +08:00
Classification models pre-trained on the ImageNet dataset have been demonstrated to be effective for other datasets and other downstream tasks.
2022-09-13 15:06:17 +08:00
Hence, this tutorial provides instructions for users to use the models provided in the [Model Zoo ](../modelzoo_statistics.md ) for other datasets to obtain better performance.
2020-07-08 12:59:15 +08:00
2021-09-28 18:05:50 +08:00
There are two steps to fine-tune a model on a new dataset.
2020-12-02 19:42:45 +08:00
2022-08-29 10:18:38 +08:00
- Add support for the new dataset following [Prepare Dataset ](dataset_prepare.md ).
2020-07-08 12:59:15 +08:00
- Modify the configs as will be discussed in this tutorial.
2021-09-28 18:05:50 +08:00
Assume we have a ResNet-50 model pre-trained on the ImageNet-2012 dataset and want
2022-08-29 10:18:38 +08:00
to fine-tune on the CIFAR-10 dataset, we need to modify five parts in the config.
2020-07-08 12:59:15 +08:00
## Inherit base configs
2020-12-02 19:42:45 +08:00
2021-09-28 18:05:50 +08:00
At first, create a new config file
2022-08-29 10:18:38 +08:00
`configs/tutorial/resnet50_finetune_cifar.py` to store our fine-tune configs. Of course,
2021-09-28 18:05:50 +08:00
the path can be customized by yourself.
2022-08-29 10:18:38 +08:00
To reuse the common parts among different base configs, we support inheriting
configs from multiple existing configs.Including following four parts:
2021-09-28 18:05:50 +08:00
2022-08-29 10:18:38 +08:00
- Model configs: To fine-tune a ResNet-50 model, the new
config needs to inherit `configs/_base_/models/resnet50.py` to build the basic structure of the model.
- Dataset configs: To use the CIFAR-10 dataset, the new config can simply
inherit `configs/_base_/datasets/cifar10_bs16.py` .
- Schedule configs: The new config can inherit `_base_/schedules/cifar10_bs128.py`
for CIFAR-10 dataset with a batch size of 128.
- Runtime configs: For runtime settings such as basic hooks, etc.,
the new config needs to inherit `configs/_base_/default_runtime.py` .
To inherit all configs above, put the following code at the config file.
2020-07-08 12:59:15 +08:00
```python
_base_ = [
'../_base_/models/resnet50.py',
2022-08-29 10:18:38 +08:00
'../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py',
'../_base_/default_runtime.py',
2020-07-08 12:59:15 +08:00
]
```
2020-12-02 19:42:45 +08:00
2022-08-29 10:18:38 +08:00
Besides, you can also choose to write the whole contents rather than use inheritance.
Refers to [`configs/lenet/lenet5_mnist.py` ](https://github.com/open-mmlab/mmclassification/blob/master/configs/lenet/lenet5_mnist.py ) for more details.
2021-09-28 18:05:50 +08:00
2022-08-29 10:18:38 +08:00
## Modify model configs
2021-09-28 18:05:50 +08:00
When fine-tuning a model, usually we want to load the pre-trained backbone
2022-08-29 10:18:38 +08:00
weights and train a new classification head from scratch.
2020-07-08 12:59:15 +08:00
2021-09-28 18:05:50 +08:00
To load the pre-trained backbone, we need to change the initialization config
of the backbone and use `Pretrained` initialization function. Besides, in the
2022-08-29 10:18:38 +08:00
`init_cfg` , we use `prefix='backbone'` to tell the initialization function
the prefix of the submodule that needs to be loaded in the checkpoint.
For example, `backbone` here means to load the backbone submodule. And here we
use an online checkpoint, it will be downloaded automatically during training,
you can also download the model manually and use a local path.
2020-12-02 19:42:45 +08:00
2021-09-28 18:05:50 +08:00
And then we need to modify the head according to the class numbers of the new
datasets by just changing `num_classes` in the head.
2020-07-08 12:59:15 +08:00
```python
model = dict(
2021-09-28 18:05:50 +08:00
backbone=dict(
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
prefix='backbone',
)),
head=dict(num_classes=10),
)
```
```{tip}
Here we only need to set the part of configs we want to modify, because the
inherited configs will be merged and get the entire configs.
```
2022-08-29 10:18:38 +08:00
When new dataset is small and shares the domain with the pre-trained dataset,
we might want to freeze the first several stages' parameters of the
2021-09-28 18:05:50 +08:00
backbone, that will help the network to keep ability to extract low-level
information learnt from pre-trained model. In MMClassification, you can simply
2022-08-29 10:18:38 +08:00
specify how many stages to freeze by `frozen_stages` argument. For example, to
freeze the first two stages' parameters, just use the following configs:
2021-09-28 18:05:50 +08:00
```python
model = dict(
backbone=dict(
frozen_stages=2,
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
prefix='backbone',
)),
head=dict(num_classes=10),
)
```
```{note}
Not all backbones support the `frozen_stages` argument by now. Please check
2023-03-02 13:29:07 +08:00
[the docs ](https://mmclassification.readthedocs.io/en/1.x/api.html#module-mmpretrain.models.backbones )
2021-09-28 18:05:50 +08:00
to confirm if your backbone supports it.
2020-07-08 12:59:15 +08:00
```
2022-08-29 10:18:38 +08:00
## Modify dataset configs
2020-12-02 19:42:45 +08:00
2021-09-28 18:05:50 +08:00
When fine-tuning on a new dataset, usually we need to modify some dataset
configs. Here, we need to modify the pipeline to resize the image from 32 to
2022-08-29 10:18:38 +08:00
224 to fit the input size of the model pre-trained on ImageNet, and modify
dataloaders correspondingly.
2020-12-02 19:42:45 +08:00
2020-07-08 12:59:15 +08:00
```python
2022-08-29 10:18:38 +08:00
# data pipeline settings
2020-07-08 12:59:15 +08:00
train_pipeline = [
2022-08-29 10:18:38 +08:00
dict(type='RandomCrop', crop_size=32, padding=4),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
2022-05-19 00:48:59 +08:00
dict(type='Resize', scale=224),
2023-03-03 15:01:11 +08:00
dict(type='PackInputs'),
2021-09-28 18:05:50 +08:00
]
test_pipeline = [
2022-05-19 00:48:59 +08:00
dict(type='Resize', scale=224),
2023-03-03 15:01:11 +08:00
dict(type='PackInputs'),
2021-09-28 18:05:50 +08:00
]
2022-08-29 10:18:38 +08:00
# dataloader settings
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
2020-07-08 12:59:15 +08:00
```
2022-08-29 10:18:38 +08:00
## Modify training schedule configs
2020-12-02 19:42:45 +08:00
2021-09-28 18:05:50 +08:00
The fine-tuning hyper parameters vary from the default schedule. It usually
2022-08-29 10:18:38 +08:00
requires smaller learning rate and quicker decaying scheduler epochs.
2021-09-28 18:05:50 +08:00
```python
# lr is set for a batch size of 128
2022-08-29 10:18:38 +08:00
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
2021-09-28 18:05:50 +08:00
# learning policy
2022-08-29 10:18:38 +08:00
param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[15], gamma=0.1)
```
```{tip}
Refers to [Learn about Configs ](config.md ) for more detailed configurations.
2021-09-28 18:05:50 +08:00
```
## Start Training
Now, we have finished the fine-tuning config file as following:
2020-07-08 12:59:15 +08:00
```python
2021-09-28 18:05:50 +08:00
_base_ = [
'../_base_/models/resnet50.py',
2022-08-29 10:18:38 +08:00
'../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py',
'../_base_/default_runtime.py',
2021-09-28 18:05:50 +08:00
]
# Model config
model = dict(
backbone=dict(
frozen_stages=2,
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
prefix='backbone',
)),
head=dict(num_classes=10),
)
# Dataset config
2022-08-29 10:18:38 +08:00
# data pipeline settings
2021-09-28 18:05:50 +08:00
train_pipeline = [
2022-08-29 10:18:38 +08:00
dict(type='RandomCrop', crop_size=32, padding=4),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
2022-05-19 00:48:59 +08:00
dict(type='Resize', scale=224),
2023-03-03 15:01:11 +08:00
dict(type='PackInputs'),
2021-09-28 18:05:50 +08:00
]
test_pipeline = [
2022-05-19 00:48:59 +08:00
dict(type='Resize', scale=224),
2023-03-03 15:01:11 +08:00
dict(type='PackInputs'),
2021-09-28 18:05:50 +08:00
]
2022-08-29 10:18:38 +08:00
# dataloader settings
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
2021-09-28 18:05:50 +08:00
# Training schedule config
2020-07-08 12:59:15 +08:00
# lr is set for a batch size of 128
2022-08-29 10:18:38 +08:00
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
2020-07-08 12:59:15 +08:00
# learning policy
2022-08-29 10:18:38 +08:00
param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[15], gamma=0.1)
2020-07-08 12:59:15 +08:00
```
2022-08-29 10:18:38 +08:00
Here we use 8 GPUs on your computer to train the model with the following command:
2020-12-02 19:42:45 +08:00
2021-09-28 18:05:50 +08:00
```shell
bash tools/dist_train.sh configs/tutorial/resnet50_finetune_cifar.py 8
```
Also, you can use only one GPU to train the model with the following command:
```shell
python tools/train.py configs/tutorial/resnet50_finetune_cifar.py
```
But wait, an important config need to be changed if using one GPU. We need to
change the dataset config as following:
2020-07-08 12:59:15 +08:00
```python
2022-08-29 10:18:38 +08:00
train_dataloader = dict(
batch_size=128,
dataset=dict(pipeline=train_pipeline),
)
val_dataloader = dict(
batch_size=128,
dataset=dict(pipeline=test_pipeline),
2021-09-28 18:05:50 +08:00
)
2022-08-29 10:18:38 +08:00
test_dataloader = val_dataloader
2020-07-08 12:59:15 +08:00
```
2021-09-28 18:05:50 +08:00
It's because our training schedule is for a batch size of 128. If using 8 GPUs,
2022-08-29 10:18:38 +08:00
just use `batch_size=16` config in the base config file for every GPU, and the total batch
2021-09-28 18:05:50 +08:00
size will be 128. But if using one GPU, you need to change it to 128 manually to
match the training schedule.