diff --git a/docs/en/common_usage/save_gpu_memory.md b/docs/en/common_usage/save_gpu_memory.md index e85ff3a8..f691d10f 100644 --- a/docs/en/common_usage/save_gpu_memory.md +++ b/docs/en/common_usage/save_gpu_memory.md @@ -68,6 +68,35 @@ runner = Runner( runner.train() ``` +## Gradient Checkpointing + +```{note} +Starting from MMEngine v0.8.5, gradient checkpointing is supported. For performance comparisons, you can click on [#1319](https://github.com/open-mmlab/mmengine/pull/1319). If you encounter any issues during usage, feel free to provide feedback in [#1319](https://github.com/open-mmlab/mmengine/pull/1319). +``` + +You can simply enable gradient checkpointing by configuring activation_checkpointing in the Runner's cfg parameters. + +Let's take [Get Started in 15 Minutes](../get_started/15_minutes.md) as an example: + +```python +cfg = dict( + activation_checkpointing=['resnet.layer1', 'resnet.layer2', 'resnet.layer3'] +) +runner = Runner( + model=MMResNet50(), + work_dir='./work_dir', + train_dataloader=train_dataloader, + optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), + train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), + val_dataloader=val_dataloader, + val_cfg=dict(), + val_evaluator=dict(type=Accuracy), + launcher=args.launcher, + cfg=cfg, +) +runner.train() +``` + ## Large Model Training ```{warning} diff --git a/docs/zh_cn/common_usage/save_gpu_memory.md b/docs/zh_cn/common_usage/save_gpu_memory.md index efa95264..29dc0ba8 100644 --- a/docs/zh_cn/common_usage/save_gpu_memory.md +++ b/docs/zh_cn/common_usage/save_gpu_memory.md @@ -68,6 +68,35 @@ runner = Runner( runner.train() ``` +## 梯度检查点 + +```{note} +MMEngine v0.8.5 开始支持梯度检查点的功能。关于性能的比较可点击 [#1319](https://github.com/open-mmlab/mmengine/pull/1319)。如果你在使用过程中遇到任何问题,欢迎在 [#1319](https://github.com/open-mmlab/mmengine/pull/1319) 反馈。 +``` + +只需在 Runner 的 cfg 参数中配置 `activation_checkpointing` 即可开启梯度检查点。 + +以[15 分钟上手 MMEngine](../get_started/15_minutes.md) 为例: + +```python +cfg = dict( + activation_checkpointing=['resnet.layer1', 'resnet.layer2', 'resnet.layer3'] +) +runner = Runner( + model=MMResNet50(), + work_dir='./work_dir', + train_dataloader=train_dataloader, + optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), + train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), + val_dataloader=val_dataloader, + val_cfg=dict(), + val_evaluator=dict(type=Accuracy), + launcher=args.launcher, + cfg=cfg, +) +runner.train() +``` + ## 大模型训练 ```{warning}