mmengine/docs/zh_cn/common_usage/large_model_training.md
2023-07-05 18:20:07 +08:00

10 KiB
Raw Blame History

大模型训练

在训练大模型时,需要庞大的资源。单卡显存通常不能满足训练的需要,因此出现了大模型训练技术,其中典型的一种是 DeepSpeed ZeRO。DeepSpedd ZeRO 支持切分优化器、梯度以及参数。

为了更加灵活地支持大模型训练技术,从 MMEngine v0.8.0 开始,我们提供了新的执行器 FlexibleRunner 和多个抽象策略 Strategy

新的执行器 FlexibleRunner 和 Strategy 还处于实验性阶段,在将来的版本中,它们的接口有可能会发生变化。

下面的示例代码摘自 examples/distributed_training_with_flexible_runner.py

DeepSpeed

DeepSpeed 是微软开源的基于 PyTorch 的分布式框架,其支持了 ZeRO, 3D-Parallelism, DeepSpeed-MoE, ZeRO-Infinity 等训练策略。 MMEngine 自 v0.8.0 开始支持使用 DeepSpeed 进行模型的训练。

使用 DeepSpeed 前需安装 deepspeed

pip install deepspeed

安装好 deepspeed 后,需配置 FlexibleRunner 的 strategy 和 optim_wrapper 参数:

  • strategy指定 type='DeepSpeedStrategy' 并配置参数。参数的详细介绍可阅读 DeepSpeedStrategy
  • optim_wrapper指定 type='DeepSpeedOptimWrapper' 并配置参数。参数的详细介绍可阅读 DeepSpeedOptimWrapper

下面是 DeepSpeed 相关的配置:

from mmengine.runner._flexible_runner import FlexibleRunner

# 指定 DeepSpeedStrategy 并配置参数
strategy = dict(
    type='DeepSpeedStrategy',
    fp16=dict(
        enabled=True,
        fp16_master_weights_and_grads=False,
        loss_scale=0,
        loss_scale_window=500,
        hysteresis=2,
        min_loss_scale=1,
        initial_scale_power=15,
    ),
    inputs_to_half=[0],
    zero_optimization=dict(
        stage=3,
        allgather_partitions=True,
        reduce_scatter=True,
        allgather_bucket_size=50000000,
        reduce_bucket_size=50000000,
        overlap_comm=True,
        contiguous_gradients=True,
        cpu_offload=False),
)

# 指定 DeepSpeedOptimWrapper 并配置参数
optim_wrapper = dict(
    type='DeepSpeedOptimWrapper',
    optimizer=dict(type='AdamW', lr=1e-3))

# 初始化 FlexibleRunner
runner = FlexibleRunner(
    model=MMResNet50(),
    work_dir='./work_dirs',
    strategy=strategy,
    train_dataloader=train_dataloader,
    optim_wrapper=optim_wrapper,
    param_scheduler=dict(type='LinearLR'),
    train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy))

# 开始训练
runner.train()

使用两张卡启动分布式训练:

