[Doc]: Update hooks docs (#317)
parent
d09af9ead4
commit
e76517c63a
|
@ -15,33 +15,33 @@ PyTorch 提供了一套基础的通信原语用于多进程之间张量的通信
|
|||
|
||||
## 分布式初始化
|
||||
|
||||
- [init_dist](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.init_dist): 是分布式训练的启动函数,目前支持 pytorch,slurm,MPI 3 种分布式启动方式,同时允许设置通信的后端,默认使用 NCCL。
|
||||
- [init_dist](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.init_dist): 是分布式训练的启动函数,目前支持 pytorch,slurm,MPI 3 种分布式启动方式,同时允许设置通信的后端,默认使用 NCCL。
|
||||
|
||||
## 分布式信息获取与控制
|
||||
|
||||
分布式信息的获取与控制函数没有参数,这些函数兼容非分布式训练的情况,功能如下
|
||||
|
||||
- [get_world_size](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.get_world_size):获取当前进程组的进程总数,非分布式情况下返回 1
|
||||
- [get_rank](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.get_rank):获取当前进程对应的全局 rank 数,非分布式情况下返回 0
|
||||
- [get_backend](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.get_backend):获取当前通信使用的后端,非分布式情况下返回 None
|
||||
- [get_local_rank](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.get_local_rank):获取当前进程对应到当前机器的 rank 数,非分布式情况下返回 0
|
||||
- [get_local_size](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.get_local_size):获取当前进程所在机器的总进程数,非分布式情况下返回 0
|
||||
- [get_dist_info](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.get_dist_info):获取当前任务的进程总数和当前进程对应到全局的 rank 数,非分布式情况下 word_size = 1,rank = 0
|
||||
- [is_main_process](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.is_main_process):判断是否为 0 号主进程,非分布式情况下返回 True
|
||||
- [master_only](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.master_only):函数装饰器,用于修饰只需要全局 0 号进程(rank 0 而不是 local rank 0)执行的函数
|
||||
- [barrier](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.barrier):同步所有进程到达相同位置
|
||||
- [get_world_size](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.get_world_size):获取当前进程组的进程总数,非分布式情况下返回 1
|
||||
- [get_rank](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.get_rank):获取当前进程对应的全局 rank 数,非分布式情况下返回 0
|
||||
- [get_backend](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.get_backend):获取当前通信使用的后端,非分布式情况下返回 None
|
||||
- [get_local_rank](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.get_local_rank):获取当前进程对应到当前机器的 rank 数,非分布式情况下返回 0
|
||||
- [get_local_size](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.get_local_size):获取当前进程所在机器的总进程数,非分布式情况下返回 0
|
||||
- [get_dist_info](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.get_dist_info):获取当前任务的进程总数和当前进程对应到全局的 rank 数,非分布式情况下 word_size = 1,rank = 0
|
||||
- [is_main_process](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.is_main_process):判断是否为 0 号主进程,非分布式情况下返回 True
|
||||
- [master_only](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.master_only):函数装饰器,用于修饰只需要全局 0 号进程(rank 0 而不是 local rank 0)执行的函数
|
||||
- [barrier](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.barrier):同步所有进程到达相同位置
|
||||
|
||||
## 分布式通信函数
|
||||
|
||||
通信函数 (Collective functions),主要用于进程间数据的通信,基于 PyTorch 原生的 all_reduce,all_gather,gather,broadcast 接口,MMEngine 提供了如下接口,兼容非分布式训练的情况,并支持更丰富数据类型的通信。
|
||||
|
||||
- [all_reduce](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.all_reduce): 对进程间 tensor 进行 AllReduce 操作
|
||||
- [all_gather](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.all_gather):对进程间 tensor 进行 AllGather 操作
|
||||
- [gather](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.gather):将进程的 tensor 收集到一个目标 rank
|
||||
- [broadcast](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.broadcast):对某个进程的 tensor 进行广播
|
||||
- [sync_random_seed](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.sync_random_seed):同步进程之间的随机种子
|
||||
- [broadcast_object_list](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.broadcast_object_list):支持 object list 的广播,可以基于 broadcast 接口实现
|
||||
- [all_reduce_dict](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.all_reduce_dict):对 dict 中的内容进行 all_reduce 操作,基于 broadcast 和 all_reduce 接口实现
|
||||
- [all_gather_object](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.all_gather_object):基于 all_gather 实现对任意可以 Python 序列化对象的 all_tather 操作
|
||||
- [gather_object](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.gather_object):将 group 里每个 rank 的 data gather 到一个目标 rank,且支持多种方式
|
||||
- [collect_results](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.collect_results):支持基于 CPU 或者 GPU 对不同进程间的列表数据进行收集
|
||||
- [all_reduce](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.all_reduce): 对进程间 tensor 进行 AllReduce 操作
|
||||
- [all_gather](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.all_gather):对进程间 tensor 进行 AllGather 操作
|
||||
- [gather](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.gather):将进程的 tensor 收集到一个目标 rank
|
||||
- [broadcast](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.broadcast):对某个进程的 tensor 进行广播
|
||||
- [sync_random_seed](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.sync_random_seed):同步进程之间的随机种子
|
||||
- [broadcast_object_list](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.broadcast_object_list):支持 object list 的广播,可以基于 broadcast 接口实现
|
||||
- [all_reduce_dict](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.all_reduce_dict):对 dict 中的内容进行 all_reduce 操作,基于 broadcast 和 all_reduce 接口实现
|
||||
- [all_gather_object](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.all_gather_object):基于 all_gather 实现对任意可以 Python 序列化对象的 all_tather 操作
|
||||
- [gather_object](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.gather_object):将 group 里每个 rank 的 data gather 到一个目标 rank,且支持多种方式
|
||||
- [collect_results](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.dist.collect_results):支持基于 CPU 或者 GPU 对不同进程间的列表数据进行收集
|
||||
|
|
|
@ -41,7 +41,7 @@ import torch.nn as nn
|
|||
def forward_hook_fn(
|
||||
module, # 被注册钩子的对象
|
||||
input, # module 前向计算的输入
|
||||
output # module 前向计算的输出
|
||||
output, # module 前向计算的输出
|
||||
):
|
||||
print(f'"forward_hook_fn" is invoked by {module.name}')
|
||||
print('weight:', module.weight.data)
|
||||
|
@ -129,13 +129,14 @@ def main():
|
|||
accuracy = ...
|
||||
```
|
||||
|
||||
上面的伪代码是训练模型的基本步骤。如果要在上面的代码中加入定制化的逻辑,我们需要不断修改和拓展 `main` 函数。为了提高 `main` 函数的灵活性和拓展性,我们可以在 `main` 方法中插入 16 个位点,并在对应位点实现调用 hook 的抽象逻辑。此时只需在这些位点插入 hook 来实现定制化逻辑,即可添加定制化功能,例如加载模型权重、更新模型参数等。
|
||||
上面的伪代码是训练模型的基本步骤。如果要在上面的代码中加入定制化的逻辑,我们需要不断修改和拓展 `main` 函数。为了提高 `main` 函数的灵活性和拓展性,我们可以在 `main` 方法中插入位点,并在对应位点实现调用 hook 的抽象逻辑。此时只需在这些位点插入 hook 来实现定制化逻辑,即可添加定制化功能,例如加载模型权重、更新模型参数等。
|
||||
|
||||
```python
|
||||
def main():
|
||||
...
|
||||
call_hooks('before_run', hooks) # 训练开始前执行的逻辑
|
||||
call_hooks('before_run', hooks) # 任务开始前执行的逻辑
|
||||
call_hooks('after_load_checkpoint', hooks) # 加载权重后执行的逻辑
|
||||
call_hooks('before_train', hooks) # 训练开始前执行的逻辑
|
||||
for i in range(max_epochs):
|
||||
call_hooks('before_train_epoch', hooks) # 遍历训练数据集前执行的逻辑
|
||||
for inputs, labels in train_dataloader:
|
||||
|
@ -157,6 +158,7 @@ def main():
|
|||
call_hooks('after_val_epoch', hooks) # 遍历完验证数据集前执行
|
||||
|
||||
call_hooks('before_save_checkpoint', hooks) # 保存权重前执行的逻辑
|
||||
call_hooks('after_train', hooks) # 训练结束后执行的逻辑
|
||||
|
||||
call_hooks('before_test_epoch', hooks) # 遍历测试数据集前执行的逻辑
|
||||
with torch.no_grad():
|
||||
|
@ -167,12 +169,37 @@ def main():
|
|||
call_hooks('after_test_iter', hooks) # 遍历完成测试数据集后执行的逻辑
|
||||
call_hooks('after_test_epoch', hooks) # 遍历完测试数据集后执行
|
||||
|
||||
call_hooks('after_run', hooks) # 训练结束后执行的逻辑
|
||||
call_hooks('after_run', hooks) # 任务结束后执行的逻辑
|
||||
```
|
||||
|
||||
在 MMEngine 中,我们将训练过程抽象成执行器(Runner),执行器除了完成环境的初始化,另一个功能是在特定的位点调用钩子完成定制化逻辑。更多关于执行器的介绍请阅读[文档](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/runner.html)。
|
||||
在 MMEngine 中,我们将训练过程抽象成执行器(Runner),执行器除了完成环境的初始化,另一个功能是在特定的位点调用钩子完成定制化逻辑。更多关于执行器的介绍请阅读[执行器文档](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/runner.html)。
|
||||
|
||||
为了方便管理,MMEngine 将 16 个位点定义为方法并集成到钩子基类(Hook)中,我们只需继承钩子基类并根据需求在特定位点实现定制化逻辑,再将钩子注册到执行器中,便可自动调用钩子中相应位点的方法。
|
||||
为了方便管理,MMEngine 将位点定义为方法并集成到[钩子基类(Hook)](https://mmengine.readthedocs.io/zh/latest/api.html#hook)中,我们只需继承钩子基类并根据需求在特定位点实现定制化逻辑,再将钩子注册到执行器中,便可自动调用钩子中相应位点的方法。
|
||||
|
||||
钩子中一共有 22 个位点:
|
||||
|
||||
- before_run
|
||||
- after_run
|
||||
- before_train
|
||||
- after_train
|
||||
- before_train_epoch
|
||||
- after_train_epoch
|
||||
- before_train_iter
|
||||
- after_train_iter
|
||||
- before_val
|
||||
- after_val
|
||||
- before_test_epoch
|
||||
- after_test_epoch
|
||||
- before_val_iter
|
||||
- after_val_iter
|
||||
- before_test
|
||||
- after_test
|
||||
- before_test_epoch
|
||||
- after_test_epoch
|
||||
- before_test_iter
|
||||
- after_test_iter
|
||||
- before_save_checkpoint
|
||||
- after_load_checkpoint
|
||||
|
||||
## 内置钩子
|
||||
|
||||
|
@ -192,23 +219,23 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
|
|||
|
||||
**默认钩子**
|
||||
|
||||
| 名称 | 用途 | 优先级 |
|
||||
| :-----------------: | :-------------------------: | :---------------: |
|
||||
| RuntimeInfoHook | 向 message hub 更新运行时信息 | VERY_HIGH (10) |
|
||||
| OptimizerHook | 反向传播以及参数更新 | HIGH (30) |
|
||||
| DistSamplerSeedHook | 确保分布式 Sampler 的 shuffle 生效 | NORMAL (50) |
|
||||
| SyncBuffersHook | 同步模型的 buffer | NORMAL (50) |
|
||||
| EmptyCacheHook | PyTorch CUDA 缓存清理 | NORMAL (50) |
|
||||
| IterTimerHook | 统计迭代耗时 | NORMAL (50) |
|
||||
| LoggerHook | 打印日志 | BELOW_NORMAL (60) |
|
||||
| ParamSchedulerHook | 调用 ParamScheduler 的 step 方法 | LOW (70) |
|
||||
| CheckpointHook | 按指定间隔保存权重 | VERY_LOW (90) |
|
||||
| 名称 | 用途 | 优先级 |
|
||||
| :-----------------------------------------: | :-------------------------: | :---------------: |
|
||||
| [RuntimeInfoHook](#runtimeinfohook) | 往 message hub 更新运行时信息 | VERY_HIGH (10) |
|
||||
| [IterTimerHook](#itertimerhook) | 统计迭代耗时 | NORMAL (50) |
|
||||
| [DistSamplerSeedHook](#distsamplerseedhook) | 确保分布式 Sampler 的 shuffle 生效 | NORMAL (50) |
|
||||
| [LoggerHook](#loggerhook) | 打印日志 | BELOW_NORMAL (60) |
|
||||
| [ParamSchedulerHook](#paramschedulerhook) | 调用 ParamScheduler 的 step 方法 | LOW (70) |
|
||||
| [CheckpointHook](#checkpointhook) | 按指定间隔保存权重 | VERY_LOW (90) |
|
||||
|
||||
**自定义钩子**
|
||||
|
||||
| 名称 | 用途 | 优先级 |
|
||||
| :------------: | :-: | :----------: |
|
||||
| VisualizerHook | 可视化 | LOWEST (100) |
|
||||
| 名称 | 用途 | 优先级 |
|
||||
| :---------------------------------: | :---------------: | :----------: |
|
||||
| [EMAHook](#emahook) | 模型参数指数滑动平均 | NORMAL (50) |
|
||||
| [EmptyCacheHook](#emptycachehook) | PyTorch CUDA 缓存清理 | NORMAL (50) |
|
||||
| [SyncBuffersHook](#syncbuffershook) | 同步模型的 buffer | NORMAL (50) |
|
||||
| NaiveVisualizationHook | 可视化 | LOWEST (100) |
|
||||
|
||||
```{note}
|
||||
不建议修改默认钩子的优先级,因为优先级低的钩子可能会依赖优先级高的钩子。例如 CheckpointHook 的优先级需要比 ParamSchedulerHook 低,这样保存的优化器状态才是正确的状态。另外,自定义钩子的优先级默认为 `NORMAL (50)`。
|
||||
|
@ -221,7 +248,6 @@ from mmengine import Runner
|
|||
|
||||
default_hooks = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
|
@ -230,7 +256,7 @@ default_hooks = dict(
|
|||
)
|
||||
|
||||
custom_hooks = [
|
||||
dict(type='VisualizerHook', priority='LOWEST'),
|
||||
dict(type='NaiveVisualizationHook', priority='LOWEST'),
|
||||
]
|
||||
|
||||
runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...)
|
||||
|
@ -279,62 +305,20 @@ checkpoint_config = dict(type='CheckpointHook', internal=5, max_keep_ckpts=2)
|
|||
|
||||
上述例子表示,假如一共训练 20 个 epoch,那么会在第 5, 10, 15, 20 个 epoch 保存模型,但是在第 15 个 epoch 的时候会删除第 5 个 epoch 保存的权重,在第 20 个 epoch 的时候会删除第 10 个 epoch 的权重,最终只有第 15 和第 20 个 epoch 的权重才会被保存。
|
||||
|
||||
### OptimizerHook
|
||||
### LoggerHook
|
||||
|
||||
`OptimizerHook` 包含一些 optimizer 相关的操作:
|
||||
`LoggerHook` 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。
|
||||
|
||||
- 梯度清零 runner.optimizer.zero_grad()
|
||||
- 反向传播 runner.output\['loss'\].backward()
|
||||
- 梯度截断 clip_grads(可选)
|
||||
- 参数更新 runner.optimizer.step()
|
||||
如果我们希望每迭代 20 次就输出(或保存)一次日志,我们可以设置 interval 参数,配置如下:
|
||||
|
||||
```python
|
||||
from mmengine import HOOKS
|
||||
|
||||
optimizer_config = dict(type='OptimizerHook')
|
||||
HOOKS.build(optimizer_config)
|
||||
config = dict(type='LoggerHook', interval=20)
|
||||
```
|
||||
|
||||
使用以上配置即可实现在 Trainer 中完成梯度清零、反向传播以及参数更新。
|
||||
|
||||
如果我们想对梯度进行截断,避免梯度爆炸,则可以设置 grad_clip 参数,该参数的设置可参考 [clip_grad_norm\_](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html)
|
||||
如果我们希望训练结束后把指定后缀的文件转存到其他路径,例如 Ceph。我们可以设置 out_dir、out_suffix 和 keep_loal 三个参数。第一个参数表示将文件转存到指定的路径;第二个参数表示需要转存以哪些后缀结尾的文件,默认是 .json、.log、.py 和 yaml;第三个参数表示当我们把文件转存到其他路径后是否删除被转存的文件。
|
||||
|
||||
```python
|
||||
optimizer_config=dict(type='OptimizerHook', grad_clip=dict(max_norm=35, norm_type=2))
|
||||
```
|
||||
|
||||
模型中可能存在不参与计算图的模型参数,有两种可能,一种是该参数没有参与前向计算,另一种参与了前向计算但没有参与 loss 的计算。而如果模型中存在这种参数,会导致 PyTorch 抛出错误 `RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one`。我们可以通过设置 `detect_anomalous_params=True` 来检测并找出这种参数。
|
||||
|
||||
```python
|
||||
optimizer_config=dict(type='OptimizerHook', detect_anomalous_params=True))
|
||||
```
|
||||
|
||||
```{note}
|
||||
`detect_anomalous_params=True` 会降低训练速度,推荐只用于调试。
|
||||
```
|
||||
|
||||
除了 `OptimizerHook`,MMEngine 还提供了 `Fp16OptimizerHook` 和 `GradientCumulativeOptimizerHook`,前者用于混合精度训练,后者用于梯度累计。
|
||||
|
||||
`Fp16OptimizerHook` 是混合精度训练在 MMEngine 中的实现,主要逻辑如下:
|
||||
|
||||
- 维护一个 FP32 数值精度模型的副本
|
||||
- 在每个 iteration
|
||||
- 拷贝并且转换成 FP16 模型
|
||||
- 前向传播(FP16 的模型参数),此时 weights, activations 都是 FP16
|
||||
- loss 乘缩放参数 s,避免非 0 梯度溢出
|
||||
- 反向传播(FP16 的模型参数和参数梯度), 此时 gradients 也是 FP16
|
||||
- 参数梯度乘 1/s
|
||||
- 利用 FP16 的梯度更新 FP32 的模型参数
|
||||
|
||||

