Add the distributed training script (#487)
* Add the distributed training script * fix md formatpull/404/head^2
parent
fbd31d2041
commit
a5284165c6
docs/zh_cn/examples
examples
|
@ -61,9 +61,11 @@ srun -p mm_dev \
|
|||
|
||||
## 混合精度训练
|
||||
|
||||
Nvidia 在 Volta 和 Turing 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。开启自动混合精度训练后,部分算子的操作精度是 FP16,其余算子的操作精度是 FP32。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。[PyTorch 从 1.6 开始官方支持 amp](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/)。如果你对自动混合精度的实现感兴趣,可以阅读 [torch.cuda.amp: 自动混合精度详解](https://zhuanlan.zhihu.com/p/348554267)。
|
||||
Nvidia 在 Volta 和 Turing 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。开启自动混合精度训练后,部分算子的操作精度是 FP16,其余算子的操作精度是 FP32。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。
|
||||
|
||||
MMEngine 提供自动混合精度的封装 [AmpOptimWrapper](https://mmengine.readthedocs.io/zh_cn/latest/api.html#mmengine.optim.AmpOptimWrapper) ,只需在 `optim_wrapper` 设置 `type='AmpOptimWrapper'` 即可开启自动混合精度训练,无需对代码做其他修改。
|
||||
[PyTorch 从 1.6 开始官方支持 amp](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/)。如果你对自动混合精度的实现感兴趣,可以阅读 [torch.cuda.amp: 自动混合精度详解](https://zhuanlan.zhihu.com/p/348554267)。
|
||||
|
||||
MMEngine 提供自动混合精度的封装 [AmpOptimWrapper](mmengine.optim.AmpOptimWrapper) ,只需在 `optim_wrapper` 设置 `type='AmpOptimWrapper'` 即可开启自动混合精度训练,无需对代码做其他修改。
|
||||
|
||||
```python
|
||||
runner = Runner(
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch.optim import SGD
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
class MMResNet50(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = torchvision.models.resnet50()
|
||||
|
||||
def forward(self, imgs, labels, mode):
|
||||
x = self.resnet(imgs)
|
||||
if mode == 'loss':
|
||||
return {'loss': F.cross_entropy(x, labels)}
|
||||
elif mode == 'predict':
|
||||
return x, labels
|
||||
|
||||
|
||||
class Accuracy(BaseMetric):
|
||||
|
||||
def process(self, data_batch, data_samples):
|
||||
score, gt = data_samples
|
||||
self.results.append({
|
||||
'batch_size': len(gt),
|
||||
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
|
||||
})
|
||||
|
||||
def compute_metrics(self, results):
|
||||
total_correct = sum(item['correct'] for item in results)
|
||||
total_size = sum(item['batch_size'] for item in results)
|
||||
return dict(accuracy=100 * total_correct / total_size)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Distributed Training')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
|
||||
train_dataloader = DataLoader(
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
dataset=torchvision.datasets.CIFAR10(
|
||||
'data/cifar10',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(**norm_cfg)
|
||||
])))
|
||||
val_dataloader = DataLoader(
|
||||
batch_size=32,
|
||||
shuffle=False,
|
||||
dataset=torchvision.datasets.CIFAR10(
|
||||
'data/cifar10',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize(**norm_cfg)])))
|
||||
runner = Runner(
|
||||
model=MMResNet50(),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=Accuracy),
|
||||
launcher=args.launcher,
|
||||
)
|
||||
runner.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue