diff --git a/docs/en/migration/hook.md b/docs/en/migration/hook.md index daabe338..0d4ac06d 100644 --- a/docs/en/migration/hook.md +++ b/docs/en/migration/hook.md @@ -1,3 +1,319 @@ # Migrate Hook from MMCV to MMEngine -Coming soon. Please refer to [chinese documentation](https://mmengine.readthedocs.io/zh_CN/latest/migration/hook.html). +## Introduction + +Due to the upgrade of our architecture design and the continuous increase of user demands, existing hook mount points in MMCV can no longer meet the requirements. Hence, we redesigned the mount points in MMEngine, and the functions of hooks were adjusted accordingly. It will help a lot to read the tutorial [Hook Design](../design/hook.md) before your migration. + +This tutorial compares the difference in function, mount point, usage and implementation between [MMCV v1.6.0](https://github.com/open-mmlab/mmcv/tree/v1.6.0) and [MMEngine v0.5.0](https://github.com/open-mmlab/mmengine/tree/v0.5.0). + +## Function Comparison + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MMCVMMEngine
Backpropagation and gradient updateOptimizerHookUnify the backpropagation and gradient update operations into OptimWrapper rather than hooks
GradientCumulativeOptimizerHook
Learning rate adjustmentLrUpdaterHookUse ParamSchdulerHook and subclasses of _ParamScheduler to complete the adjustment of optimizer hyperparameters
Momentum adjustmentMomentumUpdaterHook
Saving model weights at specified intervalCheckpointHookThe CheckpointHook is responsible for not only saving weights but also saving the optimal weights. Meanwhile, the model evaluation function of EvalHook is delegated to ValLoop or TestLoop.
Model evaluation and optimal weights savingEvalHook
Log printingLoggerHook and its subclasses can print logs, save logs and visualize dataLoggerHook
VisualizationNaiveVisualizationHook
Adding runtime informationRuntimeInfoHook
Model weights exponential moving average (EMA)EMAHookEMAHook
Ensuring that the shuffle functionality of the distributed Sampler takes effectDistSamplerSeedHookDistSamplerSeedHook
Synchronizing model bufferSyncBufferHookSyncBufferHook
Empty PyTorch CUDA cacheEmptyCacheHookEmptyCacheHook
Calculating iteration time-consumingIterTimerHookIterTimerHook
Analyzing bottlenecks of training timeProfilerHookNot yet available
Provide the most concise function registrationClosureHookNot yet available
+ +## Mount Point Comparison + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MMCVMMEngine
Global mount pointsbefore runbefore_runbefore_run
after runafter_runafter_run
Checkpoint relatedafter loading checkpointsNoneafter_load_checkpoint
before saving checkpointsNonebefore_save_checkpoint
Training relatedtriggered before trainingNonebefore_train
triggered after trainingNoneafter_train
before each epochbefore_train_epochbefore_train_epoch
after each epochafter_train_epochafter_train_epoch
before each iterationbefore_train_iterbefore_train_iter, with additional args: batch_idx and data_batch
after each iterationafter_train_iterafter_train_iter, with additional args: batch_idx、data_batch, and outputs
Validation relatedbefore validationNonebefore_val
after validationNoneafter_val
before each epochbefore_val_epochbefore_val_epoch
after each epochafter_val_epochafter_val_epoch
before each iterationbefore_val_iterbefore_val_iter, with additional args: batch_idx and data_batch
after each iterationafter_val_iterafter_val_iter, with additional args: batch_idx、data_batch and outputs
Test relatedbefore testNonebefore_test
after testNoneafter_test
before each epochNonebefore_test_epoch
after each epochNoneafter_test_epoch
before each iterationNonebefore_test_iter, with additional args: batch_idx and data_batch
after each iterationNoneafter_test_iter, with additional args: batch_idx、data_batch and outputs
+ +## Usage Comparison + +In MMCV, to register hooks to the runner, you need to call the Runner's `register_training_hooks` method to register hooks to the Runner. In MMEngine, you can register hooks by passing them as parameters to the Runner's initialization method. + +- MMCV + +```python +model = ResNet18() +optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) +lr_config = dict(policy='step', step=[2, 3]) +optimizer_config = dict(grad_clip=None) +checkpoint_config = dict(interval=5) +log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')]) +custom_hooks = [dict(type='NumClassCheckHook')] +runner = EpochBasedRunner( + model=model, + optimizer=optimizer, + work_dir='./work_dir', + max_epochs=3, + xxx, +) +runner.register_training_hooks( + lr_config=lr_config, + optimizer_config=optimizer_config, + checkpoint_config=checkpoint_config, + log_config=log_config, + custom_hooks_config=custom_hooks, +) +runner.run([trainloader], [('train', 1)]) +``` + +- MMEngine + +```python +model=ResNet18() +optim_wrapper=dict( + type='OptimizerWrapper', + optimizer=dict(type='SGD', lr=0.001, momentum=0.9)) +param_scheduler = dict(type='MultiStepLR', milestones=[2, 3]), +default_hooks = dict( + logger=dict(type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=5), +) +custom_hooks = [dict(type='NumClassCheckHook')] +runner = Runner( + model=model, + work_dir='./work_dir', + optim_wrapper=optim_wrapper, + param_scheduler=param_scheduler, + train_cfg=dict(by_epoch=True, max_epochs=3), + default_hooks=default_hooks, + custom_hooks=custom_hooks, + xxx, +) +runner.train() +``` + +For more details of MMEngine hooks, please refer to [Usage of Hooks](../tutorials/hook.md). + +## Implementation Comparison + +Taking `CheckpointHook` as an example, compared with [CheckpointHook](https://github.com/open-mmlab/mmcv/blob/v1.6.0/mmcv/runner/hooks/checkpoint.py) in MMCV, [CheckpointHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/checkpoint_hook.py) of MMEngine needs to implement the `after_val_epoch` method, since new `CheckpointHook` supports saving the optimal weights, while in MMCV, the function is achieved by EvalHook. + +- MMCV + +```python +class CheckpointHook(Hook): + def before_run(self, runner): + """Initialize out_dir and file_client""" + + def after_train_epoch(self, runner): + """Synchronize buffer and save model weights, for tasks trained in epochs""" + + def after_train_iter(self, runner): + """Synchronize buffers and save model weights for tasks trained in iterations""" +``` + +- MMEngine + +```python +class CheckpointHook(Hook): + def before_run(self, runner): + """Initialize out_dir and file_client""" + + def after_train_epoch(self, runner): + """Synchronize buffer and save model weights, for tasks trained in epochs""" + + def after_train_iter(self, runner, batch_idx, data_batch, outputs): + """Synchronize buffers and save model weights for tasks trained in iterations""" + + def after_val_epoch(self, runner, metrics): + """Save optimal weights according to metrics""" +```