* Add epoch 2 iter * Add epoch 2 iter * Refine chinese docs * Add example for training CIFAR10 by iter * minor refine * Fix as comment * Fix as comment * Refine description * Fix as comment * minor refine * Refine description Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Translate to en * Adjust indent --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
8.7 KiB
EpochBasedTraining to IterBasedTraining
Epoch-based training and iteration-based training are two commonly used training way in MMEngine. For example, downstream repositories like MMDetection choose to train the model by epoch and MMSegmentation choose to train the model by iteration.
Many modules in MMEngine default to training models by epoch, such as ParamScheduler, LoggerHook, CheckPointHook, etc. Therefore, you need to adjust the configuration of these modules if you want to train by iteration. For example, a commonly used epoch based configuration is as follows:
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8]
by_epoch=True # by_epoch is True by default
)
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=True), # log_metric_by_epoch is True by default
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True), # by_epoch is True by default
)
train_cfg = dict(
by_epoch=True, # set by_epoch=True or type='EpochBasedTrainLoop'
max_epochs=10,
val_interval=2
)
log_processor = dict(
by_epoch=True
) # This is the default configuration, and just set it here for comparison.
runner = Runner(
model=ResNet18(),
work_dir='./work_dir',
# Assuming train_dataloader is configured with an epoch-based sampler
train_dataloader=train_dataloader_cfg,
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
param_scheduler=param_scheduler
default_hooks=default_hooks,
log_processor=log_processor,
train_cfg=train_cfg,
resume=True,
)
There are four steps to convert the above configuration to iteration based training:
-
Set
by_epochintrain_cfgto False, and setmax_itersto the total number of training iterations andval_intervalto the interval between validation iterations.train_cfg = dict( by_epoch=False, max_iters=10000, val_interval=2000 ) -
Set
log_metric_by_epochtoFalsein logger andby_epochtoFalsein checkpoint.default_hooks = dict( logger=dict(type='LoggerHook', log_metric_by_epoch=False), checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), ) -
Set
by_epochin param_scheduler toFalseand convert any epoch-related parameters to iteration.param_scheduler = dict( type='MultiStepLR', milestones=[6000, 8000], by_epoch=False, )Alternatively, if you can ensure that the total number of iterations for IterBasedTraining and EpochBasedTraining is the same, simply set
convert_to_iter_basedto True.param_scheduler = dict( type='MultiStepLR', milestones=[6, 8] convert_to_iter_based=True ) -
Set by_epoch in log_processor to False.
log_processor = dict( by_epoch=False )
Take training CIFAR10 as an example:
| Step | Training by epoch | Training by iteration |
|---|---|---|
| Build model | |
|
| Build dataloader |
|
|
| Prepare metric |
|
|
| Configure default hooks |
|
|
| Configure parameter scheduler |
|
|
| Configure log_processor |
|
|
| Configure train_cfg |
|
|
| Build Runner |
|
|
If the base configuration file has configured a epoch/iteration based sampler for the train_dataloader, then it is necessary to change it to a specified type of sampler in the current configuration file, or set it to None. When the sampler in the dataloader is set to None, MMEngine will choose either the InfiniteSampler (when by_epoch is False) or the DefaultSampler (when by_epoch is True) according to the train_cfg parameter.
If `type` is configured for the `train_cfg` in the base configuration, you must overwrite the type to target type (EpochBasedTrainLoop or IterBasedTrainLoop) rather than simply set `by_epoch` to True/False.