mmengine/docs/zh_cn/tutorials/hook.md

12 KiB
Raw Blame History

钩子Hook

钩子编程是一种编程模式,是指在程序的一个或者多个位置设置位点(挂载点),当程序运行至某个位点时,会自动调用运行时注册到位点的所有方法。钩子编程可以提高程序的灵活性和拓展性,用户将自定义的方法注册到位点便可被调用而无需修改程序中的代码。

内置钩子

MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默认钩子以及自定义钩子,前者表示会默认往执行器注册,后者表示需要用户自己注册。

每个钩子都有对应的优先级,在同一位点,钩子的优先级越高,越早被执行器调用,如果优先级一样,被调用的顺序和钩子注册的顺序一致。优先级列表如下:

  • HIGHEST (0)
  • VERY_HIGH (10)
  • HIGH (30)
  • ABOVE_NORMAL (40)
  • NORMAL (50)
  • BELOW_NORMAL (60)
  • LOW (70)
  • VERY_LOW (90)
  • LOWEST (100)

默认钩子

名称 用途 优先级
RuntimeInfoHook 往 message hub 更新运行时信息 VERY_HIGH (10)
IterTimerHook 统计迭代耗时 NORMAL (50)
DistSamplerSeedHook 确保分布式 Sampler 的 shuffle 生效 NORMAL (50)
LoggerHook 打印日志 BELOW_NORMAL (60)
ParamSchedulerHook 调用 ParamScheduler 的 step 方法 LOW (70)
CheckpointHook 按指定间隔保存权重 VERY_LOW (90)

自定义钩子

名称 用途 优先级
EMAHook 模型参数指数滑动平均 NORMAL (50)
EmptyCacheHook PyTorch CUDA 缓存清理 NORMAL (50)
SyncBuffersHook 同步模型的 buffer NORMAL (50)
NaiveVisualizationHook 可视化 LOWEST (100)
不建议修改默认钩子的优先级,因为优先级低的钩子可能会依赖优先级高的钩子。例如 CheckpointHook 的优先级需要比 ParamSchedulerHook 低,这样保存的优化器状态才是正确的状态。另外,自定义钩子的优先级默认为 `NORMAL (50)`。

两种钩子在执行器中的设置不同,默认钩子的配置传给执行器的 default_hooks 参数,自定义钩子的配置传给 custom_hooks 参数,如下所示:

from mmengine.runner import Runner

default_hooks = dict(
    runtime_info=dict(type='RuntimeInfoHook'),
    timer=dict(type='IterTimerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    logger=dict(type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1),
)

custom_hooks = [
    dict(type='NaiveVisualizationHook', priority='LOWEST'),
]

runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...)
runner.train()

下面逐一介绍 MMEngine 中内置钩子的用法。

CheckpointHook

CheckpointHook 按照给定间隔保存模型的权重如果是分布式多卡训练则只有主master进程会保存权重。CheckpointHook 的主要功能如下:

  • 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重
  • 保存最新的多个权重
  • 保存最优权重
  • 指定保存权重的路径

如需了解其他功能,请阅读CheckpointHook API 文档

下面介绍上面提到的 4 个功能。

  • 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重

假设我们一共训练 20 个 epoch 并希望每隔 5 个 epoch 保存一次权重,下面的配置即可帮我们实现该需求。

# by_epoch 的默认值为 True
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))

如果想以迭代次数作为保存间隔,则可以将 by_epoch 设为 Falseinterval=5 则表示每迭代 5 次保存一次权重。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))
  • 保存最新的多个权重

如果只想保存一定数量的权重,可以通过设置 max_keep_ckpts 参数实现最多保存 max_keep_ckpts 个权重,当保存的权重数超过 max_keep_ckpts 时,前面的权重会被删除。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=2))

上述例子表示,假如一共训练 20 个 epoch那么会在第 5, 10, 15, 20 个 epoch 保存模型,但是在第 15 个 epoch 的时候会删除第 5 个 epoch 保存的权重,在第 20 个 epoch 的时候会删除第 10 个 epoch 的权重,最终只有第 15 和第 20 个 epoch 的权重才会被保存。

  • 保存最优权重

如果想要保存训练过程验证集的最优权重,可以设置 save_best 参数,如果设置为 'auto',则会根据验证集的第一个评价指标(验证集返回的评价指标是一个有序字典)判断当前权重是否最优。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', save_best='auto'))

