[Fix] Fix placement policy in ColossalAIStrategy (#1440)
parent
efcd364124
commit
671f3bcdf4
mmengine/_strategy
|
@ -120,8 +120,9 @@ class ColossalAIOptimWrapper(OptimWrapper):
|
|||
self.optimizer.backward(loss, **kwargs)
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class CollosalAIModelWrapper:
|
||||
@MODEL_WRAPPERS.register_module(
|
||||
name=['ColossalAIModelWrapper', 'CollosalAIModelWrapper'])
|
||||
class ColossalAIModelWrapper:
|
||||
|
||||
def __init__(self, model_wrapper: ModelWrapper, model: nn.Module):
|
||||
self.model_wrapper = model_wrapper
|
||||
|
@ -238,7 +239,7 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
OPTIMIZER_DIR = 'optimizer' # directory to save optimizer state.
|
||||
MODEL_DIR = 'model' # directory to save model
|
||||
SCHEDULER_DIR = 'scheduler' # directory to save scheduelrs
|
||||
model: CollosalAIModelWrapper # type: ignore
|
||||
model: ColossalAIModelWrapper # type: ignore
|
||||
optim_wrapper: ColossalAIOptimWrapper # type: ignore
|
||||
|
||||
def __init__(
|
||||
|
@ -468,8 +469,14 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
def _build_plugin(self, plugin: Union[str, dict]):
|
||||
if isinstance(plugin, str):
|
||||
if plugin == 'gemini':
|
||||
plugin = colo_plugin.GeminiPlugin(
|
||||
precision='bf16', placement_policy='cuda')
|
||||
try:
|
||||
plugin = colo_plugin.GeminiPlugin(
|
||||
precision='bf16', placement_policy='auto')
|
||||
except AssertionError:
|
||||
from colossalai.zero.gemini.placement_policy import \
|
||||
PlacementPolicyFactory as colo_placement
|
||||
raise ValueError('placement policy must be one of ' +
|
||||
f'{list(colo_placement.policies.keys())}')
|
||||
elif plugin == 'lowlevel-zero':
|
||||
plugin = colo_plugin.LowLevelZeroPlugin()
|
||||
else:
|
||||
|
@ -508,11 +515,11 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
self,
|
||||
model: nn.Module,
|
||||
optim_wrapper: Optional[OptimWrapper] = None,
|
||||
) -> Union[Tuple[CollosalAIModelWrapper, ColossalAIOptimWrapper],
|
||||
CollosalAIModelWrapper]: # type: ignore
|
||||
) -> Union[Tuple[ColossalAIModelWrapper, ColossalAIOptimWrapper],
|
||||
ColossalAIModelWrapper]: # type: ignore
|
||||
"""Wrap model with :class:`ModelWrapper`."""
|
||||
if self.model_wrapper is None:
|
||||
self.model_wrapper = {'type': 'CollosalAIModelWrapper'}
|
||||
self.model_wrapper = {'type': 'ColossalAIModelWrapper'}
|
||||
|
||||
# For zero series parallel, move `data_preprocessor` to current device
|
||||
# is reasonable. We need to `BaseDataPreprocessor.to` manually since
|
||||
|
|
Loading…
Reference in New Issue