|
||||
|
||||
关于 `Fp16OptimizerHook` 的使用请阅读[如何节省显存消耗](TODO)。
|
||||
|
||||
`GradientCumulativeOptimizerHook` 用于节省显存,即通过指定梯度累积的次数,实现反向传播多次才更新参数,常常用于显存不足但希望用较大的 batch size 训练模型。
|
||||
|
||||
```python
|
||||
# cumulative_iters=4 表示累加参数梯度 4 次才更新一次参数
|
||||
optimizer_config = dict(type="GradientCumulativeOptimizerHook", cumulative_iters=4)
|
||||
config = dict(type='LoggerHook', out_dir='s3://save_log/', out_suffix=('.json', '.py'), keep_local=True)
|
||||
```
|
||||
|
||||
### ParamSchedulerHook
|
||||
|
@ -367,6 +351,14 @@ config = dict(type='IterTimerHook')
|
|||
config = dict(type='DistSamplerSeedHook')
|
||||
```
|
||||
|
||||
### EMAHook
|
||||
|
||||
`EMAHook` 在训练过程中对模型执行指数滑动平均操作,目的是提高模型的鲁棒性。注意:指数滑动平均生成的模型只用于验证和测试,不影响训练。
|
||||
|
||||
```python
|
||||
config = dict(type='EMAHook')
|
||||
```
|
||||
|
||||
### EmptyCacheHook
|
||||
|
||||
`EmptyCacheHook` 调用 `torch.cuda.empty_cache()` 释放未被使用的显存。`EmptyCacheHook` 会在 3 个位点调用 `torch.cuda.empty_cache()`,分别是 `before_epoch`, `after_iter` 以及 `after_epoch`,用户可以通过参数控制是否调用。
|
||||
|
@ -388,6 +380,10 @@ config = dict(type='SyncBuffersHook')
|
|||
`RuntimeInfoHook` 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,
|
||||
以便其他无法访问执行器的模块能够获取到这些信息。
|
||||
|
||||
```python
|
||||
config = dict(type='RuntimeInfoHook')
|
||||
```
|
||||
|
||||
## 添加自定义钩子
|
||||
|
||||
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。
|
||||
|
@ -420,8 +416,8 @@ class CheckInvalidLossHook(Hook):
|
|||
"""All subclasses should override this method, if they need any
|
||||
operations after each training iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
|
|
|
@ -1687,9 +1687,9 @@ class Runner:
|
|||
+======================+=========================+
|
||||
| RuntimeInfoHook | VERY_HIGH (10) |
|
||||
+----------------------+-------------------------+
|
||||
| IterTimerHook | NORMAL (40) |
|
||||
| IterTimerHook | NORMAL (50) |
|
||||
+----------------------+-------------------------+
|
||||
| DistSamplerSeedHook | NORMAL (40) |
|
||||
| DistSamplerSeedHook | NORMAL (50) |
|
||||
+----------------------+-------------------------+
|
||||
| LoggerHook | BELOW_NORMAL (60) |
|
||||
+----------------------+-------------------------+
|
||||
|
@ -1716,8 +1716,9 @@ class Runner:
|
|||
|
||||
hooks = dict(timer=None)
|
||||
|
||||
The final registered default hooks will be :obj:`OptimizerHook`,
|
||||
:obj:`LoggerHook`, :obj:`ParamSchedulerHook` and :obj:`CheckpointHook`.
|
||||
The final registered default hooks will be :obj:`RuntimeInfoHook`,
|
||||
:obj:`DistSamplerSeedHook`, :obj:`LoggerHook`,
|
||||
:obj:`ParamSchedulerHook` and :obj:`CheckpointHook`.
|
||||
|
||||
Args:
|
||||
hooks (dict[str, Hook or dict], optional): Default hooks or configs
|
||||
|
@ -1726,10 +1727,10 @@ class Runner:
|
|||
default_hooks: dict = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
)
|
||||
if hooks is not None:
|
||||
for name, hook in hooks.items():
|
||||
|
|
Loading…
Reference in New Issue