[Fix] Fix placement policy in ColossalAIStrategy ()

pull/1458/head
fanqiNO1 2023-12-23 16:24:39 +08:00 committed by GitHub
parent efcd364124
commit 671f3bcdf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 8 deletions
mmengine/_strategy

View File

@ -120,8 +120,9 @@ class ColossalAIOptimWrapper(OptimWrapper):
self.optimizer.backward(loss, **kwargs)
@MODEL_WRAPPERS.register_module()
class CollosalAIModelWrapper:
@MODEL_WRAPPERS.register_module(
name=['ColossalAIModelWrapper', 'CollosalAIModelWrapper'])
class ColossalAIModelWrapper:
def __init__(self, model_wrapper: ModelWrapper, model: nn.Module):
self.model_wrapper = model_wrapper
@ -238,7 +239,7 @@ class ColossalAIStrategy(BaseStrategy):
OPTIMIZER_DIR = 'optimizer' # directory to save optimizer state.
MODEL_DIR = 'model' # directory to save model
SCHEDULER_DIR = 'scheduler' # directory to save scheduelrs
model: CollosalAIModelWrapper # type: ignore
model: ColossalAIModelWrapper # type: ignore
optim_wrapper: ColossalAIOptimWrapper # type: ignore
def __init__(
@ -468,8 +469,14 @@ class ColossalAIStrategy(BaseStrategy):
def _build_plugin(self, plugin: Union[str, dict]):
if isinstance(plugin, str):
if plugin == 'gemini':
plugin = colo_plugin.GeminiPlugin(
precision='bf16', placement_policy='cuda')
try:
plugin = colo_plugin.GeminiPlugin(
precision='bf16', placement_policy='auto')
except AssertionError:
from colossalai.zero.gemini.placement_policy import \
PlacementPolicyFactory as colo_placement
raise ValueError('placement policy must be one of ' +
f'{list(colo_placement.policies.keys())}')
elif plugin == 'lowlevel-zero':
plugin = colo_plugin.LowLevelZeroPlugin()
else:
@ -508,11 +515,11 @@ class ColossalAIStrategy(BaseStrategy):
self,
model: nn.Module,
optim_wrapper: Optional[OptimWrapper] = None,
) -> Union[Tuple[CollosalAIModelWrapper, ColossalAIOptimWrapper],
CollosalAIModelWrapper]: # type: ignore
) -> Union[Tuple[ColossalAIModelWrapper, ColossalAIOptimWrapper],
ColossalAIModelWrapper]: # type: ignore
"""Wrap model with :class:`ModelWrapper`."""
if self.model_wrapper is None:
self.model_wrapper = {'type': 'CollosalAIModelWrapper'}
self.model_wrapper = {'type': 'ColossalAIModelWrapper'}
# For zero series parallel, move `data_preprocessor` to current device
# is reasonable. We need to `BaseDataPreprocessor.to` manually since