[Docs] Add the document for the transition between IterBasedTraining and EpochBasedTraining (#926)

* Add epoch 2 iter

* Add epoch 2 iter

* Refine chinese docs

* Add example for training CIFAR10 by iter

* minor refine

* Fix as comment

* Fix as comment

* Refine description

* Fix as comment

* minor refine

* Refine description

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Translate to en

* Adjust indent

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/938/head
Mashiro 2023-02-21 21:12:38 +08:00 committed by GitHub
parent 3dc2be05d5
commit 346989464c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 693 additions and 5 deletions

View File

@ -0,0 +1,343 @@
# EpochBasedTraining to IterBasedTraining
Epoch-based training and iteration-based training are two commonly used training way in MMEngine. For example, downstream repositories like [MMDetection](https://github.com/open-mmlab/mmdetection) choose to train the model by epoch and [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) choose to train the model by iteration.
Many modules in MMEngine default to training models by epoch, such as `ParamScheduler`, `LoggerHook`, `CheckPointHook`, etc. Therefore, you need to adjust the configuration of these modules if you want to train by iteration. For example, a commonly used epoch based configuration is as follows:
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8]
by_epoch=True # by_epoch is True by default
)
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=True), # log_metric_by_epoch is True by default
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True), # by_epoch is True by default
)
train_cfg = dict(
by_epoch=True, # set by_epoch=True or type='EpochBasedTrainLoop'
max_epochs=10,
val_interval=2
)
log_processor = dict(
by_epoch=True
) # This is the default configuration, and just set it here for comparison.
runner = Runner(
model=ResNet18(),
work_dir='./work_dir',
# Assuming train_dataloader is configured with an epoch-based sampler
train_dataloader=train_dataloader_cfg,
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
param_scheduler=param_scheduler
default_hooks=default_hooks,
log_processor=log_processor,
train_cfg=train_cfg,
resume=True,
)
```
There are four steps to convert the above configuration to iteration based training:
1. Set `by_epoch` in `train_cfg` to False, and set `max_iters` to the total number of training iterations and `val_interval` to the interval between validation iterations.
```python
train_cfg = dict(
by_epoch=False,
max_iters=10000,
val_interval=2000
)
```
2. Set `log_metric_by_epoch` to `False` in logger and `by_epoch` to `False` in checkpoint.
```python
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)
```
3. Set `by_epoch` in param_scheduler to `False` and convert any epoch-related parameters to iteration.
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6000, 8000],
by_epoch=False,
)
```
Alternatively, if you can ensure that the total number of iterations for IterBasedTraining and EpochBasedTraining is the same, simply set `convert_to_iter_based` to True.
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8]
convert_to_iter_based=True
)
```
4. Set by_epoch in log_processor to False.
```python
log_processor = dict(
by_epoch=False
)
```
Take [training CIFAR10](../get_started/15_minutes.md) as an example:
<table class="docutils">
<thead>
<tr>
<th>Step</th>
<th>Training by epoch</th>
<th>Training by iteration</th>
<tbody>
<tr>
<td>Build model</td>
<td colspan="2"><div>
```python
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
```
</td>
</div>
</tr>
<tr>
<td>Build dataloader</td>
<td colspan="2">
```python
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(
batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])))
val_dataloader = DataLoader(
batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])))
```
</td>
</tr>
<tr>
<td>Prepare metric</td>
<td colspan="2">
```python
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# save the middle result of a batch to `self.results`
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# return the dict containing the eval results
# the key is the name of the metric name
return dict(accuracy=100 * total_correct / total_size)
```
</td>
</tr>
<tr>
<td>Configure default hooks</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=True),
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
)
```
</div>
</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)
```
</div>
</td>
</tr>
<tr>
<td>Configure parameter scheduler</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8]
by_epoch=True
)
```
</div>
</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6000, 8000],
by_epoch=False,
)
```
</div>
</td>
</tr>
<tr>
<td>Configure log_processor</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
# The default configuration of log_processor is used for epoch based training.
# Defining it here additionally is for building runner with the same way.
log_processor = dict(by_epoch=True)
```
</div>
</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
log_processor = dict(by_epoch=False)
```
</div>
</td>
</tr>
<tr>
<td>Configure train_cfg</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
train_cfg = dict(
by_epoch=True,
max_epochs=10,
val_interval=2
)
```
</div>
</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
train_cfg = dict(
by_epoch=False,
max_iters=10000,
val_interval=2000
)
```
</div>
</td>
</tr>
<tr>
<td>Build Runner</td>
<td colspan="2">
```python
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=train_cfg,
log_processor=log_processor,
default_hooks=default_hooks,
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
runner.train()
```
</td>
</tr>
</thead>
</table>
```{note}
If the base configuration file has configured a epoch/iteration based sampler for the train_dataloader, then it is necessary to change it to a specified type of sampler in the current configuration file, or set it to None. When the sampler in the dataloader is set to None, MMEngine will choose either the InfiniteSampler (when by_epoch is False) or the DefaultSampler (when by_epoch is True) according to the train_cfg parameter.
```
```{note}
If `type` is configured for the `train_cfg` in the base configuration, you must overwrite the type to target type (EpochBasedTrainLoop or IterBasedTrainLoop) rather than simply set `by_epoch` to True/False.
```

View File

@ -3,10 +3,11 @@
In this tutorial, we'll take training a ResNet-50 model on CIFAR-10 dataset as an example. We will build a complete and configurable pipeline for both training and validation in only 80 lines of code with `MMEgnine`.
The whole process includes the following steps:
1. [Build a Model](#build-a-model)
2. [Build a Dataset and DataLoader](#build-a-dataset-and-dataloader)
3. [Build a Evaluation Metrics](#build-a-evaluation-metrics)
4. [Build a Runner and Run the Task](#build-a-runner-and-run-the-task)
- [15 minutes to get started with MMEngine](#15-minutes-to-get-started-with-mmengine)
- [Build a Model](#build-a-model)
- [Build a Dataset and DataLoader](#build-a-dataset-and-dataloader)
- [Build a Evaluation Metrics](#build-a-evaluation-metrics)
- [Build a Runner and Run the Task](#build-a-runner-and-run-the-task)
## Build a Model

View File

@ -23,6 +23,7 @@ You can switch between Chinese and English documents in the lower-left corner of
common_usage/resume_training.md
common_usage/speed_up_training.md
common_usage/save_gpu_memory.md
common_usage/epoch_to_iter.md
.. toctree::
:maxdepth: 3

View File

@ -0,0 +1,342 @@
# 从 EpochBased 切换至 IterBased
MMEngine 支持两种训练模式,基于轮次的 EpochBased 方式和基于迭代次数的 IterBased 方式,这两种方式在下游算法库均有使用,例如 [MMDetection](https://github.com/open-mmlab/mmdetection) 默认使用 EpochBased 方式,[MMSegmentation](https://github.com/open-mmlab/mmsegmentation) 默认使用 IterBased 方式。
MMEngine 很多模块默认以 EpochBased 的模式执行,例如 `ParamScheduler`, `LoggerHook`, `CheckpointHook` 等,常见的 EpochBased 配置写法如下:
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8]
by_epoch=True # by_epoch 默认为 True这边显式的写出来只是为了方便对比
)
default_hooks = dict(
logger=dict(type='LoggerHook'),
checkpoint=dict(type='CheckpointHook', interval=2),
)
train_cfg = dict(
by_epoch=True, # by_epoch 默认为 True这边显式的写出来只是为了方便对比
max_epochs=10,
val_interval=2
)
log_processor = dict(
by_epoch=True
) # log_processor 的 by_epoch 默认为 True这边显式的写出来只是为了方便对比 实际上不需要设置
runner = Runner(
model=ResNet18(),
work_dir='./work_dir',
train_dataloader=train_dataloader_cfg,
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
param_scheduler=param_scheduler
default_hooks=default_hooks,
log_processor=log_processor,
train_cfg=train_cfg,
resume=True,
)
```
如果想按照 iter 训练模型,需要做以下改动:
1. 将 `train_cfg` 中的 `by_epoch` 设置为 `False`,同时将 `max_iters` 设置为训练的总 iter 数,`val_iterval` 设置为验证间隔的 iter 数。
```python
train_cfg = dict(
by_epoch=False,
max_iters=10000,
val_interval=2000
)
```
2. 将 `default_hooks` 中的 `logger``log_metric_by_epoch` 设置为 False `checkpoint``by_epoch` 设置为 `False`
```python
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)
```
3. 将 `param_scheduler` 中的 `by_epoch` 设置为 `False`,并將 `epoch` 相关的参数换算成 `iter`
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6000, 8000],
by_epoch=False,
)
```
除了这种方式,如果你能保证 IterBasedTraining 和 EpochBasedTraining 总 iter 数一致,直接设置 `convert_to_iter_based``True` 即可。
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8]
convert_to_iter_based=True
)
```
4. 将 `log_processor``by_epoch` 设置为 `False`
```python
log_processor = dict(
by_epoch=False
)
```
以 [15 分钟教程训练 CIFAR10 数据集](../get_started/15_minutes.md)为例:
<table class="docutils">
<thead>
<tr>
<th>Step</th>
<th>Training by epoch</th>
<th>Training by iteration</th>
<tbody>
<tr>
<td>Build model</td>
<td colspan="2"><div>
```python
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
```
</td>
</div>
</tr>
<tr>
<td>Build dataloader</td>
<td colspan="2">
```python
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(
batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])))
val_dataloader = DataLoader(
batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])))
```
</td>
</tr>
<tr>
<td>Prepare metric</td>
<td colspan="2">
```python
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# save the middle result of a batch to `self.results`
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# return the dict containing the eval results
# the key is the name of the metric name
return dict(accuracy=100 * total_correct / total_size)
```
</td>
</tr>
<tr>
<td>Configure default hooks</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=True),
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
)
```
</div>
</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)
```
</div>
</td>
</tr>
<tr>
<td>Configure parameter scheduler</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8],
by_epoch=True,
)
```
</div>
</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6000, 8000],
by_epoch=False,
)
```
</div>
</td>
</tr>
<tr>
<td>Configure log_processor</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
# The default configuration of log_processor is used for epoch based training.
# Defining it here additionally is for building runner with the same way.
log_processor = dict(by_epoch=True)
```
</div>
</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
log_processor = dict(by_epoch=False)
```
</div>
</td>
</tr>
<tr>
<td>Configure train_cfg</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
train_cfg = dict(
by_epoch=True,
max_epochs=10,
val_interval=2
)
```
</div>
</td>
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
<div style="overflow-x: auto">
```python
train_cfg = dict(
by_epoch=False,
max_iters=10000,
val_interval=2000
)
```
</div>
</td>
</tr>
<tr>
<td>Build Runner</td>
<td colspan="2">
```python
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=train_cfg,
log_processor=log_processor,
default_hooks=default_hooks,
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
runner.train()
```
</td>
</tr>
</thead>
</table>
```{note}
如果基础配置文件为 train_dataloader 配置了基于 iteration/epoch 采样的 sampler则需要在当前配置文件中将其更改为指定类型的 sampler或将其设置为 None。当 dataloader 中的 sampler 为 NoneMMEngine 或根据 train_cfg 中的 by_epoch 参数选择 `InfiniteSampler`False`DefaultSampler`True
```
```{note}
如果基础配置文件在 train_cfg 中指定了 type那么必须在当前配置文件中将 type 覆盖为IterBasedTrainLoop 或 EpochBasedTrainLoop而不能简单的指定 by_epoch 参数。
```

View File

@ -25,6 +25,7 @@
common_usage/save_gpu_memory.md
common_usage/set_random_seed.md
common_usage/set_interval.md
common_usage/epoch_to_iter.md
.. toctree::
:maxdepth: 3

View File

@ -170,7 +170,7 @@ class Runner:
as possible like seed and deterministic.
Defaults to ``dict(seed=None)``. If seed is None, a random number
will be generated and it will be broadcasted to all other processes
if in distributed environment. If ``cudnn_benchmarch`` is
if in distributed environment. If ``cudnn_benchmark`` is
``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
``randomness``, the value of ``torch.backends.cudnn.benchmark``
will be ``False`` finally.