[Enhancement] Register DeepSpeedStrategy even if deepspeed is not installed (#1240)

pull/1259/head
Qingyun 2023-07-18 11:02:55 +08:00 committed by GitHub
parent 5bc841c09c
commit 86387da4a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 23 deletions

View File

@ -1,15 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import digit_version, is_installed
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from .base import BaseStrategy
from .deepspeed import DeepSpeedStrategy
from .distributed import DDPStrategy
from .single_device import SingleDeviceStrategy
__all__ = ['BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy']
if is_installed('deepspeed'):
from .deepspeed import DeepSpeedStrategy # noqa: F401
__all__.append('DeepSpeedStrategy')
__all__ = [
'BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy', 'DeepSpeedStrategy'
]
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
try:

View File

@ -4,7 +4,11 @@ import os.path as osp
import time
from typing import Callable, Dict, List, Optional, Union
import deepspeed
try:
import deepspeed
except ImportError:
deepspeed = None
import torch.nn as nn
import mmengine
@ -71,6 +75,10 @@ class DeepSpeedStrategy(BaseStrategy):
# the following args are for BaseStrategy
**kwargs,
):
assert deepspeed is not None, \
'DeepSpeed is not installed. Please check ' \
'https://github.com/microsoft/DeepSpeed#installation.'
super().__init__(**kwargs)
self.config = self._parse_config(config)

View File

@ -2,11 +2,15 @@
from typing import Any, Dict, List, Optional, Union
import torch
from deepspeed.runtime.engine import DeepSpeedEngine
from mmengine.optim.optimizer._deepspeed import DeepSpeedOptimWrapper
from mmengine.registry import MODEL_WRAPPERS
try:
from deepspeed.runtime.engine import DeepSpeedEngine
except ImportError:
DeepSpeedEngine = None
@MODEL_WRAPPERS.register_module()
class MMDeepSpeedEngineWrapper:

View File

@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import is_installed
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
AmpOptimWrapper, ApexOptimWrapper, BaseOptimWrapper,
DefaultOptimWrapperConstructor, OptimWrapper,
OptimWrapperDict, ZeroRedundancyOptimizer,
build_optim_wrapper)
DeepSpeedOptimWrapper, DefaultOptimWrapperConstructor,
OptimWrapper, OptimWrapperDict,
ZeroRedundancyOptimizer, build_optim_wrapper)
# yapf: disable
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
CosineAnnealingLR, CosineAnnealingMomentum,
@ -32,9 +31,5 @@ __all__ = [
'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR',
'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR',
'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler',
'ZeroRedundancyOptimizer', 'BaseOptimWrapper'
'ZeroRedundancyOptimizer', 'BaseOptimWrapper', 'DeepSpeedOptimWrapper'
]
if is_installed('deepspeed'):
from .optimizer import DeepSpeedOptimWrapper # noqa:F401
__all__.append('DeepSpeedOptimWrapper')

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import is_installed
from ._deepspeed import DeepSpeedOptimWrapper
from .amp_optimizer_wrapper import AmpOptimWrapper
from .apex_optimizer_wrapper import ApexOptimWrapper
from .base import BaseOptimWrapper
@ -14,9 +14,5 @@ __all__ = [
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict',
'ZeroRedundancyOptimizer', 'BaseOptimWrapper'
'ZeroRedundancyOptimizer', 'BaseOptimWrapper', 'DeepSpeedOptimWrapper'
]
if is_installed('deepspeed'):
from ._deepspeed import DeepSpeedOptimWrapper # noqa:F401
__all__.append('DeepSpeedOptimWrapper')