[Feature] Support save_optimizer=False for DeepSpeed (#1474)

This commit is contained in:
Zhihao Lin 2024-01-24 11:12:54 +08:00 committed by GitHub
parent 396cac19cd
commit cd298e3086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,18 +6,23 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from mmengine.logging import print_log
try:
import deepspeed
except ImportError:
deepspeed = None
import logging
import torch.nn as nn
import mmengine
from mmengine.dist import init_dist
from mmengine.dist import init_dist, is_main_process
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
STRATEGIES)
from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu
from mmengine.utils import apply_to, digit_version, get_git_hash
from .base import BaseStrategy
@ -506,7 +511,7 @@ class DeepSpeedStrategy(BaseStrategy):
"""Save checkpoint to given ``filename``.
Warning:
`save_optimizer` and `callback` parameters are not supported yet.
`callback` parameter is not supported yet.
Args:
filename (str): Filename to save checkpoint.
@ -527,25 +532,53 @@ class DeepSpeedStrategy(BaseStrategy):
mmengine=mmengine.__version__ + get_git_hash(),
)
if save_optimizer and hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be thrown
# when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
if save_param_scheduler and hasattr(self, 'param_schedulers'):
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
exclude_frozen_parameters=self.exclude_frozen_parameters)
if (not save_optimizer
and self.model.zero_optimization_partition_weights()
and not self.model.zero_gather_16bit_weights_on_model_save()):
print_log(
'Configured to `save_optimizer=False`, but currently using '
"DeepSpeed's ZeRO stage 3 with "
'`gather_16bit_weights_on_model_save=False`. In '
'this configuration, the model cannot be saved properly '
'and will be saved with the optimizer state. '
'To support `save_optimizer=False`, please set '
'`gather_16bit_weights_on_model_save=True` in your '
'DeepSpeed config.',
logger='current',
level=logging.WARNING)
save_optimizer = True
if save_optimizer:
if hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be
# thrown when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
exclude_frozen_parameters=self.exclude_frozen_parameters)
else:
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False)
else:
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False)
if self.model.zero_optimization_partition_weights():
# TODO: `_zero3_consolidated_16bit_state_dict` doesn't support
# `exclude_frozen_parameters`.
state_dict = self.model._zero3_consolidated_16bit_state_dict()
else:
state_dict = self.model.module_state_dict(
exclude_frozen_parameters=self.exclude_frozen_parameters)
if is_main_process():
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
save_checkpoint(ckpt, filename)