7.4 KiB
How to Pretrain with Custom Dataset
In this tutorial, we provide a practice example and some tips on how to train on your own dataset.
In MMPretrain, We support the CustomDataset
(similar to the ImageFolder
in torchvision
), which is able to read the images within the specified folder directly. You only need to prepare the path information of the custom dataset and edit the config.
Step-1: Prepare your dataset
Prepare your dataset following Prepare Dataset.
And the root folder of the dataset can be like data/custom_dataset/
.
Here, we assume you want to do unsupervised training, and use the sub-folder format CustomDataset
to
organize your dataset as:
data/custom_dataset/
├── sample1.png
├── sample2.png
├── sample3.png
├── sample4.png
└── ...
Step-2: Choose one config as template
Here, we would like to use configs/mae/mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
as the example. We
first copy this config file to the same folder and rename it as
mae_vit-base-p16_8xb512-amp-coslr-300e_custom.py
.
As a convention, the last field of the config name is the dataset, e.g.,`in1k` for ImageNet dataset, `coco` for COCO dataset
The content of this config is:
_base_ = [
'../_base_/models/mae_vit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.)
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=260,
by_epoch=True,
begin=40,
end=300,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)
Step-3: Edit the dataset related config
- Override the
type
of dataset settings as'CustomDataset'
- Override the
data_root
of dataset settings asdata/custom_dataset
. - Override the
ann_file
of dataset settings as an empty string since we assume you are using the sub-folder formatCustomDataset
. - Override the
data_prefix
of dataset settings as an empty string since we are using the whole dataset under thedata_root
, and you don't need to split samples into different subset and set thedata_prefix
.
The modified config will be like:
_base_ = [
'../_base_/models/mae_vit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# >>>>>>>>>>>>>>> Override dataset settings here >>>>>>>>>>>>>>>>>>>
train_dataloader = dict(
dataset=dict(
type='CustomDataset',
data_root='data/custom_dataset/',
ann_file='', # We assume you are using the sub-folder format without ann_file
data_prefix='', # The `data_root` is the data_prefix directly.
with_label=False,
)
)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.)
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=260,
by_epoch=True,
begin=40,
end=300,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)
By using the edited config file, you are able to train a self-supervised model with MAE algorithm on the custom dataset.
Another example: Train MAE on COCO Dataset
You need to install MMDetection to use the `mmdet.CocoDataset` follow this [documentation](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/get_started.md)
Follow the aforementioned idea, we also present an example of how to train MAE on COCO dataset. The edited file will be like this:
_base_ = [
'../_base_/models/mae_vit-base-p16.py',
'../_base_/datasets/imagenet_mae.py',
'../_base_/default_runtime.py',
]
# >>>>>>>>>>>>>>> Override dataset settings here >>>>>>>>>>>>>>>>>>>
train_dataloader = dict(
dataset=dict(
type='mmdet.CocoDataset',
data_root='data/coco/',
ann_file='annotations/instances_train2017.json', # Only for loading images, and the labels won't be used.
data_prefix=dict(img='train2017/'),
)
)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.)
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=260,
by_epoch=True,
begin=40,
end=300,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)