[Enhancement] Register DeepSpeedStrategy even if deepspeed is not installed (#1240)
parent
5bc841c09c
commit
86387da4a5
|
@ -1,15 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils import digit_version, is_installed
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from .base import BaseStrategy
|
||||
from .deepspeed import DeepSpeedStrategy
|
||||
from .distributed import DDPStrategy
|
||||
from .single_device import SingleDeviceStrategy
|
||||
|
||||
__all__ = ['BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy']
|
||||
|
||||
if is_installed('deepspeed'):
|
||||
from .deepspeed import DeepSpeedStrategy # noqa: F401
|
||||
__all__.append('DeepSpeedStrategy')
|
||||
__all__ = [
|
||||
'BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy', 'DeepSpeedStrategy'
|
||||
]
|
||||
|
||||
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
|
||||
try:
|
||||
|
|
|
@ -4,7 +4,11 @@ import os.path as osp
|
|||
import time
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import deepspeed
|
||||
try:
|
||||
import deepspeed
|
||||
except ImportError:
|
||||
deepspeed = None
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
import mmengine
|
||||
|
@ -71,6 +75,10 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
# the following args are for BaseStrategy
|
||||
**kwargs,
|
||||
):
|
||||
assert deepspeed is not None, \
|
||||
'DeepSpeed is not installed. Please check ' \
|
||||
'https://github.com/microsoft/DeepSpeed#installation.'
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = self._parse_config(config)
|
||||
|
|
|
@ -2,11 +2,15 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from deepspeed.runtime.engine import DeepSpeedEngine
|
||||
|
||||
from mmengine.optim.optimizer._deepspeed import DeepSpeedOptimWrapper
|
||||
from mmengine.registry import MODEL_WRAPPERS
|
||||
|
||||
try:
|
||||
from deepspeed.runtime.engine import DeepSpeedEngine
|
||||
except ImportError:
|
||||
DeepSpeedEngine = None
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class MMDeepSpeedEngineWrapper:
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils import is_installed
|
||||
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
AmpOptimWrapper, ApexOptimWrapper, BaseOptimWrapper,
|
||||
DefaultOptimWrapperConstructor, OptimWrapper,
|
||||
OptimWrapperDict, ZeroRedundancyOptimizer,
|
||||
build_optim_wrapper)
|
||||
DeepSpeedOptimWrapper, DefaultOptimWrapperConstructor,
|
||||
OptimWrapper, OptimWrapperDict,
|
||||
ZeroRedundancyOptimizer, build_optim_wrapper)
|
||||
# yapf: disable
|
||||
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||
CosineAnnealingLR, CosineAnnealingMomentum,
|
||||
|
@ -32,9 +31,5 @@ __all__ = [
|
|||
'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR',
|
||||
'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR',
|
||||
'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler',
|
||||
'ZeroRedundancyOptimizer', 'BaseOptimWrapper'
|
||||
'ZeroRedundancyOptimizer', 'BaseOptimWrapper', 'DeepSpeedOptimWrapper'
|
||||
]
|
||||
|
||||
if is_installed('deepspeed'):
|
||||
from .optimizer import DeepSpeedOptimWrapper # noqa:F401
|
||||
__all__.append('DeepSpeedOptimWrapper')
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils import is_installed
|
||||
from ._deepspeed import DeepSpeedOptimWrapper
|
||||
from .amp_optimizer_wrapper import AmpOptimWrapper
|
||||
from .apex_optimizer_wrapper import ApexOptimWrapper
|
||||
from .base import BaseOptimWrapper
|
||||
|
@ -14,9 +14,5 @@ __all__ = [
|
|||
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
|
||||
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
|
||||
'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict',
|
||||
'ZeroRedundancyOptimizer', 'BaseOptimWrapper'
|
||||
'ZeroRedundancyOptimizer', 'BaseOptimWrapper', 'DeepSpeedOptimWrapper'
|
||||
]
|
||||
|
||||
if is_installed('deepspeed'):
|
||||
from ._deepspeed import DeepSpeedOptimWrapper # noqa:F401
|
||||
__all__.append('DeepSpeedOptimWrapper')
|
||||
|
|
Loading…
Reference in New Issue