[Enhancement] Support deepspeed with flexible runner (#1673)

* [Feature] Support deepspeed with flexible runner

* [Fix] Reformat with yapf

* [Refacor] Rename configs

* [Fix] Reformat with yapf

* [Refactor] Remove unused keys

* [Refactor] Change the _base_ path

* [Refactor] Reformat
pull/1679/head
fanqiNO1 2023-06-29 10:16:27 +08:00 committed by GitHub
parent 68758db7a8
commit 658db80089
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 2 deletions

View File

@ -0,0 +1,32 @@
_base_ = ['./vit-huge-p14_8xb128-coslr-50e_in1k.py']
# optimizer wrapper
optim_wrapper = dict(type='DeepSpeedOptimWrapper')
# training strategy
# Deepspeed with ZeRO3 + fp16
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=['inputs'],
zero_optimization=dict(
stage=3,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False,
))
# runner which supports strategies
runner_type = 'FlexibleRunner'

View File

@ -0,0 +1,32 @@
_base_ = ['./vit-large-p16_8xb128-coslr-50e_in1k.py']
# optimizer wrapper
optim_wrapper = dict(type='DeepSpeedOptimWrapper')
# training strategy
# Deepspeed with ZeRO3 + fp16
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=['inputs'],
zero_optimization=dict(
stage=3,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False,
))
# runner which supports strategies
runner_type = 'FlexibleRunner'

View File

@ -7,6 +7,7 @@ from copy import deepcopy
import mmengine
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.evaluator import DumpResults
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
@ -169,7 +170,13 @@ def main():
cfg = merge_args(cfg, args)
# build the runner from config
runner = Runner.from_cfg(cfg)
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
if args.out and args.out_item in ['pred', None]:
runner.test_evaluator.metrics.append(

View File

@ -5,6 +5,7 @@ import os.path as osp
from copy import deepcopy
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
@ -149,7 +150,13 @@ def main():
cfg = merge_args(cfg, args)
# build the runner from config
runner = Runner.from_cfg(cfg)
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start training
runner.train()