[Enhance] Enable `exclude_frozen_parameters` for `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` (#1517)
parent
e258c84824
commit
39ed23fae8
mmengine/_strategy
|
@ -311,8 +311,8 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
self.config['steps_per_print'] = steps_per_print
|
||||
self._inputs_to_half = inputs_to_half
|
||||
assert (exclude_frozen_parameters is None or
|
||||
digit_version(deepspeed.__version__) >= digit_version('0.10.1')
|
||||
), ('DeepSpeed >= 0.10.1 is required to enable '
|
||||
digit_version(deepspeed.__version__) >= digit_version('0.13.2')
|
||||
), ('DeepSpeed >= 0.13.2 is required to enable '
|
||||
'exclude_frozen_parameters')
|
||||
self.exclude_frozen_parameters = exclude_frozen_parameters
|
||||
|
||||
|
@ -430,7 +430,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
self.logger.info(f'Load checkpoint from {filename}')
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
|
||||
_, extra_ckpt = self.model.load_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
|
@ -468,7 +468,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
self.logger.info(f'Resume checkpoint from {filename}')
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
|
||||
_, extra_ckpt = self.model.load_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
|
@ -551,6 +551,11 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
level=logging.WARNING)
|
||||
save_optimizer = True
|
||||
|
||||
state_dict_kwargs = {}
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
|
||||
state_dict_kwargs[
|
||||
'exclude_frozen_parameters'] = self.exclude_frozen_parameters
|
||||
|
||||
if save_optimizer:
|
||||
if hasattr(self, 'optim_wrapper'):
|
||||
# The key can not be 'optimizer', otherwise error will be
|
||||
|
@ -558,27 +563,19 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||
self.model.save_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
client_state=extra_ckpt,
|
||||
save_latest=False,
|
||||
exclude_frozen_parameters=self.exclude_frozen_parameters)
|
||||
else:
|
||||
self.model.save_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
client_state=extra_ckpt,
|
||||
save_latest=False)
|
||||
self.model.save_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
client_state=extra_ckpt,
|
||||
save_latest=False,
|
||||
**state_dict_kwargs)
|
||||
else:
|
||||
if self.model.zero_optimization_partition_weights():
|
||||
# TODO: `_zero3_consolidated_16bit_state_dict` doesn't support
|
||||
# `exclude_frozen_parameters`.
|
||||
state_dict = self.model._zero3_consolidated_16bit_state_dict()
|
||||
state_dict = self.model._zero3_consolidated_16bit_state_dict(
|
||||
**state_dict_kwargs)
|
||||
else:
|
||||
state_dict = self.model.module_state_dict(
|
||||
exclude_frozen_parameters=self.exclude_frozen_parameters)
|
||||
state_dict = self.model.module_state_dict(**state_dict_kwargs)
|
||||
|
||||
if is_main_process():
|
||||
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
|
||||
save_checkpoint(ckpt, filename)
|
||||
|
|
Loading…
Reference in New Issue