[Enhancement] Deprecate _save_to_state_dict implemented in mmengine (#610)

* [Refine] Make get_state_dict directly call nn.Module._save_to_state_dict

* deprecate _save_to_state_dict

* deprecate _save_to_state_dict in 0.5.0

* deprecate in 0.3.0

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/650/head
Mashiro 2022-10-28 17:14:08 +08:00 committed by GitHub
parent d1dd240796
commit 4aad15df90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 4 deletions

View File

@ -19,7 +19,7 @@ from mmengine.fileio import FileClient, get_file_backend
from mmengine.fileio import load as load_file from mmengine.fileio import load as load_file
from mmengine.logging import print_log from mmengine.logging import print_log
from mmengine.model import BaseTTAModel, is_model_wrapper from mmengine.model import BaseTTAModel, is_model_wrapper
from mmengine.utils import mkdir_or_exist from mmengine.utils import deprecated_function, mkdir_or_exist
from mmengine.utils.dl_utils import load_url from mmengine.utils.dl_utils import load_url
# `MMENGINE_HOME` is the highest priority directory to save checkpoints # `MMENGINE_HOME` is the highest priority directory to save checkpoints
@ -574,6 +574,11 @@ def weights_to_cpu(state_dict):
return state_dict_cpu return state_dict_cpu
@deprecated_function(
since='0.3.0',
removed_in='0.5.0',
instructions='`_save_to_state_dict` will be deprecated in the future, '
'please use `nn.Module._save_to_state_dict` directly.')
def _save_to_state_dict(module, destination, prefix, keep_vars): def _save_to_state_dict(module, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary. """Saves module state to `destination` dictionary.
@ -626,7 +631,7 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
destination._metadata = OrderedDict() destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict( destination._metadata[prefix[:-1]] = local_metadata = dict(
version=module._version) version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars) module._save_to_state_dict(destination, prefix, keep_vars)
for name, child in module._modules.items(): for name, child in module._modules.items():
if child is not None: if child is not None:
get_state_dict( get_state_dict(

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .manager import ManagerMeta, ManagerMixin from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning, from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
has_method, import_modules_from_strings, is_list_of, deprecated_function, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of, is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package, iter_cast, list_cast, requires_executable, requires_package,
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
@ -26,5 +27,5 @@ __all__ = [
'is_abs', 'is_method_overridden', 'has_method', 'digit_version', 'is_abs', 'is_method_overridden', 'has_method', 'digit_version',
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time', 'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
'TimerError', 'ProgressBar', 'track_iter_progress', 'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress' 'track_parallel_progress', 'track_progress', 'deprecated_function'
] ]