[Docs] Add the document for the transition between IterBasedTraining and EpochBasedTraining (#926)
* Add epoch 2 iter * Add epoch 2 iter * Refine chinese docs * Add example for training CIFAR10 by iter * minor refine * Fix as comment * Fix as comment * Refine description * Fix as comment * minor refine * Refine description Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Translate to en * Adjust indent --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/938/head
parent
3dc2be05d5
commit
346989464c
|
@ -0,0 +1,343 @@
|
|||
# EpochBasedTraining to IterBasedTraining
|
||||
|
||||
Epoch-based training and iteration-based training are two commonly used training way in MMEngine. For example, downstream repositories like [MMDetection](https://github.com/open-mmlab/mmdetection) choose to train the model by epoch and [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) choose to train the model by iteration.
|
||||
|
||||
Many modules in MMEngine default to training models by epoch, such as `ParamScheduler`, `LoggerHook`, `CheckPointHook`, etc. Therefore, you need to adjust the configuration of these modules if you want to train by iteration. For example, a commonly used epoch based configuration is as follows:
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6, 8]
|
||||
by_epoch=True # by_epoch is True by default
|
||||
)
|
||||
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', log_metric_by_epoch=True), # log_metric_by_epoch is True by default
|
||||
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True), # by_epoch is True by default
|
||||
)
|
||||
|
||||
train_cfg = dict(
|
||||
by_epoch=True, # set by_epoch=True or type='EpochBasedTrainLoop'
|
||||
max_epochs=10,
|
||||
val_interval=2
|
||||
)
|
||||
|
||||
log_processor = dict(
|
||||
by_epoch=True
|
||||
) # This is the default configuration, and just set it here for comparison.
|
||||
|
||||
runner = Runner(
|
||||
model=ResNet18(),
|
||||
work_dir='./work_dir',
|
||||
# Assuming train_dataloader is configured with an epoch-based sampler
|
||||
train_dataloader=train_dataloader_cfg,
|
||||
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
|
||||
param_scheduler=param_scheduler
|
||||
default_hooks=default_hooks,
|
||||
log_processor=log_processor,
|
||||
train_cfg=train_cfg,
|
||||
resume=True,
|
||||
)
|
||||
```
|
||||
|
||||
There are four steps to convert the above configuration to iteration based training:
|
||||
|
||||
1. Set `by_epoch` in `train_cfg` to False, and set `max_iters` to the total number of training iterations and `val_interval` to the interval between validation iterations.
|
||||
|
||||
```python
|
||||
train_cfg = dict(
|
||||
by_epoch=False,
|
||||
max_iters=10000,
|
||||
val_interval=2000
|
||||
)
|
||||
```
|
||||
|
||||
2. Set `log_metric_by_epoch` to `False` in logger and `by_epoch` to `False` in checkpoint.
|
||||
|
||||
```python
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
|
||||
)
|
||||
```
|
||||
|
||||
3. Set `by_epoch` in param_scheduler to `False` and convert any epoch-related parameters to iteration.
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6000, 8000],
|
||||
by_epoch=False,
|
||||
)
|
||||
```
|
||||
|
||||
Alternatively, if you can ensure that the total number of iterations for IterBasedTraining and EpochBasedTraining is the same, simply set `convert_to_iter_based` to True.
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6, 8]
|
||||
convert_to_iter_based=True
|
||||
)
|
||||
```
|
||||
|
||||
4. Set by_epoch in log_processor to False.
|
||||
|
||||
```python
|
||||
log_processor = dict(
|
||||
by_epoch=False
|
||||
)
|
||||
```
|
||||
|
||||
Take [training CIFAR10](../get_started/15_minutes.md) as an example:
|
||||
|
||||
<table class="docutils">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Step</th>
|
||||
<th>Training by epoch</th>
|
||||
<th>Training by iteration</th>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>Build model</td>
|
||||
<td colspan="2"><div>
|
||||
|
||||
```python
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
|
||||
class MMResNet50(BaseModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = torchvision.models.resnet50()
|
||||
|
||||
def forward(self, imgs, labels, mode):
|
||||
x = self.resnet(imgs)
|
||||
if mode == 'loss':
|
||||
return {'loss': F.cross_entropy(x, labels)}
|
||||
elif mode == 'predict':
|
||||
return x, labels
|
||||
```
|
||||
|
||||
</td>
|
||||
</div>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Build dataloader</td>
|
||||
|
||||
<td colspan="2">
|
||||
|
||||
```python
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
|
||||
train_dataloader = DataLoader(
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
dataset=torchvision.datasets.CIFAR10(
|
||||
'data/cifar10',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(**norm_cfg)])))
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
batch_size=32,
|
||||
shuffle=False,
|
||||
dataset=torchvision.datasets.CIFAR10(
|
||||
'data/cifar10',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(**norm_cfg)])))
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Prepare metric</td>
|
||||
<td colspan="2">
|
||||
|
||||
```python
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
class Accuracy(BaseMetric):
|
||||
def process(self, data_batch, data_samples):
|
||||
score, gt = data_samples
|
||||
# save the middle result of a batch to `self.results`
|
||||
self.results.append({
|
||||
'batch_size': len(gt),
|
||||
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
|
||||
})
|
||||
|
||||
def compute_metrics(self, results):
|
||||
total_correct = sum(item['correct'] for item in results)
|
||||
total_size = sum(item['batch_size'] for item in results)
|
||||
# return the dict containing the eval results
|
||||
# the key is the name of the metric name
|
||||
return dict(accuracy=100 * total_correct / total_size)
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Configure default hooks</td>
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', log_metric_by_epoch=True),
|
||||
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Configure parameter scheduler</td>
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6, 8]
|
||||
by_epoch=True
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6000, 8000],
|
||||
by_epoch=False,
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Configure log_processor</td>
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
# The default configuration of log_processor is used for epoch based training.
|
||||
# Defining it here additionally is for building runner with the same way.
|
||||
log_processor = dict(by_epoch=True)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
log_processor = dict(by_epoch=False)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Configure train_cfg</td>
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
train_cfg = dict(
|
||||
by_epoch=True,
|
||||
max_epochs=10,
|
||||
val_interval=2
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
train_cfg = dict(
|
||||
by_epoch=False,
|
||||
max_iters=10000,
|
||||
val_interval=2000
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Build Runner</td>
|
||||
<td colspan="2">
|
||||
|
||||
```python
|
||||
from torch.optim import SGD
|
||||
from mmengine.runner import Runner
|
||||
runner = Runner(
|
||||
model=MMResNet50(),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
|
||||
train_cfg=train_cfg,
|
||||
log_processor=log_processor,
|
||||
default_hooks=default_hooks,
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=Accuracy),
|
||||
)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</thead>
|
||||
</table>
|
||||
|
||||
```{note}
|
||||
If the base configuration file has configured a epoch/iteration based sampler for the train_dataloader, then it is necessary to change it to a specified type of sampler in the current configuration file, or set it to None. When the sampler in the dataloader is set to None, MMEngine will choose either the InfiniteSampler (when by_epoch is False) or the DefaultSampler (when by_epoch is True) according to the train_cfg parameter.
|
||||
```
|
||||
|
||||
```{note}
|
||||
If `type` is configured for the `train_cfg` in the base configuration, you must overwrite the type to target type (EpochBasedTrainLoop or IterBasedTrainLoop) rather than simply set `by_epoch` to True/False.
|
||||
```
|
|
@ -3,10 +3,11 @@
|
|||
In this tutorial, we'll take training a ResNet-50 model on CIFAR-10 dataset as an example. We will build a complete and configurable pipeline for both training and validation in only 80 lines of code with `MMEgnine`.
|
||||
The whole process includes the following steps:
|
||||
|
||||
1. [Build a Model](#build-a-model)
|
||||
2. [Build a Dataset and DataLoader](#build-a-dataset-and-dataloader)
|
||||
3. [Build a Evaluation Metrics](#build-a-evaluation-metrics)
|
||||
4. [Build a Runner and Run the Task](#build-a-runner-and-run-the-task)
|
||||
- [15 minutes to get started with MMEngine](#15-minutes-to-get-started-with-mmengine)
|
||||
- [Build a Model](#build-a-model)
|
||||
- [Build a Dataset and DataLoader](#build-a-dataset-and-dataloader)
|
||||
- [Build a Evaluation Metrics](#build-a-evaluation-metrics)
|
||||
- [Build a Runner and Run the Task](#build-a-runner-and-run-the-task)
|
||||
|
||||
## Build a Model
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ You can switch between Chinese and English documents in the lower-left corner of
|
|||
common_usage/resume_training.md
|
||||
common_usage/speed_up_training.md
|
||||
common_usage/save_gpu_memory.md
|
||||
common_usage/epoch_to_iter.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
|
|
@ -0,0 +1,342 @@
|
|||
# 从 EpochBased 切换至 IterBased
|
||||
|
||||
MMEngine 支持两种训练模式,基于轮次的 EpochBased 方式和基于迭代次数的 IterBased 方式,这两种方式在下游算法库均有使用,例如 [MMDetection](https://github.com/open-mmlab/mmdetection) 默认使用 EpochBased 方式,[MMSegmentation](https://github.com/open-mmlab/mmsegmentation) 默认使用 IterBased 方式。
|
||||
|
||||
MMEngine 很多模块默认以 EpochBased 的模式执行,例如 `ParamScheduler`, `LoggerHook`, `CheckpointHook` 等,常见的 EpochBased 配置写法如下:
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6, 8]
|
||||
by_epoch=True # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
|
||||
)
|
||||
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=2),
|
||||
)
|
||||
|
||||
train_cfg = dict(
|
||||
by_epoch=True, # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
|
||||
max_epochs=10,
|
||||
val_interval=2
|
||||
)
|
||||
|
||||
log_processor = dict(
|
||||
by_epoch=True
|
||||
) # log_processor 的 by_epoch 默认为 True,这边显式的写出来只是为了方便对比, 实际上不需要设置
|
||||
|
||||
runner = Runner(
|
||||
model=ResNet18(),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader_cfg,
|
||||
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
|
||||
param_scheduler=param_scheduler
|
||||
default_hooks=default_hooks,
|
||||
log_processor=log_processor,
|
||||
train_cfg=train_cfg,
|
||||
resume=True,
|
||||
)
|
||||
```
|
||||
|
||||
如果想按照 iter 训练模型,需要做以下改动:
|
||||
|
||||
1. 将 `train_cfg` 中的 `by_epoch` 设置为 `False`,同时将 `max_iters` 设置为训练的总 iter 数,`val_iterval` 设置为验证间隔的 iter 数。
|
||||
|
||||
```python
|
||||
train_cfg = dict(
|
||||
by_epoch=False,
|
||||
max_iters=10000,
|
||||
val_interval=2000
|
||||
)
|
||||
```
|
||||
|
||||
2. 将 `default_hooks` 中的 `logger` 的 `log_metric_by_epoch` 设置为 False, `checkpoint` 的 `by_epoch` 设置为 `False`。
|
||||
|
||||
```python
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
|
||||
)
|
||||
```
|
||||
|
||||
3. 将 `param_scheduler` 中的 `by_epoch` 设置为 `False`,并將 `epoch` 相关的参数换算成 `iter`
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6000, 8000],
|
||||
by_epoch=False,
|
||||
)
|
||||
```
|
||||
|
||||
除了这种方式,如果你能保证 IterBasedTraining 和 EpochBasedTraining 总 iter 数一致,直接设置 `convert_to_iter_based` 为 `True` 即可。
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6, 8]
|
||||
convert_to_iter_based=True
|
||||
)
|
||||
```
|
||||
|
||||
4. 将 `log_processor` 的 `by_epoch` 设置为 `False`。
|
||||
|
||||
```python
|
||||
log_processor = dict(
|
||||
by_epoch=False
|
||||
)
|
||||
```
|
||||
|
||||
以 [15 分钟教程训练 CIFAR10 数据集](../get_started/15_minutes.md)为例:
|
||||
|
||||
<table class="docutils">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Step</th>
|
||||
<th>Training by epoch</th>
|
||||
<th>Training by iteration</th>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>Build model</td>
|
||||
<td colspan="2"><div>
|
||||
|
||||
```python
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
|
||||
class MMResNet50(BaseModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = torchvision.models.resnet50()
|
||||
|
||||
def forward(self, imgs, labels, mode):
|
||||
x = self.resnet(imgs)
|
||||
if mode == 'loss':
|
||||
return {'loss': F.cross_entropy(x, labels)}
|
||||
elif mode == 'predict':
|
||||
return x, labels
|
||||
```
|
||||
|
||||
</td>
|
||||
</div>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Build dataloader</td>
|
||||
|
||||
<td colspan="2">
|
||||
|
||||
```python
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
|
||||
train_dataloader = DataLoader(
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
dataset=torchvision.datasets.CIFAR10(
|
||||
'data/cifar10',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(**norm_cfg)])))
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
batch_size=32,
|
||||
shuffle=False,
|
||||
dataset=torchvision.datasets.CIFAR10(
|
||||
'data/cifar10',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(**norm_cfg)])))
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Prepare metric</td>
|
||||
<td colspan="2">
|
||||
|
||||
```python
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
class Accuracy(BaseMetric):
|
||||
def process(self, data_batch, data_samples):
|
||||
score, gt = data_samples
|
||||
# save the middle result of a batch to `self.results`
|
||||
self.results.append({
|
||||
'batch_size': len(gt),
|
||||
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
|
||||
})
|
||||
|
||||
def compute_metrics(self, results):
|
||||
total_correct = sum(item['correct'] for item in results)
|
||||
total_size = sum(item['batch_size'] for item in results)
|
||||
# return the dict containing the eval results
|
||||
# the key is the name of the metric name
|
||||
return dict(accuracy=100 * total_correct / total_size)
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Configure default hooks</td>
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', log_metric_by_epoch=True),
|
||||
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Configure parameter scheduler</td>
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6, 8],
|
||||
by_epoch=True,
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[6000, 8000],
|
||||
by_epoch=False,
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Configure log_processor</td>
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
# The default configuration of log_processor is used for epoch based training.
|
||||
# Defining it here additionally is for building runner with the same way.
|
||||
log_processor = dict(by_epoch=True)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
log_processor = dict(by_epoch=False)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Configure train_cfg</td>
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
train_cfg = dict(
|
||||
by_epoch=True,
|
||||
max_epochs=10,
|
||||
val_interval=2
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td valign="top" class='two-column-table-wrapper' width="50%" colspan="1">
|
||||
<div style="overflow-x: auto">
|
||||
|
||||
```python
|
||||
train_cfg = dict(
|
||||
by_epoch=False,
|
||||
max_iters=10000,
|
||||
val_interval=2000
|
||||
)
|
||||
```
|
||||
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Build Runner</td>
|
||||
<td colspan="2">
|
||||
|
||||
```python
|
||||
from torch.optim import SGD
|
||||
from mmengine.runner import Runner
|
||||
runner = Runner(
|
||||
model=MMResNet50(),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
|
||||
train_cfg=train_cfg,
|
||||
log_processor=log_processor,
|
||||
default_hooks=default_hooks,
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=Accuracy),
|
||||
)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</thead>
|
||||
</table>
|
||||
|
||||
```{note}
|
||||
如果基础配置文件为 train_dataloader 配置了基于 iteration/epoch 采样的 sampler,则需要在当前配置文件中将其更改为指定类型的 sampler,或将其设置为 None。当 dataloader 中的 sampler 为 None,MMEngine 或根据 train_cfg 中的 by_epoch 参数选择 `InfiniteSampler`(False) 或 `DefaultSampler`(True)。
|
||||
```
|
||||
|
||||
```{note}
|
||||
如果基础配置文件在 train_cfg 中指定了 type,那么必须在当前配置文件中将 type 覆盖为(IterBasedTrainLoop 或 EpochBasedTrainLoop),而不能简单的指定 by_epoch 参数。
|
||||
```
|
|
@ -25,6 +25,7 @@
|
|||
common_usage/save_gpu_memory.md
|
||||
common_usage/set_random_seed.md
|
||||
common_usage/set_interval.md
|
||||
common_usage/epoch_to_iter.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
|
|
@ -170,7 +170,7 @@ class Runner:
|
|||
as possible like seed and deterministic.
|
||||
Defaults to ``dict(seed=None)``. If seed is None, a random number
|
||||
will be generated and it will be broadcasted to all other processes
|
||||
if in distributed environment. If ``cudnn_benchmarch`` is
|
||||
if in distributed environment. If ``cudnn_benchmark`` is
|
||||
``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
|
||||
``randomness``, the value of ``torch.backends.cudnn.benchmark``
|
||||
will be ``False`` finally.
|
||||
|
|
Loading…
Reference in New Issue