mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Doc] Refine runner doc. (#178)
* [Doc] Refine runner doc. * resolve comments
This commit is contained in:
parent
55713207b0
commit
ecf816e1e9
@ -91,17 +91,18 @@ runner.train()
|
||||
model = FasterRCNN()
|
||||
test_dataset = CocoDataset()
|
||||
test_dataloader = Dataloader(dataset=test_dataset, batch_size=2, num_workers=2)
|
||||
evaluator = CocoEvaluator(metric='bbox')
|
||||
metric = CocoMetric()
|
||||
test_evaluator = Evaluator(metric)
|
||||
|
||||
# 初始化执行器
|
||||
runner = Runner(model=model, test_dataloader=test_dataloader, evaluator=evaluator,
|
||||
load_checkpoint='./faster_rcnn.pth')
|
||||
runner = Runner(model=model, test_dataloader=test_dataloader, test_evaluator=test_evaluator,
|
||||
load_from='./faster_rcnn.pth')
|
||||
|
||||
# 执行测试
|
||||
runner.test()
|
||||
```
|
||||
|
||||
这个例子中我们手动构建了一个 Faster R-CNN 检测模型,以及测试用的 COCO 数据集和对应的 COCO 评测器,并使用这些模块初始化执行器,最后通过调用执行器的 `test` 函数进行模型测试。
|
||||
这个例子中我们手动构建了一个 Faster R-CNN 检测模型,以及测试用的 COCO 数据集和使用 COCO 指标的评测器,并使用这些模块初始化执行器,最后通过调用执行器的 `test` 函数进行模型测试。
|
||||
|
||||
### 通过配置文件使用执行器
|
||||
|
||||
@ -146,12 +147,13 @@ test_dataloader = ...
|
||||
optimizer = dict(type='SGD', lr=0.01)
|
||||
# 参数调度器配置
|
||||
param_scheduler = dict(type='MultiStepLR', milestones=[80, 90])
|
||||
#评测器配置
|
||||
evaluator = dict(type='Accuracy')
|
||||
#验证和测试的评测器配置
|
||||
val_evaluator = dict(type='Accuracy')
|
||||
test_evaluator = dict(type='Accuracy')
|
||||
|
||||
# 训练、验证、测试流程配置
|
||||
train_cfg = dict(by_epoch=True, max_epochs=100)
|
||||
validation_cfg = dict(interval=1) # 每隔一个 epoch 进行一次验证
|
||||
val_cfg = dict(interval=1) # 每隔一个 epoch 进行一次验证
|
||||
test_cfg = dict()
|
||||
|
||||
# 自定义钩子
|
||||
@ -163,20 +165,40 @@ default_hooks = dict(
|
||||
checkpoint=dict(type='CheckpointHook', interval=1), # 模型保存钩子
|
||||
logger=dict(type='TextLoggerHook'), # 训练日志钩子
|
||||
optimizer=dict(type='OptimzierHook', grad_clip=False), # 优化器钩子
|
||||
param_scheduler=dict(type='ParamSchedulerHook')) # 参数调度器执行钩子
|
||||
param_scheduler=dict(type='ParamSchedulerHook'), # 参数调度器执行钩子
|
||||
sampler_seed=dict(type='DistSamplerSeedHook')) # 为每轮次的数据采样设置随机种子的钩子
|
||||
|
||||
# 环境配置
|
||||
env_cfg = dict(
|
||||
dist_params=dict(backend='nccl'),
|
||||
cudnn_benchmark=False,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork')
|
||||
)
|
||||
# 系统日志配置
|
||||
log_cfg = dict(log_level='INFO')
|
||||
# 日志等级配置
|
||||
log_level = 'INFO'
|
||||
|
||||
# 加载权重
|
||||
load_from = None
|
||||
# 恢复训练
|
||||
resume = False
|
||||
```
|
||||
|
||||
一个完整的配置文件主要由模型、数据、优化器、参数调度器、评测器等模块的配置,训练、验证、测试等流程的配置,还有执行流程过程中的各种钩子模块的配置,以及环境和日志等其他配置的字段组成。
|
||||
通过配置文件构建的执行器采用了懒初始化 (lazy initialization),只有当调用到训练或测试等执行函数时,才会根据配置文件去完整初始化所需要的模块。
|
||||
|
||||
## 加载权重或恢复训练
|
||||
|
||||
执行器可以通过 `load_from` 参数加载检查点(checkpoint)文件中的模型权重,只需要将 `load_from` 参数设置为检查点文件的路径即可。
|
||||
|
||||
```python
|
||||
runner = Runner(model=model, test_dataloader=test_dataloader, test_evaluator=test_evaluator,
|
||||
load_from='./faster_rcnn.pth')
|
||||
```
|
||||
|
||||
如果是通过配置文件使用执行器,只需修改配置文件中的 `load_from` 字段即可。
|
||||
|
||||
用户也可通过设置 `resume=True` 来,加载检查点中的训练状态信息来恢复训练。当 `load_from` 和 `resume=True` 同时被设置时,执行器将加载 `load_from` 路径对应的检查点文件中的训练状态。如果仅设置 `resume=True`,执行器将会尝试从 `work_dir` 文件夹中寻找并读取最新的检查点文件。
|
||||
|
||||
## 进阶使用
|
||||
|
||||
MMEngine 中的默认执行器能够完成大部分的深度学习任务,但不可避免会存在无法满足的情况。有的用户希望能够对执行器进行更多自定义修改,因此,MMEngine 支持自定义模型的训练、验证以及测试的流程。
|
||||
@ -195,48 +217,68 @@ MMEngine 内提供了四种默认的循环:
|
||||
|
||||
用户可以通过继承循环基类来实现自己的训练流程。循环基类需要提供两个输入:`runner` 执行器的实例和 `loader` 循环所需要迭代的迭代器。
|
||||
用户如果有自定义的需求,也可以增加更多的输入参数。MMEngine 中同样提供了 LOOPS 注册器对循环类进行管理,用户可以向注册器内注册自定义的循环模块,
|
||||
然后在配置文件的 `train_cfg`、`validation_cfg`、`test_cfg` 中增加 `type` 字段来指定使用何种循环。
|
||||
然后在配置文件的 `train_cfg`、`val_cfg`、`test_cfg` 中增加 `type` 字段来指定使用何种循环。
|
||||
用户可以在自定义的循环中实现任意的执行逻辑,也可以增加或删减钩子(hook)点位,但需要注意的是一旦钩子点位被修改,默认的钩子函数可能不会被执行,导致一些训练过程中默认发生的行为发生变化。
|
||||
因此,我们强烈建议用户按照本文档中定义的循环执行流程图以及[钩子规范](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/hook.html) 去重载循环基类。
|
||||
|
||||
```python
|
||||
from mmengine.registry import LOOPS
|
||||
from mmengine.registry import LOOPS, HOOKS
|
||||
from mmengine.runner.loop import BaseLoop
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
|
||||
# 自定义验证循环
|
||||
@LOOPS.register_module()
|
||||
class CustomValLoop(BaseLoop):
|
||||
def __init__(self, runner, loader, evaluator, loader2):
|
||||
super().__init__(runner, loader, evaluator)
|
||||
self.loader2 = runner.build_dataloader(loader2)
|
||||
def __init__(self, runner, dataloader, evaluator, dataloader2):
|
||||
super().__init__(runner, dataloader, evaluator)
|
||||
self.dataloader2 = runner.build_dataloader(dataloader2)
|
||||
|
||||
def run(self):
|
||||
self.runner.call_hooks('before_val_epoch')
|
||||
for idx, databatch in enumerate(self.loader):
|
||||
self.runner.call_hooks('before_val_iter',
|
||||
args=dict(databatch=databatch))
|
||||
outputs = self.run_iter(idx, databatch)
|
||||
self.runner.call_hooks('after_val_iter',
|
||||
args=dict(databatch=databatch, outputs=outputs))
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.runner.call_hooks(
|
||||
'before_val_iter', batch_idx=idx, data_batch=data_batch)
|
||||
outputs = self.run_iter(idx, data_batch)
|
||||
self.runner.call_hooks(
|
||||
'after_val_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
|
||||
metric = self.evaluator.evaluate()
|
||||
for idx, databatch in enumerate(self.loader2):
|
||||
self.runner.call_hooks('before_val_iter2',
|
||||
args=dict(databatch=databatch))
|
||||
self.run_iter(idx, databatch)
|
||||
self.runner.call_hooks('after_val_iter2',
|
||||
args=dict(databatch=databatch, outputs=outputs))
|
||||
|
||||
# 增加额外的验证循环
|
||||
for idx, data_batch in enumerate(self.dataloader2):
|
||||
# 增加额外的钩子点位
|
||||
self.runner.call_hooks(
|
||||
'before_valloader2_iter', batch_idx=idx, data_batch=data_batch)
|
||||
self.run_iter(idx, data_batch)
|
||||
# 增加额外的钩子点位
|
||||
self.runner.call_hooks(
|
||||
'after_valloader2_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
|
||||
metric2 = self.evaluator.evaluate()
|
||||
|
||||
...
|
||||
|
||||
self.runner.call_hooks('after_val_epoch')
|
||||
|
||||
|
||||
# 定义额外点位的钩子类
|
||||
@HOOKS.register_module()
|
||||
class CustomValHook(Hook):
|
||||
def before_valloader2_iter(self, batch_idx, data_batch):
|
||||
...
|
||||
|
||||
def after_valloader2_iter(self, batch_idx, data_batch, outputs):
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
上面的例子中实现了一个与默认验证循环不一样的自定义验证循环,它在两个不同的验证集上进行验证,同时对第二次验证增加了额外的钩子点位,并在最后对两个验证结果进行进一步的处理。在实现了自定义的循环类之后,
|
||||
只需要在配置文件的 `validation_cfg` 内设置 `type='CustomValLoop'`,并添加额外的配置即可。
|
||||
只需要在配置文件的 `val_cfg` 内设置 `type='CustomValLoop'`,并添加额外的配置即可。
|
||||
|
||||
```python
|
||||
validation_cfg = dict(type='CustomValLoop', loader2=dict(dataset=dict(type='ValDataset2'), ...))
|
||||
# 自定义验证循环
|
||||
val_cfg = dict(type='CustomValLoop', dataloader2=dict(dataset=dict(type='ValDataset2'), ...))
|
||||
# 额外点位的钩子
|
||||
custom_hooks = [dict(type='CustomValHook')]
|
||||
```
|
||||
|
||||
### 自定义执行器
|
||||
|
Loading…
x
Reference in New Issue
Block a user