mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Support exclude_frozen_parameters for DeepSpeedStrategy's resume (#1424)
This commit is contained in:
parent
46784185cf
commit
26f22ed283
@ -463,8 +463,15 @@ class DeepSpeedStrategy(BaseStrategy):
|
||||
self.logger.info(f'Resume checkpoint from {filename}')
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
_, extra_ckpt = self.model.load_checkpoint(
|
||||
dirname, tag=basename, load_optimizer_states=resume_optimizer)
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||
_, extra_ckpt = self.model.load_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
load_optimizer_states=resume_optimizer,
|
||||
load_module_strict=not self.exclude_frozen_parameters)
|
||||
else:
|
||||
_, extra_ckpt = self.model.load_checkpoint(
|
||||
dirname, tag=basename, load_optimizer_states=resume_optimizer)
|
||||
|
||||
if resume_optimizer:
|
||||
self.load_optim_state_dict(extra_ckpt.pop('optim_wrapper'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user