diff --git a/mmengine/_strategy/__init__.py b/mmengine/_strategy/__init__.py index 0b201aef..af369751 100644 --- a/mmengine/_strategy/__init__.py +++ b/mmengine/_strategy/__init__.py @@ -1,15 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.utils import digit_version, is_installed +from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION from .base import BaseStrategy +from .deepspeed import DeepSpeedStrategy from .distributed import DDPStrategy from .single_device import SingleDeviceStrategy -__all__ = ['BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy'] - -if is_installed('deepspeed'): - from .deepspeed import DeepSpeedStrategy # noqa: F401 - __all__.append('DeepSpeedStrategy') +__all__ = [ + 'BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy', 'DeepSpeedStrategy' +] if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): try: diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 0c1d1e42..999ba4d0 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -4,7 +4,11 @@ import os.path as osp import time from typing import Callable, Dict, List, Optional, Union -import deepspeed +try: + import deepspeed +except ImportError: + deepspeed = None + import torch.nn as nn import mmengine @@ -71,6 +75,10 @@ class DeepSpeedStrategy(BaseStrategy): # the following args are for BaseStrategy **kwargs, ): + assert deepspeed is not None, \ + 'DeepSpeed is not installed. Please check ' \ + 'https://github.com/microsoft/DeepSpeed#installation.' + super().__init__(**kwargs) self.config = self._parse_config(config) diff --git a/mmengine/model/wrappers/_deepspeed.py b/mmengine/model/wrappers/_deepspeed.py index a161afd3..51559d8b 100644 --- a/mmengine/model/wrappers/_deepspeed.py +++ b/mmengine/model/wrappers/_deepspeed.py @@ -2,11 +2,15 @@ from typing import Any, Dict, List, Optional, Union import torch -from deepspeed.runtime.engine import DeepSpeedEngine from mmengine.optim.optimizer._deepspeed import DeepSpeedOptimWrapper from mmengine.registry import MODEL_WRAPPERS +try: + from deepspeed.runtime.engine import DeepSpeedEngine +except ImportError: + DeepSpeedEngine = None + @MODEL_WRAPPERS.register_module() class MMDeepSpeedEngineWrapper: diff --git a/mmengine/optim/__init__.py b/mmengine/optim/__init__.py index d83fa51c..78c6850a 100644 --- a/mmengine/optim/__init__.py +++ b/mmengine/optim/__init__.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.utils import is_installed from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, AmpOptimWrapper, ApexOptimWrapper, BaseOptimWrapper, - DefaultOptimWrapperConstructor, OptimWrapper, - OptimWrapperDict, ZeroRedundancyOptimizer, - build_optim_wrapper) + DeepSpeedOptimWrapper, DefaultOptimWrapperConstructor, + OptimWrapper, OptimWrapperDict, + ZeroRedundancyOptimizer, build_optim_wrapper) # yapf: disable from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, CosineAnnealingLR, CosineAnnealingMomentum, @@ -32,9 +31,5 @@ __all__ = [ 'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', 'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler', - 'ZeroRedundancyOptimizer', 'BaseOptimWrapper' + 'ZeroRedundancyOptimizer', 'BaseOptimWrapper', 'DeepSpeedOptimWrapper' ] - -if is_installed('deepspeed'): - from .optimizer import DeepSpeedOptimWrapper # noqa:F401 - __all__.append('DeepSpeedOptimWrapper') diff --git a/mmengine/optim/optimizer/__init__.py b/mmengine/optim/optimizer/__init__.py index b9674ee0..bc868e98 100644 --- a/mmengine/optim/optimizer/__init__.py +++ b/mmengine/optim/optimizer/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.utils import is_installed +from ._deepspeed import DeepSpeedOptimWrapper from .amp_optimizer_wrapper import AmpOptimWrapper from .apex_optimizer_wrapper import ApexOptimWrapper from .base import BaseOptimWrapper @@ -14,9 +14,5 @@ __all__ = [ 'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper', 'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict', - 'ZeroRedundancyOptimizer', 'BaseOptimWrapper' + 'ZeroRedundancyOptimizer', 'BaseOptimWrapper', 'DeepSpeedOptimWrapper' ] - -if is_installed('deepspeed'): - from ._deepspeed import DeepSpeedOptimWrapper # noqa:F401 - __all__.append('DeepSpeedOptimWrapper')