[Enhancement] Support deepspeed with flexible runner (#1673)

* [Feature] Support deepspeed with flexible runner

* [Fix] Reformat with yapf

* [Refacor] Rename configs

* [Fix] Reformat with yapf

* [Refactor] Remove unused keys

* [Refactor] Change the _base_ path

* [Refactor] Reformat
pull/1679/head
fanqiNO1 2023-06-29 10:16:27 +08:00 committed by GitHub
parent 68758db7a8
commit 658db80089
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 2 deletions

View File

@ -0,0 +1,32 @@
_base_ = ['./vit-huge-p14_8xb128-coslr-50e_in1k.py']
# optimizer wrapper
optim_wrapper = dict(type='DeepSpeedOptimWrapper')
# training strategy
# Deepspeed with ZeRO3 + fp16
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=['inputs'],
zero_optimization=dict(
stage=3,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False,
))
# runner which supports strategies
runner_type = 'FlexibleRunner'

View File

@ -0,0 +1,32 @@
_base_ = ['./vit-large-p16_8xb128-coslr-50e_in1k.py']
# optimizer wrapper
optim_wrapper = dict(type='DeepSpeedOptimWrapper')
# training strategy
# Deepspeed with ZeRO3 + fp16
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=['inputs'],
zero_optimization=dict(
stage=3,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False,
))
# runner which supports strategies
runner_type = 'FlexibleRunner'

View File

@ -7,6 +7,7 @@ from copy import deepcopy
import mmengine import mmengine
from mmengine.config import Config, ConfigDict, DictAction from mmengine.config import Config, ConfigDict, DictAction
from mmengine.evaluator import DumpResults from mmengine.evaluator import DumpResults
from mmengine.registry import RUNNERS
from mmengine.runner import Runner from mmengine.runner import Runner
@ -169,7 +170,13 @@ def main():
cfg = merge_args(cfg, args) cfg = merge_args(cfg, args)
# build the runner from config # build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
if args.out and args.out_item in ['pred', None]: if args.out and args.out_item in ['pred', None]:
runner.test_evaluator.metrics.append( runner.test_evaluator.metrics.append(

View File

@ -5,6 +5,7 @@ import os.path as osp
from copy import deepcopy from copy import deepcopy
from mmengine.config import Config, ConfigDict, DictAction from mmengine.config import Config, ConfigDict, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner from mmengine.runner import Runner
from mmengine.utils import digit_version from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.dl_utils import TORCH_VERSION
@ -149,7 +150,13 @@ def main():
cfg = merge_args(cfg, args) cfg = merge_args(cfg, args)
# build the runner from config # build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start training # start training
runner.train() runner.train()