[Fix] Fix placement policy in ColossalAIStrategy (#1440)

This commit is contained in:
fanqiNO1 2023-12-23 16:24:39 +08:00 committed by GitHub
parent efcd364124
commit 671f3bcdf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -120,8 +120,9 @@ class ColossalAIOptimWrapper(OptimWrapper):
self.optimizer.backward(loss, **kwargs) self.optimizer.backward(loss, **kwargs)
@MODEL_WRAPPERS.register_module() @MODEL_WRAPPERS.register_module(
class CollosalAIModelWrapper: name=['ColossalAIModelWrapper', 'CollosalAIModelWrapper'])
class ColossalAIModelWrapper:
def __init__(self, model_wrapper: ModelWrapper, model: nn.Module): def __init__(self, model_wrapper: ModelWrapper, model: nn.Module):
self.model_wrapper = model_wrapper self.model_wrapper = model_wrapper
@ -238,7 +239,7 @@ class ColossalAIStrategy(BaseStrategy):
OPTIMIZER_DIR = 'optimizer' # directory to save optimizer state. OPTIMIZER_DIR = 'optimizer' # directory to save optimizer state.
MODEL_DIR = 'model' # directory to save model MODEL_DIR = 'model' # directory to save model
SCHEDULER_DIR = 'scheduler' # directory to save scheduelrs SCHEDULER_DIR = 'scheduler' # directory to save scheduelrs
model: CollosalAIModelWrapper # type: ignore model: ColossalAIModelWrapper # type: ignore
optim_wrapper: ColossalAIOptimWrapper # type: ignore optim_wrapper: ColossalAIOptimWrapper # type: ignore
def __init__( def __init__(
@ -468,8 +469,14 @@ class ColossalAIStrategy(BaseStrategy):
def _build_plugin(self, plugin: Union[str, dict]): def _build_plugin(self, plugin: Union[str, dict]):
if isinstance(plugin, str): if isinstance(plugin, str):
if plugin == 'gemini': if plugin == 'gemini':
plugin = colo_plugin.GeminiPlugin( try:
precision='bf16', placement_policy='cuda') plugin = colo_plugin.GeminiPlugin(
precision='bf16', placement_policy='auto')
except AssertionError:
from colossalai.zero.gemini.placement_policy import \
PlacementPolicyFactory as colo_placement
raise ValueError('placement policy must be one of ' +
f'{list(colo_placement.policies.keys())}')
elif plugin == 'lowlevel-zero': elif plugin == 'lowlevel-zero':
plugin = colo_plugin.LowLevelZeroPlugin() plugin = colo_plugin.LowLevelZeroPlugin()
else: else:
@ -508,11 +515,11 @@ class ColossalAIStrategy(BaseStrategy):
self, self,
model: nn.Module, model: nn.Module,
optim_wrapper: Optional[OptimWrapper] = None, optim_wrapper: Optional[OptimWrapper] = None,
) -> Union[Tuple[CollosalAIModelWrapper, ColossalAIOptimWrapper], ) -> Union[Tuple[ColossalAIModelWrapper, ColossalAIOptimWrapper],
CollosalAIModelWrapper]: # type: ignore ColossalAIModelWrapper]: # type: ignore
"""Wrap model with :class:`ModelWrapper`.""" """Wrap model with :class:`ModelWrapper`."""
if self.model_wrapper is None: if self.model_wrapper is None:
self.model_wrapper = {'type': 'CollosalAIModelWrapper'} self.model_wrapper = {'type': 'ColossalAIModelWrapper'}
# For zero series parallel, move `data_preprocessor` to current device # For zero series parallel, move `data_preprocessor` to current device
# is reasonable. We need to `BaseDataPreprocessor.to` manually since # is reasonable. We need to `BaseDataPreprocessor.to` manually since