也可以直接指定 save_best 的值为评价指标,例如在分类任务中,可以指定为 save_best='top-1',则会根据 'top-1' 的值判断当前权重是否最优。

除了 save_best 参数,和保存最优权重相关的参数还有 rulegreater_keysless_keys,这三者用来判断 save_bes 的值是越大越好还是越小越好。例如指定了 save_best='top-1',可以指定 rule='greater',则表示该值越大表示权重越好。

  • 指定保存权重的路径

权重默认保存在工作目录work_dir但可以通过设置 out_dir 改变保存路径。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))

LoggerHook

LoggerHook 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。

如果我们希望每迭代 20 次就输出(或保存)一次日志,我们可以设置 interval 参数,配置如下:

default_hooks = dict(logger=dict(type='LoggerHook', interval=20))

如果你对日志的管理感兴趣,可以阅读记录日志logging

ParamSchedulerHook

ParamSchedulerHook 遍历执行器的所有优化器参数调整策略Parameter Scheduler并逐个调用 step 方法更新优化器的参数。如需了解优化器参数调整策略的用法请阅读文档ParamSchedulerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

IterTimerHook

IterTimerHook 用于记录加载数据的时间以及迭代一次耗费的时间。IterTimerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

DistSamplerSeedHook

DistSamplerSeedHook 在分布式训练时调用 Sampler 的 step 方法以确保 shuffle 参数生效。DistSamplerSeedHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

RuntimeInfoHook

RuntimeInfoHook 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等更新至 message hub 中, 以便其他无法访问执行器的模块能够获取到这些信息。RuntimeInfoHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

EMAHook

EMAHook 在训练过程中对模型执行指数滑动平均操作,目的是提高模型的鲁棒性。注意:指数滑动平均生成的模型只用于验证和测试,不影响训练。

custom_hooks = [dict(type='EMAHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

EMAHook 默认使用 ExponentialMovingAverage,可选值还有 StochasticWeightAverageMomentumAnnealingEMA。可以通过设置 ema_type 使用其他的平均策略。

custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]

更多用法请阅读EMAHook API 文档

EmptyCacheHook

EmptyCacheHook 调用 torch.cuda.empty_cache() 释放未被使用的显存。 可以通过设置 before_epoch, after_iter 以及 after_epoch 参数控制释显存的时机,第一个参数表示在每个 epoch 开始之前,第二参数表示在每次迭代之后,第三个参数表示在每个 epoch 之后。

# 每一个 epoch 结束都会执行释放操作
custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

SyncBuffersHook

SyncBuffersHook 在分布式训练每一轮epoch结束时同步模型的 buffer例如 BN 层的 running_mean 以及 running_var

custom_hooks = [dict(type='SyncBuffersHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

自定义钩子

如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。

例如,如果希望在训练的过程中判断损失值是否有效,如果值为无穷大则无效,我们可以在每次迭代后判断损失值是否无穷大,因此只需重写 after_train_iter 位点。

import torch

from mmengine.registry import HOOKS
from mmengine.hooks import Hook


@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    """Check invalid loss hook.

    This hook will regularly check whether the loss is valid
    during training.

    Args:
        interval (int): Checking interval (every k iterations).
            Defaults to 50.
    """

    def __init__(self, interval=50):
        self.interval = interval

    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        """All subclasses should override this method, if they need any
        operations after each training iteration.

        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (Sequence[dict], optional): Data from dataloader.
                Defaults to None.
            outputs (dict, optional): Outputs from model.
                Defaults to None.
        """
        if self.every_n_train_iters(runner, self.interval):
            assert torch.isfinite(outputs['loss']),\
                runner.logger.info('loss become infinite or NaN!')

我们只需将钩子的配置传给执行器的 custom_hooks 的参数,执行器初始化的时候会注册钩子,

from mmengine.runner import Runner

custom_hooks = dict(
    dict(type='CheckInvalidLossHook', interval=50)
)
runner = Runner(custom_hooks=custom_hooks, ...)  # 实例化执行器,主要完成环境的初始化以及各种模块的构建
runner.train()  # 执行器开始训练

便会在每次模型前向计算后检查损失值。

注意,自定义钩子的优先级默认为 NORMAL (50),如果想改变钩子的优先级,则可以在配置中设置 priority 字段。

custom_hooks = dict(
    dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL')
)

也可以在定义类是给定优先级

@HOOKS.register_module()
class CheckInvalidLossHook(Hook):

    priority = 'ABOVE_NORMAL'

你可能还想阅读钩子的设计或者钩子的 API 文档