mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support save_optimizer=False for DeepSpeed (#1474)
This commit is contained in:
parent
396cac19cd
commit
cd298e3086
@ -6,18 +6,23 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.logging import print_log
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
except ImportError:
|
||||
deepspeed = None
|
||||
|
||||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
import mmengine
|
||||
from mmengine.dist import init_dist
|
||||
from mmengine.dist import init_dist, is_main_process
|
||||
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
|
||||
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
|
||||
STRATEGIES)
|
||||
from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu
|
||||
from mmengine.utils import apply_to, digit_version, get_git_hash
|
||||
from .base import BaseStrategy
|
||||
|
||||
@ -506,7 +511,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
||||
"""Save checkpoint to given ``filename``.
|
||||
|
||||
Warning:
|
||||
`save_optimizer` and `callback` parameters are not supported yet.
|
||||
`callback` parameter is not supported yet.
|
||||
|
||||
Args:
|
||||
filename (str): Filename to save checkpoint.
|
||||
@ -527,25 +532,53 @@ class DeepSpeedStrategy(BaseStrategy):
|
||||
mmengine=mmengine.__version__ + get_git_hash(),
|
||||
)
|
||||
|
||||
if save_optimizer and hasattr(self, 'optim_wrapper'):
|
||||
# The key can not be 'optimizer', otherwise error will be thrown
|
||||
# when loading or resuming checkpoint.
|
||||
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
|
||||
|
||||
if save_param_scheduler and hasattr(self, 'param_schedulers'):
|
||||
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||
self.model.save_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
client_state=extra_ckpt,
|
||||
save_latest=False,
|
||||
exclude_frozen_parameters=self.exclude_frozen_parameters)
|
||||
if (not save_optimizer
|
||||
and self.model.zero_optimization_partition_weights()
|
||||
and not self.model.zero_gather_16bit_weights_on_model_save()):
|
||||
print_log(
|
||||
'Configured to `save_optimizer=False`, but currently using '
|
||||
"DeepSpeed's ZeRO stage 3 with "
|
||||
'`gather_16bit_weights_on_model_save=False`. In '
|
||||
'this configuration, the model cannot be saved properly '
|
||||
'and will be saved with the optimizer state. '
|
||||
'To support `save_optimizer=False`, please set '
|
||||
'`gather_16bit_weights_on_model_save=True` in your '
|
||||
'DeepSpeed config.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
save_optimizer = True
|
||||
|
||||
if save_optimizer:
|
||||
if hasattr(self, 'optim_wrapper'):
|
||||
# The key can not be 'optimizer', otherwise error will be
|
||||
# thrown when loading or resuming checkpoint.
|
||||
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||
self.model.save_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
client_state=extra_ckpt,
|
||||
save_latest=False,
|
||||
exclude_frozen_parameters=self.exclude_frozen_parameters)
|
||||
else:
|
||||
self.model.save_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
client_state=extra_ckpt,
|
||||
save_latest=False)
|
||||
else:
|
||||
self.model.save_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
client_state=extra_ckpt,
|
||||
save_latest=False)
|
||||
if self.model.zero_optimization_partition_weights():
|
||||
# TODO: `_zero3_consolidated_16bit_state_dict` doesn't support
|
||||
# `exclude_frozen_parameters`.
|
||||
state_dict = self.model._zero3_consolidated_16bit_state_dict()
|
||||
else:
|
||||
state_dict = self.model.module_state_dict(
|
||||
exclude_frozen_parameters=self.exclude_frozen_parameters)
|
||||
if is_main_process():
|
||||
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
|
||||
save_checkpoint(ckpt, filename)
|
||||
|
Loading…
x
Reference in New Issue
Block a user