[Docs] Add activation checkpointing usage (#1341)

This commit is contained in:
Zaida Zhou 2023-09-05 11:23:44 +08:00 committed by GitHub
parent ccd17571ce
commit 45ee96d0c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 0 deletions

View File

@ -68,6 +68,35 @@ runner = Runner(
runner.train()
```
## Gradient Checkpointing
```{note}
Starting from MMEngine v0.8.5, gradient checkpointing is supported. For performance comparisons, you can click on [#1319](https://github.com/open-mmlab/mmengine/pull/1319). If you encounter any issues during usage, feel free to provide feedback in [#1319](https://github.com/open-mmlab/mmengine/pull/1319).
```
You can simply enable gradient checkpointing by configuring activation_checkpointing in the Runner's cfg parameters.
Let's take [Get Started in 15 Minutes](../get_started/15_minutes.md) as an example:
```python
cfg = dict(
activation_checkpointing=['resnet.layer1', 'resnet.layer2', 'resnet.layer3']
)
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
launcher=args.launcher,
cfg=cfg,
)
runner.train()
```
## Large Model Training
```{warning}

View File

@ -68,6 +68,35 @@ runner = Runner(
runner.train()
```
## 梯度检查点
```{note}
MMEngine v0.8.5 开始支持梯度检查点的功能。关于性能的比较可点击 [#1319](https://github.com/open-mmlab/mmengine/pull/1319)。如果你在使用过程中遇到任何问题,欢迎在 [#1319](https://github.com/open-mmlab/mmengine/pull/1319) 反馈。
```
只需在 Runner 的 cfg 参数中配置 `activation_checkpointing` 即可开启梯度检查点。
以[15 分钟上手 MMEngine](../get_started/15_minutes.md) 为例:
```python
cfg = dict(
activation_checkpointing=['resnet.layer1', 'resnet.layer2', 'resnet.layer3']
)
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
launcher=args.launcher,
cfg=cfg,
)
runner.train()
```
## 大模型训练
```{warning}