[Enhance] Enable `exclude_frozen_parameters` for `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` ()

pull/1531/head
Zhihao Lin 2024-04-12 14:25:54 +08:00 committed by GitHub
parent e258c84824
commit 39ed23fae8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 22 deletions
mmengine/_strategy

View File

@ -311,8 +311,8 @@ class DeepSpeedStrategy(BaseStrategy):
self.config['steps_per_print'] = steps_per_print
self._inputs_to_half = inputs_to_half
assert (exclude_frozen_parameters is None or
digit_version(deepspeed.__version__) >= digit_version('0.10.1')
), ('DeepSpeed >= 0.10.1 is required to enable '
digit_version(deepspeed.__version__) >= digit_version('0.13.2')
), ('DeepSpeed >= 0.13.2 is required to enable '
'exclude_frozen_parameters')
self.exclude_frozen_parameters = exclude_frozen_parameters
@ -430,7 +430,7 @@ class DeepSpeedStrategy(BaseStrategy):
self.logger.info(f'Load checkpoint from {filename}')
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
_, extra_ckpt = self.model.load_checkpoint(
dirname,
tag=basename,
@ -468,7 +468,7 @@ class DeepSpeedStrategy(BaseStrategy):
self.logger.info(f'Resume checkpoint from {filename}')
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
_, extra_ckpt = self.model.load_checkpoint(
dirname,
tag=basename,
@ -551,6 +551,11 @@ class DeepSpeedStrategy(BaseStrategy):
level=logging.WARNING)
save_optimizer = True
state_dict_kwargs = {}
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
state_dict_kwargs[
'exclude_frozen_parameters'] = self.exclude_frozen_parameters
if save_optimizer:
if hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be
@ -558,27 +563,19 @@ class DeepSpeedStrategy(BaseStrategy):
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
exclude_frozen_parameters=self.exclude_frozen_parameters)
else:
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False)
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
**state_dict_kwargs)
else:
if self.model.zero_optimization_partition_weights():
# TODO: `_zero3_consolidated_16bit_state_dict` doesn't support
# `exclude_frozen_parameters`.
state_dict = self.model._zero3_consolidated_16bit_state_dict()
state_dict = self.model._zero3_consolidated_16bit_state_dict(
**state_dict_kwargs)
else:
state_dict = self.model.module_state_dict(
exclude_frozen_parameters=self.exclude_frozen_parameters)
state_dict = self.model.module_state_dict(**state_dict_kwargs)
if is_main_process():
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
save_checkpoint(ckpt, filename)