diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 99ab023c..97e4b197 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -19,7 +19,7 @@ from mmengine.fileio import FileClient, get_file_backend from mmengine.fileio import load as load_file from mmengine.logging import print_log from mmengine.model import BaseTTAModel, is_model_wrapper -from mmengine.utils import mkdir_or_exist +from mmengine.utils import deprecated_function, mkdir_or_exist from mmengine.utils.dl_utils import load_url # `MMENGINE_HOME` is the highest priority directory to save checkpoints @@ -574,6 +574,11 @@ def weights_to_cpu(state_dict): return state_dict_cpu +@deprecated_function( + since='0.3.0', + removed_in='0.5.0', + instructions='`_save_to_state_dict` will be deprecated in the future, ' + 'please use `nn.Module._save_to_state_dict` directly.') def _save_to_state_dict(module, destination, prefix, keep_vars): """Saves module state to `destination` dictionary. @@ -626,7 +631,7 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict( version=module._version) - _save_to_state_dict(module, destination, prefix, keep_vars) + module._save_to_state_dict(destination, prefix, keep_vars) for name, child in module._modules.items(): if child is not None: get_state_dict( diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 71f6f4f8..9ad47dda 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .manager import ManagerMeta, ManagerMixin from .misc import (check_prerequisites, concat_list, deprecated_api_warning, - has_method, import_modules_from_strings, is_list_of, + deprecated_function, has_method, + import_modules_from_strings, is_list_of, is_method_overridden, is_seq_of, is_str, is_tuple_of, iter_cast, list_cast, requires_executable, requires_package, slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, @@ -26,5 +27,5 @@ __all__ = [ 'is_abs', 'is_method_overridden', 'has_method', 'digit_version', 'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress', - 'track_parallel_progress', 'track_progress' + 'track_parallel_progress', 'track_progress', 'deprecated_function' ]