torchrun --nproc-per-node 2 examples/distributed_training_with_flexible_runner.py --use-deepspeed
训练日志
07/03 13:04:17 - mmengine - INFO - Epoch(train)  [1][ 10/196]  lr: 3.3333e-04  eta: 0:13:14  time: 0.4073  data_time: 0.0335  memory: 970  loss: 6.1887
07/03 13:04:19 - mmengine - INFO - Epoch(train)  [1][ 20/196]  lr: 3.3333e-04  eta: 0:09:39  time: 0.1904  data_time: 0.0327  memory: 970  loss: 2.5746
07/03 13:04:21 - mmengine - INFO - Epoch(train)  [1][ 30/196]  lr: 3.3333e-04  eta: 0:08:32  time: 0.1993  data_time: 0.0342  memory: 970  loss: 2.4180
07/03 13:04:23 - mmengine - INFO - Epoch(train)  [1][ 40/196]  lr: 3.3333e-04  eta: 0:08:01  time: 0.2052  data_time: 0.0368  memory: 970  loss: 2.3682
07/03 13:04:25 - mmengine - INFO - Epoch(train)  [1][ 50/196]  lr: 3.3333e-04  eta: 0:07:39  time: 0.2013  data_time: 0.0356  memory: 970  loss: 2.3025
07/03 13:04:27 - mmengine - INFO - Epoch(train)  [1][ 60/196]  lr: 3.3333e-04  eta: 0:07:25  time: 0.2025  data_time: 0.0353  memory: 970  loss: 2.2078
07/03 13:04:29 - mmengine - INFO - Epoch(train)  [1][ 70/196]  lr: 3.3333e-04  eta: 0:07:13  time: 0.1999  data_time: 0.0352  memory: 970  loss: 2.2045
07/03 13:04:31 - mmengine - INFO - Epoch(train)  [1][ 80/196]  lr: 3.3333e-04  eta: 0:07:04  time: 0.2013  data_time: 0.0350  memory: 970  loss: 2.1709
07/03 13:04:33 - mmengine - INFO - Epoch(train)  [1][ 90/196]  lr: 3.3333e-04  eta: 0:06:56  time: 0.1975  data_time: 0.0341  memory: 970  loss: 2.2070
07/03 13:04:35 - mmengine - INFO - Epoch(train)  [1][100/196]  lr: 3.3333e-04  eta: 0:06:49  time: 0.1993  data_time: 0.0347  memory: 970  loss: 2.0891
07/03 13:04:37 - mmengine - INFO - Epoch(train)  [1][110/196]  lr: 3.3333e-04  eta: 0:06:44  time: 0.1995  data_time: 0.0357  memory: 970  loss: 2.0700
07/03 13:04:39 - mmengine - INFO - Epoch(train)  [1][120/196]  lr: 3.3333e-04  eta: 0:06:38  time: 0.1966  data_time: 0.0342  memory: 970  loss: 1.9983
07/03 13:04:41 - mmengine - INFO - Epoch(train)  [1][130/196]  lr: 3.3333e-04  eta: 0:06:37  time: 0.2216  data_time: 0.0341  memory: 970  loss: 1.9409
07/03 13:04:43 - mmengine - INFO - Epoch(train)  [1][140/196]  lr: 3.3333e-04  eta: 0:06:32  time: 0.1944  data_time: 0.0336  memory: 970  loss: 1.9800
07/03 13:04:45 - mmengine - INFO - Epoch(train)  [1][150/196]  lr: 3.3333e-04  eta: 0:06:27  time: 0.1946  data_time: 0.0338  memory: 970  loss: 1.9356
07/03 13:04:47 - mmengine - INFO - Epoch(train)  [1][160/196]  lr: 3.3333e-04  eta: 0:06:22  time: 0.1937  data_time: 0.0333  memory: 970  loss: 1.8145
07/03 13:04:49 - mmengine - INFO - Epoch(train)  [1][170/196]  lr: 3.3333e-04  eta: 0:06:18  time: 0.1941  data_time: 0.0335  memory: 970  loss: 1.8525
07/03 13:04:51 - mmengine - INFO - Epoch(train)  [1][180/196]  lr: 3.3333e-04  eta: 0:06:17  time: 0.2204  data_time: 0.0341  memory: 970  loss: 1.7637
07/03 13:04:53 - mmengine - INFO - Epoch(train)  [1][190/196]  lr: 3.3333e-04  eta: 0:06:14  time: 0.1998  data_time: 0.0345  memory: 970  loss: 1.7523

FullyShardedDataParallel (FSDP)

PyTorch 从 v1.11 版本开始支持 FullyShardedDataParallel 训练,但由于其接口一直处于变动中,我们只支持 PyTorch v2.0.0 及以上的版本。

使用 FSDP 需配置 FlexibleRunner 的 strategy 参数:指定 type='FSDPStrategy' 并配置参数。参数的详细介绍可阅读 FSDPStrategy

下面是 FSDP 相关的配置:

from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
size_based_auto_wrap_policy = partial(
    size_based_auto_wrap_policy, min_num_params=1e7)

# 指定 FSDPStrategy 并配置参数
strategy = dict(
    type='FSDPStrategy',
    model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy))

# 指定 AmpOptimWrapper 并配置参数
optim_wrapper = dict(
    type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3))

# 初始化 FlexibleRunner
runner = FlexibleRunner(
    model=MMResNet50(),
    work_dir='./work_dirs',
    strategy=strategy,
    train_dataloader=train_dataloader,
    optim_wrapper=optim_wrapper,
    param_scheduler=dict(type='LinearLR'),
    train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy))

# 开始训练
runner.train()

使用两张卡启动分布式训练:

