8.4 KiB
8.4 KiB
迁移 MMCV 钩子到 MMEngine
简介
由于架构设计的更新和用户需求的不断增加,MMCV 的钩子(Hook)点位已经满足不了需求,因此在 MMEngine 中对钩子点位进行了重新设计以及对钩子的功能做了调整。在开始迁移前,阅读钩子的设计会很有帮助。
本文对比 MMCV v1.6.0 和 MMEngine v0.5.0 的钩子在功能、点位、用法和实现上的差异。
功能差异
MMCV | MMEngine | |
---|---|---|
反向传播以及梯度更新 | OptimizerHook | 将反向传播以及梯度更新的操作抽象成 OptimWrapper 而不是钩子 |
GradientCumulativeOptimizerHook | ||
学习率调整 | LrUpdaterHook | ParamSchdulerHook 以及 _ParamScheduler 的子类完成优化器超参的调整 |
动量调整 | MomentumUpdaterHook | |
按指定间隔保存权重 | CheckpointHook | CheckpointHook 除了保存权重,还有保存最优权重的功能,而 EvalHook 的模型评估功能则交由 ValLoop 或 TestLoop 完成 |
模型评估并保存最优模型 | EvalHook | |
打印日志 | LoggerHook 及其子类实现打印日志、保存日志以及可视化功能 | LoggerHook |
可视化 | NaiveVisualizationHook | |
添加运行时信息 | RuntimeInfoHook | |
模型参数指数滑动平均 | EMAHook | EMAHook |
确保分布式 Sampler 的 shuffle 生效 | DistSamplerSeedHook | DistSamplerSeedHook |
同步模型的 buffer | SyncBufferHook | SyncBufferHook |
PyTorch CUDA 缓存清理 | EmptyCacheHook | EmptyCacheHook |
统计迭代耗时 | IterTimerHook | IterTimerHook |
分析训练时间的瓶颈 | ProfilerHook | 暂未提供 |
提供注册方法给钩子点位的功能 | ClosureHook | 暂未提供 |
点位差异
MMCV | MMEngine | ||
---|---|---|---|
全局位点 | 执行前 | before_run | before_run |
执行后 | after_run | after_run | |
Checkpoint 相关 | 加载 checkpoint 后 | 无 | after_load_checkpoint |
保存 checkpoint 前 | 无 | before_save_checkpoint | |
训练相关 | 训练前触发 | 无 | before_train |
训练后触发 | 无 | after_train | |
每个 epoch 前 | before_train_epoch | before_train_epoch | |
每个 epoch 后 | after_train_epoch | after_train_epoch | |
每次迭代前 | before_train_iter | before_train_iter,新增 batch_idx 和 data_batch 参数 | |
每次迭代后 | after_train_iter | after_train_iter,新增 batch_idx、data_batch 和 outputs 参数 | |
验证相关 | 验证前触发 | 无 | before_val |
验证后触发 | 无 | after_val | |
每个 epoch 前 | before_val_epoch | before_val_epoch | |
每个 epoch 后 | after_val_epoch | after_val_epoch | |
每次迭代前 | before_val_iter | before_val_iter,新增 batch_idx 和 data_batch 参数 | |
每次迭代后 | after_val_iter | after_val_iter,新增 batch_idx、data_batch 和 outputs 参数 | |
测试相关 | 测试前触发 | 无 | before_test |
测试后触发 | 无 | after_test | |
每个 epoch 前 | 无 | before_test_epoch | |
每个 epoch 后 | 无 | after_test_epoch | |
每次迭代前 | 无 | before_test_iter,新增 batch_idx 和 data_batch 参数 | |
每次迭代后 | 无 | after_test_iter,新增 batch_idx、data_batch 和 outputs 参数 |
用法差异
在 MMCV 中,将钩子注册到执行器(Runner),需调用执行器的 register_training_hooks
方法往执行器注册钩子,而在 MMEngine 中,可以通过参数传递给执行器的初始化方法进行注册。
- MMCV
model = ResNet18()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
lr_config = dict(policy='step', step=[2, 3])
optimizer_config = dict(grad_clip=None)
checkpoint_config = dict(interval=5)
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
runner = EpochBasedRunner(
model=model,
optimizer=optimizer,
work_dir='./work_dir',
max_epochs=3,
xxx,
)
runner.register_training_hooks(
lr_config=lr_config,
optimizer_config=optimizer_config,
checkpoint_config=checkpoint_config,
log_config=log_config,
custom_hooks_config=custom_hooks,
)
runner.run([trainloader], [('train', 1)])
- MMEngine
model=ResNet18()
optim_wrapper=dict(
type='OptimizerWrapper',
optimizer=dict(type='SGD', lr=0.001, momentum=0.9))
param_scheduler = dict(type='MultiStepLR', milestones=[2, 3]),
default_hooks = dict(
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=5),
)
custom_hooks = [dict(type='NumClassCheckHook')]
runner = Runner(
model=model,
work_dir='./work_dir',
optim_wrapper=optim_wrapper,
param_scheduler=param_scheduler,
train_cfg=dict(by_epoch=True, max_epochs=3),
default_hooks=default_hooks,
custom_hooks=custom_hooks,
xxx,
)
runner.train()
MMEngine 钩子的更多用法请参考钩子的用法。
实现差异
以 CheckpointHook
为例,MMEngine 的 CheckpointHook 相比 MMCV 的 CheckpointHook(新增保存最优权重的功能,在 MMCV 中,保存最优权重的功能由 EvalHook 提供),因此,它需要实现 after_val_epoch
点位。
- MMCV
class CheckpointHook(Hook):
def before_run(self, runner):
"""初始化 out_dir 和 file_client 属性"""
def after_train_epoch(self, runner):
"""同步 buffer 和保存权重,用于以 epoch 为单位训练的任务"""
def after_train_iter(self, runner):
"""同步 buffer 和保存权重,用于以 iteration 为单位训练的任务"""
- MMEngine
class CheckpointHook(Hook):
def before_run(self, runner):
"""初始化 out_dir 和 file_client 属性"""
def after_train_epoch(self, runner):
"""同步 buffer 和保存权重,用于以 epoch 为单位训练的任务"""
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""同步 buffer 和保存权重,用于以 iteration 为单位训练的任务"""
def after_val_epoch(self, runner, metrics):
"""根据 metrics 保存最优权重"""