mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Register DeepSpeedStrategy even if deepspeed is not installed (#1240)
This commit is contained in:
parent
5bc841c09c
commit
86387da4a5
@ -1,15 +1,14 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmengine.utils import digit_version, is_installed
|
from mmengine.utils import digit_version
|
||||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||||
from .base import BaseStrategy
|
from .base import BaseStrategy
|
||||||
|
from .deepspeed import DeepSpeedStrategy
|
||||||
from .distributed import DDPStrategy
|
from .distributed import DDPStrategy
|
||||||
from .single_device import SingleDeviceStrategy
|
from .single_device import SingleDeviceStrategy
|
||||||
|
|
||||||
__all__ = ['BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy']
|
__all__ = [
|
||||||
|
'BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy', 'DeepSpeedStrategy'
|
||||||
if is_installed('deepspeed'):
|
]
|
||||||
from .deepspeed import DeepSpeedStrategy # noqa: F401
|
|
||||||
__all__.append('DeepSpeedStrategy')
|
|
||||||
|
|
||||||
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
|
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
|
||||||
try:
|
try:
|
||||||
|
@ -4,7 +4,11 @@ import os.path as osp
|
|||||||
import time
|
import time
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import deepspeed
|
try:
|
||||||
|
import deepspeed
|
||||||
|
except ImportError:
|
||||||
|
deepspeed = None
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import mmengine
|
import mmengine
|
||||||
@ -71,6 +75,10 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||||||
# the following args are for BaseStrategy
|
# the following args are for BaseStrategy
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
assert deepspeed is not None, \
|
||||||
|
'DeepSpeed is not installed. Please check ' \
|
||||||
|
'https://github.com/microsoft/DeepSpeed#installation.'
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.config = self._parse_config(config)
|
self.config = self._parse_config(config)
|
||||||
|
@ -2,11 +2,15 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from deepspeed.runtime.engine import DeepSpeedEngine
|
|
||||||
|
|
||||||
from mmengine.optim.optimizer._deepspeed import DeepSpeedOptimWrapper
|
from mmengine.optim.optimizer._deepspeed import DeepSpeedOptimWrapper
|
||||||
from mmengine.registry import MODEL_WRAPPERS
|
from mmengine.registry import MODEL_WRAPPERS
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deepspeed.runtime.engine import DeepSpeedEngine
|
||||||
|
except ImportError:
|
||||||
|
DeepSpeedEngine = None
|
||||||
|
|
||||||
|
|
||||||
@MODEL_WRAPPERS.register_module()
|
@MODEL_WRAPPERS.register_module()
|
||||||
class MMDeepSpeedEngineWrapper:
|
class MMDeepSpeedEngineWrapper:
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmengine.utils import is_installed
|
|
||||||
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||||
AmpOptimWrapper, ApexOptimWrapper, BaseOptimWrapper,
|
AmpOptimWrapper, ApexOptimWrapper, BaseOptimWrapper,
|
||||||
DefaultOptimWrapperConstructor, OptimWrapper,
|
DeepSpeedOptimWrapper, DefaultOptimWrapperConstructor,
|
||||||
OptimWrapperDict, ZeroRedundancyOptimizer,
|
OptimWrapper, OptimWrapperDict,
|
||||||
build_optim_wrapper)
|
ZeroRedundancyOptimizer, build_optim_wrapper)
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||||
CosineAnnealingLR, CosineAnnealingMomentum,
|
CosineAnnealingLR, CosineAnnealingMomentum,
|
||||||
@ -32,9 +31,5 @@ __all__ = [
|
|||||||
'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR',
|
'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR',
|
||||||
'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR',
|
'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR',
|
||||||
'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler',
|
'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler',
|
||||||
'ZeroRedundancyOptimizer', 'BaseOptimWrapper'
|
'ZeroRedundancyOptimizer', 'BaseOptimWrapper', 'DeepSpeedOptimWrapper'
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_installed('deepspeed'):
|
|
||||||
from .optimizer import DeepSpeedOptimWrapper # noqa:F401
|
|
||||||
__all__.append('DeepSpeedOptimWrapper')
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmengine.utils import is_installed
|
from ._deepspeed import DeepSpeedOptimWrapper
|
||||||
from .amp_optimizer_wrapper import AmpOptimWrapper
|
from .amp_optimizer_wrapper import AmpOptimWrapper
|
||||||
from .apex_optimizer_wrapper import ApexOptimWrapper
|
from .apex_optimizer_wrapper import ApexOptimWrapper
|
||||||
from .base import BaseOptimWrapper
|
from .base import BaseOptimWrapper
|
||||||
@ -14,9 +14,5 @@ __all__ = [
|
|||||||
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
|
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
|
||||||
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
|
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
|
||||||
'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict',
|
'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict',
|
||||||
'ZeroRedundancyOptimizer', 'BaseOptimWrapper'
|
'ZeroRedundancyOptimizer', 'BaseOptimWrapper', 'DeepSpeedOptimWrapper'
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_installed('deepspeed'):
|
|
||||||
from ._deepspeed import DeepSpeedOptimWrapper # noqa:F401
|
|
||||||
__all__.append('DeepSpeedOptimWrapper')
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user