From a5284165c687909f3476a9523946ab91e242b24e Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Tue, 30 Aug 2022 19:05:51 +0800 Subject: [PATCH] Add the distributed training script (#487) * Add the distributed training script * fix md format --- docs/zh_cn/examples/speed_up_training.md | 6 +- examples/distributed_training.py | 98 ++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 examples/distributed_training.py diff --git a/docs/zh_cn/examples/speed_up_training.md b/docs/zh_cn/examples/speed_up_training.md index 131bf88c..6bf9456d 100644 --- a/docs/zh_cn/examples/speed_up_training.md +++ b/docs/zh_cn/examples/speed_up_training.md @@ -61,9 +61,11 @@ srun -p mm_dev \ ## 混合精度训练 -Nvidia 在 Volta 和 Turing 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。开启自动混合精度训练后,部分算子的操作精度是 FP16,其余算子的操作精度是 FP32。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。[PyTorch 从 1.6 开始官方支持 amp](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/)。如果你对自动混合精度的实现感兴趣,可以阅读 [torch.cuda.amp: 自动混合精度详解](https://zhuanlan.zhihu.com/p/348554267)。 +Nvidia 在 Volta 和 Turing 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。开启自动混合精度训练后,部分算子的操作精度是 FP16,其余算子的操作精度是 FP32。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。 -MMEngine 提供自动混合精度的封装 [AmpOptimWrapper](https://mmengine.readthedocs.io/zh_cn/latest/api.html#mmengine.optim.AmpOptimWrapper) ,只需在 `optim_wrapper` 设置 `type='AmpOptimWrapper'` 即可开启自动混合精度训练,无需对代码做其他修改。 +[PyTorch 从 1.6 开始官方支持 amp](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/)。如果你对自动混合精度的实现感兴趣,可以阅读 [torch.cuda.amp: 自动混合精度详解](https://zhuanlan.zhihu.com/p/348554267)。 + +MMEngine 提供自动混合精度的封装 [AmpOptimWrapper](mmengine.optim.AmpOptimWrapper) ,只需在 `optim_wrapper` 设置 `type='AmpOptimWrapper'` 即可开启自动混合精度训练,无需对代码做其他修改。 ```python runner = Runner( diff --git a/examples/distributed_training.py b/examples/distributed_training.py new file mode 100644 index 00000000..6910c6dd --- /dev/null +++ b/examples/distributed_training.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.optim import SGD +from torch.utils.data import DataLoader + +from mmengine.evaluator import BaseMetric +from mmengine.model import BaseModel +from mmengine.runner import Runner + + +class MMResNet50(BaseModel): + + def __init__(self): + super().__init__() + self.resnet = torchvision.models.resnet50() + + def forward(self, imgs, labels, mode): + x = self.resnet(imgs) + if mode == 'loss': + return {'loss': F.cross_entropy(x, labels)} + elif mode == 'predict': + return x, labels + + +class Accuracy(BaseMetric): + + def process(self, data_batch, data_samples): + score, gt = data_samples + self.results.append({ + 'batch_size': len(gt), + 'correct': (score.argmax(dim=1) == gt).sum().cpu(), + }) + + def compute_metrics(self, results): + total_correct = sum(item['correct'] for item in results) + total_size = sum(item['batch_size'] for item in results) + return dict(accuracy=100 * total_correct / total_size) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Distributed Training') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201]) + train_dataloader = DataLoader( + batch_size=32, + shuffle=True, + dataset=torchvision.datasets.CIFAR10( + 'data/cifar10', + train=True, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(**norm_cfg) + ]))) + val_dataloader = DataLoader( + batch_size=32, + shuffle=False, + dataset=torchvision.datasets.CIFAR10( + 'data/cifar10', + train=False, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize(**norm_cfg)]))) + runner = Runner( + model=MMResNet50(), + work_dir='./work_dir', + train_dataloader=train_dataloader, + optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), + train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), + val_dataloader=val_dataloader, + val_cfg=dict(), + val_evaluator=dict(type=Accuracy), + launcher=args.launcher, + ) + runner.train() + + +if __name__ == '__main__': + main()