mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Docs] Add activation checkpointing usage (#1341)
This commit is contained in:
parent
ccd17571ce
commit
45ee96d0c4
@ -68,6 +68,35 @@ runner = Runner(
|
||||
runner.train()
|
||||
```
|
||||
|
||||
## Gradient Checkpointing
|
||||
|
||||
```{note}
|
||||
Starting from MMEngine v0.8.5, gradient checkpointing is supported. For performance comparisons, you can click on [#1319](https://github.com/open-mmlab/mmengine/pull/1319). If you encounter any issues during usage, feel free to provide feedback in [#1319](https://github.com/open-mmlab/mmengine/pull/1319).
|
||||
```
|
||||
|
||||
You can simply enable gradient checkpointing by configuring activation_checkpointing in the Runner's cfg parameters.
|
||||
|
||||
Let's take [Get Started in 15 Minutes](../get_started/15_minutes.md) as an example:
|
||||
|
||||
```python
|
||||
cfg = dict(
|
||||
activation_checkpointing=['resnet.layer1', 'resnet.layer2', 'resnet.layer3']
|
||||
)
|
||||
runner = Runner(
|
||||
model=MMResNet50(),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=Accuracy),
|
||||
launcher=args.launcher,
|
||||
cfg=cfg,
|
||||
)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
## Large Model Training
|
||||
|
||||
```{warning}
|
||||
|
@ -68,6 +68,35 @@ runner = Runner(
|
||||
runner.train()
|
||||
```
|
||||
|
||||
## 梯度检查点
|
||||
|
||||
```{note}
|
||||
MMEngine v0.8.5 开始支持梯度检查点的功能。关于性能的比较可点击 [#1319](https://github.com/open-mmlab/mmengine/pull/1319)。如果你在使用过程中遇到任何问题,欢迎在 [#1319](https://github.com/open-mmlab/mmengine/pull/1319) 反馈。
|
||||
```
|
||||
|
||||
只需在 Runner 的 cfg 参数中配置 `activation_checkpointing` 即可开启梯度检查点。
|
||||
|
||||
以[15 分钟上手 MMEngine](../get_started/15_minutes.md) 为例:
|
||||
|
||||
```python
|
||||
cfg = dict(
|
||||
activation_checkpointing=['resnet.layer1', 'resnet.layer2', 'resnet.layer3']
|
||||
)
|
||||
runner = Runner(
|
||||
model=MMResNet50(),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=Accuracy),
|
||||
launcher=args.launcher,
|
||||
cfg=cfg,
|
||||
)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
## 大模型训练
|
||||
|
||||
```{warning}
|
||||
|
Loading…
x
Reference in New Issue
Block a user