mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support save_optimizer=False for DeepSpeed (#1474)
This commit is contained in:
parent
396cac19cd
commit
cd298e3086
@ -6,18 +6,23 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from mmengine.logging import print_log
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
except ImportError:
|
except ImportError:
|
||||||
deepspeed = None
|
deepspeed = None
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import mmengine
|
import mmengine
|
||||||
from mmengine.dist import init_dist
|
from mmengine.dist import init_dist, is_main_process
|
||||||
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
|
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
|
||||||
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
|
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
|
||||||
STRATEGIES)
|
STRATEGIES)
|
||||||
|
from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu
|
||||||
from mmengine.utils import apply_to, digit_version, get_git_hash
|
from mmengine.utils import apply_to, digit_version, get_git_hash
|
||||||
from .base import BaseStrategy
|
from .base import BaseStrategy
|
||||||
|
|
||||||
@ -506,7 +511,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||||||
"""Save checkpoint to given ``filename``.
|
"""Save checkpoint to given ``filename``.
|
||||||
|
|
||||||
Warning:
|
Warning:
|
||||||
`save_optimizer` and `callback` parameters are not supported yet.
|
`callback` parameter is not supported yet.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): Filename to save checkpoint.
|
filename (str): Filename to save checkpoint.
|
||||||
@ -527,25 +532,53 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||||||
mmengine=mmengine.__version__ + get_git_hash(),
|
mmengine=mmengine.__version__ + get_git_hash(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if save_optimizer and hasattr(self, 'optim_wrapper'):
|
|
||||||
# The key can not be 'optimizer', otherwise error will be thrown
|
|
||||||
# when loading or resuming checkpoint.
|
|
||||||
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
|
|
||||||
|
|
||||||
if save_param_scheduler and hasattr(self, 'param_schedulers'):
|
if save_param_scheduler and hasattr(self, 'param_schedulers'):
|
||||||
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
|
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
|
||||||
|
|
||||||
dirname, basename = osp.split(filename)
|
if (not save_optimizer
|
||||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
and self.model.zero_optimization_partition_weights()
|
||||||
self.model.save_checkpoint(
|
and not self.model.zero_gather_16bit_weights_on_model_save()):
|
||||||
dirname,
|
print_log(
|
||||||
tag=basename,
|
'Configured to `save_optimizer=False`, but currently using '
|
||||||
client_state=extra_ckpt,
|
"DeepSpeed's ZeRO stage 3 with "
|
||||||
save_latest=False,
|
'`gather_16bit_weights_on_model_save=False`. In '
|
||||||
exclude_frozen_parameters=self.exclude_frozen_parameters)
|
'this configuration, the model cannot be saved properly '
|
||||||
|
'and will be saved with the optimizer state. '
|
||||||
|
'To support `save_optimizer=False`, please set '
|
||||||
|
'`gather_16bit_weights_on_model_save=True` in your '
|
||||||
|
'DeepSpeed config.',
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
|
save_optimizer = True
|
||||||
|
|
||||||
|
if save_optimizer:
|
||||||
|
if hasattr(self, 'optim_wrapper'):
|
||||||
|
# The key can not be 'optimizer', otherwise error will be
|
||||||
|
# thrown when loading or resuming checkpoint.
|
||||||
|
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
|
||||||
|
|
||||||
|
dirname, basename = osp.split(filename)
|
||||||
|
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||||
|
self.model.save_checkpoint(
|
||||||
|
dirname,
|
||||||
|
tag=basename,
|
||||||
|
client_state=extra_ckpt,
|
||||||
|
save_latest=False,
|
||||||
|
exclude_frozen_parameters=self.exclude_frozen_parameters)
|
||||||
|
else:
|
||||||
|
self.model.save_checkpoint(
|
||||||
|
dirname,
|
||||||
|
tag=basename,
|
||||||
|
client_state=extra_ckpt,
|
||||||
|
save_latest=False)
|
||||||
else:
|
else:
|
||||||
self.model.save_checkpoint(
|
if self.model.zero_optimization_partition_weights():
|
||||||
dirname,
|
# TODO: `_zero3_consolidated_16bit_state_dict` doesn't support
|
||||||
tag=basename,
|
# `exclude_frozen_parameters`.
|
||||||
client_state=extra_ckpt,
|
state_dict = self.model._zero3_consolidated_16bit_state_dict()
|
||||||
save_latest=False)
|
else:
|
||||||
|
state_dict = self.model.module_state_dict(
|
||||||
|
exclude_frozen_parameters=self.exclude_frozen_parameters)
|
||||||
|
if is_main_process():
|
||||||
|
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
|
||||||
|
save_checkpoint(ckpt, filename)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user