144 lines
4.8 KiB
Python
144 lines
4.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
|
|
from mmengine.evaluator import BaseMetric
|
|
from mmengine.model import BaseModel
|
|
from mmengine.runner._flexible_runner import FlexibleRunner
|
|
|
|
|
|
class MMResNet50(BaseModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.resnet = torchvision.models.resnet50()
|
|
|
|
def forward(self, imgs, labels, mode):
|
|
x = self.resnet(imgs)
|
|
if mode == 'loss':
|
|
return {'loss': F.cross_entropy(x, labels)}
|
|
elif mode == 'predict':
|
|
return x, labels
|
|
|
|
|
|
class Accuracy(BaseMetric):
|
|
|
|
def process(self, data_batch, data_samples):
|
|
score, gt = data_samples
|
|
self.results.append({
|
|
'batch_size': len(gt),
|
|
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
|
|
})
|
|
|
|
def compute_metrics(self, results):
|
|
total_correct = sum(item['correct'] for item in results)
|
|
total_size = sum(item['batch_size'] for item in results)
|
|
return dict(accuracy=100 * total_correct / total_size)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Distributed Training')
|
|
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
|
|
parser.add_argument('--use-fsdp', action='store_true')
|
|
parser.add_argument('--use-deepspeed', action='store_true')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
|
|
train_set = torchvision.datasets.CIFAR10(
|
|
'data/cifar10',
|
|
train=True,
|
|
download=True,
|
|
transform=transforms.Compose([
|
|
transforms.RandomCrop(32, padding=4),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(**norm_cfg)
|
|
]))
|
|
valid_set = torchvision.datasets.CIFAR10(
|
|
'data/cifar10',
|
|
train=False,
|
|
download=True,
|
|
transform=transforms.Compose(
|
|
[transforms.ToTensor(),
|
|
transforms.Normalize(**norm_cfg)]))
|
|
train_dataloader = dict(
|
|
batch_size=128,
|
|
dataset=train_set,
|
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
|
collate_fn=dict(type='default_collate'))
|
|
val_dataloader = dict(
|
|
batch_size=128,
|
|
dataset=valid_set,
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
collate_fn=dict(type='default_collate'))
|
|
|
|
if args.use_deepspeed:
|
|
strategy = dict(
|
|
type='DeepSpeedStrategy',
|
|
fp16=dict(
|
|
enabled=True,
|
|
fp16_master_weights_and_grads=False,
|
|
loss_scale=0,
|
|
loss_scale_window=500,
|
|
hysteresis=2,
|
|
min_loss_scale=1,
|
|
initial_scale_power=15,
|
|
),
|
|
inputs_to_half=[0],
|
|
zero_optimization=dict(
|
|
stage=3,
|
|
allgather_partitions=True,
|
|
reduce_scatter=True,
|
|
allgather_bucket_size=50000000,
|
|
reduce_bucket_size=50000000,
|
|
overlap_comm=True,
|
|
contiguous_gradients=True,
|
|
cpu_offload=False),
|
|
)
|
|
optim_wrapper = dict(
|
|
type='DeepSpeedOptimWrapper',
|
|
optimizer=dict(type='AdamW', lr=1e-3))
|
|
elif args.use_fsdp:
|
|
from functools import partial
|
|
|
|
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
|
size_based_auto_wrap_policy = partial(
|
|
size_based_auto_wrap_policy, min_num_params=1e7)
|
|
strategy = dict(
|
|
type='FSDPStrategy',
|
|
model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy))
|
|
optim_wrapper = dict(
|
|
type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3))
|
|
else:
|
|
strategy = None
|
|
optim_wrapper = dict(
|
|
type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3))
|
|
|
|
runner = FlexibleRunner(
|
|
model=MMResNet50(),
|
|
work_dir='./work_dirs',
|
|
strategy=strategy,
|
|
train_dataloader=train_dataloader,
|
|
optim_wrapper=optim_wrapper,
|
|
param_scheduler=dict(type='LinearLR'),
|
|
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
|
|
val_dataloader=val_dataloader,
|
|
val_cfg=dict(),
|
|
val_evaluator=dict(type=Accuracy))
|
|
runner.train()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-fsdp # noqa: 501
|
|
# torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-deepspeed # noqa: 501
|
|
# torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py
|
|
# python distributed_training_with_flexible_runner.py
|
|
main()
|