mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* Add epoch 2 iter * Add epoch 2 iter * Refine chinese docs * Add example for training CIFAR10 by iter * minor refine * Fix as comment * Fix as comment * Refine description * Fix as comment * minor refine * Refine description Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Translate to en * Adjust indent --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
8.5 KiB
8.5 KiB
从 EpochBased 切换至 IterBased
MMEngine 支持两种训练模式,基于轮次的 EpochBased 方式和基于迭代次数的 IterBased 方式,这两种方式在下游算法库均有使用,例如 MMDetection 默认使用 EpochBased 方式,MMSegmentation 默认使用 IterBased 方式。
MMEngine 很多模块默认以 EpochBased 的模式执行,例如 ParamScheduler, LoggerHook, CheckpointHook 等,常见的 EpochBased 配置写法如下:
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8]
by_epoch=True # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
)
default_hooks = dict(
logger=dict(type='LoggerHook'),
checkpoint=dict(type='CheckpointHook', interval=2),
)
train_cfg = dict(
by_epoch=True, # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
max_epochs=10,
val_interval=2
)
log_processor = dict(
by_epoch=True
) # log_processor 的 by_epoch 默认为 True,这边显式的写出来只是为了方便对比, 实际上不需要设置
runner = Runner(
model=ResNet18(),
work_dir='./work_dir',
train_dataloader=train_dataloader_cfg,
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
param_scheduler=param_scheduler
default_hooks=default_hooks,
log_processor=log_processor,
train_cfg=train_cfg,
resume=True,
)
如果想按照 iter 训练模型,需要做以下改动:
-
将
train_cfg中的by_epoch设置为False,同时将max_iters设置为训练的总 iter 数,val_iterval设置为验证间隔的 iter 数。train_cfg = dict( by_epoch=False, max_iters=10000, val_interval=2000 ) -
将
default_hooks中的logger的log_metric_by_epoch设置为 False,checkpoint的by_epoch设置为False。default_hooks = dict( logger=dict(type='LoggerHook', log_metric_by_epoch=False), checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), ) -
将
param_scheduler中的by_epoch设置为False,并將epoch相关的参数换算成iterparam_scheduler = dict( type='MultiStepLR', milestones=[6000, 8000], by_epoch=False, )除了这种方式,如果你能保证 IterBasedTraining 和 EpochBasedTraining 总 iter 数一致,直接设置
convert_to_iter_based为True即可。param_scheduler = dict( type='MultiStepLR', milestones=[6, 8] convert_to_iter_based=True ) -
将
log_processor的by_epoch设置为False。log_processor = dict( by_epoch=False )
| Step | Training by epoch | Training by iteration |
|---|---|---|
| Build model | |
|
| Build dataloader |
|
|
| Prepare metric |
|
|
| Configure default hooks |
|
|
| Configure parameter scheduler |
|
|
| Configure log_processor |
|
|
| Configure train_cfg |
|
|
| Build Runner |
|
|
如果基础配置文件为 train_dataloader 配置了基于 iteration/epoch 采样的 sampler,则需要在当前配置文件中将其更改为指定类型的 sampler,或将其设置为 None。当 dataloader 中的 sampler 为 None,MMEngine 或根据 train_cfg 中的 by_epoch 参数选择 `InfiniteSampler`(False) 或 `DefaultSampler`(True)。
如果基础配置文件在 train_cfg 中指定了 type,那么必须在当前配置文件中将 type 覆盖为(IterBasedTrainLoop 或 EpochBasedTrainLoop),而不能简单的指定 by_epoch 参数。