torchrun --nproc-per-node 2 examples/distributed_training_with_flexible_runner.py --use-fsdp
训练日志
07/03 13:05:37 - mmengine - INFO - Epoch(train)  [1][ 10/196]  lr: 3.3333e-04  eta: 0:08:28  time: 0.2606  data_time: 0.0330  memory: 954  loss: 6.1265
07/03 13:05:38 - mmengine - INFO - Epoch(train)  [1][ 20/196]  lr: 3.3333e-04  eta: 0:05:18  time: 0.0673  data_time: 0.0325  memory: 954  loss: 2.5584
07/03 13:05:39 - mmengine - INFO - Epoch(train)  [1][ 30/196]  lr: 3.3333e-04  eta: 0:04:13  time: 0.0666  data_time: 0.0320  memory: 954  loss: 2.4816
07/03 13:05:39 - mmengine - INFO - Epoch(train)  [1][ 40/196]  lr: 3.3333e-04  eta: 0:03:41  time: 0.0666  data_time: 0.0321  memory: 954  loss: 2.3695
07/03 13:05:40 - mmengine - INFO - Epoch(train)  [1][ 50/196]  lr: 3.3333e-04  eta: 0:03:21  time: 0.0671  data_time: 0.0324  memory: 954  loss: 2.3208
07/03 13:05:41 - mmengine - INFO - Epoch(train)  [1][ 60/196]  lr: 3.3333e-04  eta: 0:03:08  time: 0.0667  data_time: 0.0320  memory: 954  loss: 2.2431
07/03 13:05:41 - mmengine - INFO - Epoch(train)  [1][ 70/196]  lr: 3.3333e-04  eta: 0:02:58  time: 0.0667  data_time: 0.0320  memory: 954  loss: 2.1873
07/03 13:05:42 - mmengine - INFO - Epoch(train)  [1][ 80/196]  lr: 3.3333e-04  eta: 0:02:51  time: 0.0669  data_time: 0.0320  memory: 954  loss: 2.2006
07/03 13:05:43 - mmengine - INFO - Epoch(train)  [1][ 90/196]  lr: 3.3333e-04  eta: 0:02:45  time: 0.0671  data_time: 0.0324  memory: 954  loss: 2.1547
07/03 13:05:43 - mmengine - INFO - Epoch(train)  [1][100/196]  lr: 3.3333e-04  eta: 0:02:40  time: 0.0667  data_time: 0.0321  memory: 954  loss: 2.1361
07/03 13:05:44 - mmengine - INFO - Epoch(train)  [1][110/196]  lr: 3.3333e-04  eta: 0:02:36  time: 0.0668  data_time: 0.0320  memory: 954  loss: 2.0405
07/03 13:05:45 - mmengine - INFO - Epoch(train)  [1][120/196]  lr: 3.3333e-04  eta: 0:02:32  time: 0.0669  data_time: 0.0320  memory: 954  loss: 2.0228
07/03 13:05:45 - mmengine - INFO - Epoch(train)  [1][130/196]  lr: 3.3333e-04  eta: 0:02:29  time: 0.0670  data_time: 0.0324  memory: 954  loss: 2.0375
07/03 13:05:46 - mmengine - INFO - Epoch(train)  [1][140/196]  lr: 3.3333e-04  eta: 0:02:26  time: 0.0664  data_time: 0.0320  memory: 954  loss: 1.9926
07/03 13:05:47 - mmengine - INFO - Epoch(train)  [1][150/196]  lr: 3.3333e-04  eta: 0:02:24  time: 0.0668  data_time: 0.0320  memory: 954  loss: 1.9820
07/03 13:05:47 - mmengine - INFO - Epoch(train)  [1][160/196]  lr: 3.3333e-04  eta: 0:02:22  time: 0.0674  data_time: 0.0325  memory: 954  loss: 1.9728
07/03 13:05:48 - mmengine - INFO - Epoch(train)  [1][170/196]  lr: 3.3333e-04  eta: 0:02:20  time: 0.0666  data_time: 0.0320  memory: 954  loss: 1.9359
07/03 13:05:49 - mmengine - INFO - Epoch(train)  [1][180/196]  lr: 3.3333e-04  eta: 0:02:18  time: 0.0667  data_time: 0.0321  memory: 954  loss: 1.9488
07/03 13:05:49 - mmengine - INFO - Epoch(train)  [1][190/196]  lr: 3.3333e-04  eta: 0:02:16  time: 0.0671  data_time: 0.0323  memory: 954  loss: 1.9023\