[Feature] Support save_optimizer=False for DeepSpeed (#1474)

This commit is contained in:
Zhihao Lin 2024-01-24 11:12:54 +08:00 committed by GitHub
parent 396cac19cd
commit cd298e3086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,18 +6,23 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from mmengine.logging import print_log
try: try:
import deepspeed import deepspeed
except ImportError: except ImportError:
deepspeed = None deepspeed = None
import logging
import torch.nn as nn import torch.nn as nn
import mmengine import mmengine
from mmengine.dist import init_dist from mmengine.dist import init_dist, is_main_process
from mmengine.optim import BaseOptimWrapper, _ParamScheduler from mmengine.optim import BaseOptimWrapper, _ParamScheduler
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS, from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
STRATEGIES) STRATEGIES)
from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu
from mmengine.utils import apply_to, digit_version, get_git_hash from mmengine.utils import apply_to, digit_version, get_git_hash
from .base import BaseStrategy from .base import BaseStrategy
@ -506,7 +511,7 @@ class DeepSpeedStrategy(BaseStrategy):
"""Save checkpoint to given ``filename``. """Save checkpoint to given ``filename``.
Warning: Warning:
`save_optimizer` and `callback` parameters are not supported yet. `callback` parameter is not supported yet.
Args: Args:
filename (str): Filename to save checkpoint. filename (str): Filename to save checkpoint.
@ -527,25 +532,53 @@ class DeepSpeedStrategy(BaseStrategy):
mmengine=mmengine.__version__ + get_git_hash(), mmengine=mmengine.__version__ + get_git_hash(),
) )
if save_optimizer and hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be thrown
# when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
if save_param_scheduler and hasattr(self, 'param_schedulers'): if save_param_scheduler and hasattr(self, 'param_schedulers'):
extra_ckpt['param_schedulers'] = self.scheduler_state_dict() extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
dirname, basename = osp.split(filename) if (not save_optimizer
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'): and self.model.zero_optimization_partition_weights()
self.model.save_checkpoint( and not self.model.zero_gather_16bit_weights_on_model_save()):
dirname, print_log(
tag=basename, 'Configured to `save_optimizer=False`, but currently using '
client_state=extra_ckpt, "DeepSpeed's ZeRO stage 3 with "
save_latest=False, '`gather_16bit_weights_on_model_save=False`. In '
exclude_frozen_parameters=self.exclude_frozen_parameters) 'this configuration, the model cannot be saved properly '
'and will be saved with the optimizer state. '
'To support `save_optimizer=False`, please set '
'`gather_16bit_weights_on_model_save=True` in your '
'DeepSpeed config.',
logger='current',
level=logging.WARNING)
save_optimizer = True
if save_optimizer:
if hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be
# thrown when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
exclude_frozen_parameters=self.exclude_frozen_parameters)
else:
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False)
else: else:
self.model.save_checkpoint( if self.model.zero_optimization_partition_weights():
dirname, # TODO: `_zero3_consolidated_16bit_state_dict` doesn't support
tag=basename, # `exclude_frozen_parameters`.
client_state=extra_ckpt, state_dict = self.model._zero3_consolidated_16bit_state_dict()
save_latest=False) else:
state_dict = self.model.module_state_dict(
exclude_frozen_parameters=self.exclude_frozen_parameters)
if is_main_process():
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
save_checkpoint(ckpt, filename)