mmengine/docs/en/common_usage/speed_up_training.md
Qian Zhao 2ed8e343a0
[Feature] Enable bf16 in AmpOptimWrapper (#960)
* support bf16 in AmpOptimWrapper

* add docstring

* modify docs

* add unittests for bf16 in AmpOptimWrapper

* fix type

* fix to pass ci

* fix ut skip logic to pass ci

* fix as comment

* add type hints

* fix docstring and add warning information

* remove check for pytorch>=1.6 in unittest

* modify unittest

* modify unittest

* remove torch.float32 && torch.float64 from valid dtypes

* fix as comments

* minor refine docstring

* fix unittest parameterized to pass CI

* fix unittest && add back torch.float32, torch.float64
2023-03-01 21:35:18 +08:00

3.6 KiB

Speed up Training

Distributed Training

MMEngine supports training models with CPU, single GPU, multiple GPUs in single machine and multiple machines. When multiple GPUs are available in the environment, we can use the following command to enable multiple GPUs in single machine or multiple machines to shorten the training time of the model.

  • multiple GPUs in single machine

    Assuming the current machine has 8 GPUs, you can enable multiple GPUs training with the following command:

    python -m torch.distributed.launch --nproc_per_node=8 examples/train.py --launcher pytorch
    

    If you need to specify the GPU index, you can set the CUDA_VISIBLE_DEVICES environment variable, e.g. use the 0th and 3rd GPU.

    CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 examples/train.py --launcher pytorch
    
  • multiple machines

    Assume that there are 2 machines connected with ethernet, you can simply run following commands.

    On the first machine:

    python -m torch.distributed.launch \
        --nnodes 8 \
        --node_rank 0 \
        --master_addr 127.0.0.1 \
        --master_port 29500 \
        --nproc_per_node=8 \
        examples/train.py --launcher pytorch
    

    On the second machine:

    python -m torch.distributed.launch \
        --nnodes 8 \
        --node_rank 1 \
        --master_addr 127.0.0.1 \
        --master_port 29500 \
        --nproc_per_node=8 \
    

    If you are running MMEngine in a slurm cluster, simply run the following command to enable training for 2 machines and 16 GPUs.

    srun -p mm_dev \
        --job-name=test \
        --gres=gpu:8 \
        --ntasks=16 \
        --ntasks-per-node=8 \
        --cpus-per-task=5 \
        --kill-on-bad-exit=1 \
        python examples/train.py --launcher="slurm"
    

Mixed Precision Training

Nvidia introduced the Tensor Core unit into the Volta and Turing architectures to support FP32 and FP16 mixed precision computing. They further support BF16 in Ampere architectures. With automatic mixed precision training enabled, some operators operate at FP16/BF16 and the rest operate at FP32, which reduces training time and storage requirements without changing the model or degrading its training precision, thus supporting training with larger batch sizes, larger models, and larger input sizes.

PyTorch officially supports amp from 1.6. If you are interested in the implementation of automatic mixing precision, you can refer to Mixed Precision Training.

MMEngine provides the wrapper AmpOptimWrapper for auto-mixing precision training, just set type='AmpOptimWrapper' in optim_wrapper to enable auto-mixing precision training, no other code changes are needed.

runner = Runner(
    model=ResNet18(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader_cfg,
    optim_wrapper=dict(
        type='AmpOptimWrapper',
        # If you want to use bfloat16, uncomment the following line
        # dtype='bfloat16',  # valid values: ('float16', 'bfloat16', None)
        optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=3),
)
runner.train()
Up till PyTorch 1.13, `torch.bfloat16` performance on `Convolution` is bad unless manually set environment variable `TORCH_CUDNN_V8_API_ENABLED=1`. More context at [PyTorch issue](https://github.com/pytorch/pytorch/issues/57707#issuecomment-1166656767)