diff --git a/docs/en/api/utils.rst b/docs/en/api/utils.rst index 681e15d2..21c8f38e 100644 --- a/docs/en/api/utils.rst +++ b/docs/en/api/utils.rst @@ -116,3 +116,4 @@ Miscellaneous requires_executable requires_package check_time + apply_to diff --git a/docs/zh_cn/api/utils.rst b/docs/zh_cn/api/utils.rst index 681e15d2..21c8f38e 100644 --- a/docs/zh_cn/api/utils.rst +++ b/docs/zh_cn/api/utils.rst @@ -116,3 +116,4 @@ Miscellaneous requires_executable requires_package check_time + apply_to diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 6b4f4d6e..cf6cf0c4 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -18,7 +18,8 @@ from mmengine.fileio import FileClient, get_file_backend from mmengine.fileio import load as load_file from mmengine.logging import print_log from mmengine.model import BaseTTAModel, is_model_wrapper -from mmengine.utils import deprecated_function, digit_version, mkdir_or_exist +from mmengine.utils import (apply_to, deprecated_function, digit_version, + mkdir_or_exist) from mmengine.utils.dl_utils import load_url # `MMENGINE_HOME` is the highest priority directory to save checkpoints @@ -622,12 +623,11 @@ def weights_to_cpu(state_dict): Returns: OrderedDict: Model weights on GPU. """ - state_dict_cpu = OrderedDict() - for key, val in state_dict.items(): - state_dict_cpu[key] = val.cpu() + state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'), + lambda x: x.cpu()) # Keep metadata in state_dict - state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict()) - return state_dict_cpu + state_dict._metadata = getattr(state_dict, '_metadata', OrderedDict()) + return state_dict @deprecated_function( diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index c4e787f4..29d704ea 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -35,14 +35,14 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS, VISUALIZERS, DefaultScope) -from mmengine.utils import digit_version, get_git_hash, is_seq_of +from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, set_multi_processing) from mmengine.visualization import Visualizer from .base_loop import BaseLoop from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, - find_latest_checkpoint, get_state_dict, - save_checkpoint, weights_to_cpu) + find_latest_checkpoint, save_checkpoint, + weights_to_cpu) from .log_processor import LogProcessor from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority @@ -2164,14 +2164,20 @@ class Runner: model = self.model checkpoint = { - 'meta': meta, - 'state_dict': weights_to_cpu(get_state_dict(model)), - 'message_hub': self.message_hub.state_dict() + 'meta': + meta, + 'state_dict': + weights_to_cpu(model.state_dict()), + 'message_hub': + apply_to(self.message_hub.state_dict(), + lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()), } # save optimizer state dict to checkpoint if save_optimizer: if isinstance(self.optim_wrapper, OptimWrapper): - checkpoint['optimizer'] = self.optim_wrapper.state_dict() + checkpoint['optimizer'] = apply_to( + self.optim_wrapper.state_dict(), + lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()) else: raise TypeError( 'self.optim_wrapper should be an `OptimWrapper` ' diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 9ad47dda..2800d935 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .manager import ManagerMeta, ManagerMixin -from .misc import (check_prerequisites, concat_list, deprecated_api_warning, - deprecated_function, has_method, +from .misc import (apply_to, check_prerequisites, concat_list, + deprecated_api_warning, deprecated_function, has_method, import_modules_from_strings, is_list_of, is_method_overridden, is_seq_of, is_str, is_tuple_of, iter_cast, list_cast, requires_executable, requires_package, @@ -27,5 +27,6 @@ __all__ = [ 'is_abs', 'is_method_overridden', 'has_method', 'digit_version', 'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress', - 'track_parallel_progress', 'track_progress', 'deprecated_function' + 'track_parallel_progress', 'track_progress', 'deprecated_function', + 'apply_to' ] diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index 39703f38..aaea15c4 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -217,6 +217,47 @@ def concat_list(in_list): return list(itertools.chain(*in_list)) +def apply_to(data: Any, expr: Callable, apply_func: Callable): + """Apply function to each element in dict, list or tuple that matches with + the expression. + + For examples, if you want to convert each element in a list of dict from + `np.ndarray` to `Tensor`. You can use the following code: + + Examples: + >>> from mmengine.utils import apply_to + >>> import numpy as np + >>> import torch + >>> data = dict(array=[np.array(1)]) # {'array': [array(1)]} + >>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x)) + >>> print(result) # {'array': [tensor(1)]} + + Args: + data (Any): Data to be applied. + expr (Callable): Expression to tell which data should be applied with + the function. It should return a boolean. + apply_func (Callable): Function applied to data. + + Returns: + Any: The data after applying. + """ # noqa: E501 + if isinstance(data, dict): + # Keep the original dict type + res = type(data)() + for key, value in data.items(): + res[key] = apply_to(value, expr, apply_func) + return res + elif isinstance(data, tuple) and hasattr(data, '_fields'): + # namedtuple + return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable + elif isinstance(data, (tuple, list)): + return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable + elif expr(data): + return apply_func(data) + else: + return data + + def check_prerequisites( prerequisites, checker, diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 95d7a006..3798759e 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -1,9 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +from collections import namedtuple + +import numpy as np import pytest +import torch from mmengine import MMLogger # yapf: disable -from mmengine.utils.misc import (concat_list, deprecated_api_warning, +from mmengine.utils.misc import (apply_to, concat_list, deprecated_api_warning, deprecated_function, has_method, import_modules_from_strings, is_list_of, is_method_overridden, is_seq_of, is_tuple_of, @@ -283,3 +287,43 @@ def test_deprecated_function(): Short summary.""" # noqa: E122 assert expected_docstring.strip(' ') == deprecated_demo1.__doc__ + + +def test_apply_to(): + # Test only apply `+1` to int object. + data = dict(a=1, b=2.0) + result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1) + assert result == dict(a=2, b=2.0) + + # Test with nested data + data = dict(a=[dict(c=1)], b=2.0) + result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1) + assert result == dict(a=[dict(c=2)], b=2.0) + + # Tensor to numpy + data = dict(a=[dict(c=torch.tensor(1))], b=torch.tensor(2)) + result = apply_to(data, lambda x: isinstance(x, torch.Tensor), + lambda x: x.numpy()) + assert isinstance(result['b'], np.ndarray) + assert isinstance(result['a'][0]['c'], np.ndarray) + + # Tuple and convert string + data = (1, dict(a=[dict(b=2.0)]), 'test') + result = apply_to( + data, lambda x: isinstance(x, int) or x == 'test', + lambda x: torch.Tensor(x) if isinstance(x, int) else 'train') + assert isinstance(result, tuple) + assert isinstance(result[0], torch.Tensor) + assert isinstance(result[1]['a'][0]['b'], float) + assert result[2] == 'train' + + # Named Tuple + dataclass = namedtuple('Data', ['a', 'b']) + data = dataclass('test', dict(a=[dict(c=1)], b=2.0)) + result = apply_to( + data, lambda x: isinstance(x, int) or x == 'test', + lambda x: torch.Tensor(x) if isinstance(x, int) else 'train') + assert isinstance(result, dataclass) + assert result[0] == 'train' + assert isinstance(result.b['a'][0]['c'], torch.Tensor) + assert isinstance(result.b['b'], float)