mmengine/docs/zh_cn/common_usage/distributed_training.md

3.1 KiB
Raw Blame History

分布式训练

MMEngine 支持 CPU、单卡、单机多卡以及多机多卡的训练。当环境中有多张显卡时我们可以使用以下命令开启单机多卡或者多机多卡的方式从而缩短模型的训练时间。

启动训练

单机多卡

假设当前机器有 8 张显卡,可以使用以下命令开启多卡训练

python -m torch.distributed.launch --nproc_per_node=8 examples/distributed_training.py --launcher pytorch

如果需要指定显卡的编号,可以设置 CUDA_VISIBLE_DEVICES 环境变量,例如使用第 0 和第 3 张卡

CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 examples/distributed_training.py --launcher pytorch

多机多卡

假设有 2 台机器,每台机器有 8 张卡。

第一台机器运行以下命令

python -m torch.distributed.launch \
    --nnodes 8 \
    --node_rank 0 \
    --master_addr 127.0.0.1 \
    --master_port 29500 \
    --nproc_per_node=8 \
    examples/distributed_training.py --launcher pytorch

第 2 台机器运行以下命令

python -m torch.distributed.launch \
    --nnodes 8 \
    --node_rank 1 \
    --master_addr 127.0.0.1 \
    --master_port 29500 \
    --nproc_per_node=8 \
    examples/distributed_training.py --launcher pytorch

如果在 slurm 集群运行 MMEngine只需运行以下命令即可开启 2 机 16 卡的训练

srun -p mm_dev \
    --job-name=test \
    --gres=gpu:8 \
    --ntasks=16 \
    --ntasks-per-node=8 \
    --cpus-per-task=5 \
    --kill-on-bad-exit=1 \
    python examples/distributed_training.py --launcher="slurm"

定制化分布式训练

当用户从单卡训练切换到多卡训练时,无需做任何改动,Runner 会默认使用 MMDistributedDataParallel 封装 model从而支持多卡训练。 如果你希望给 MMDistributedDataParallel 传入更多的参数或者使用自定义的 CustomDistributedDataParallel,你可以设置 model_wrapper_cfg

往 MMDistributedDataParallel 传入更多的参数

例如设置 find_unused_parametersTrue

cfg = dict(
    model_wrapper_cfg=dict(
        type='MMDistributedDataParallel', find_unused_parameters=True)
)
runner = Runner(
    model=ResNet18(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader_cfg,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=3),
    cfg=cfg,
)
runner.train()

使用自定义的 CustomDistributedDataParallel

from mmengine.registry import MODEL_WRAPPERS

@MODEL_WRAPPERS.register_module()
class CustomDistributedDataParallel(DistributedDataParallel):
    pass


cfg = dict(model_wrapper_cfg=dict(type='CustomDistributedDataParallel'))
runner = Runner(
    model=ResNet18(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader_cfg,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=3),
    cfg=cfg,
)
runner.train()