[Fix] Save optimizer.state_dict() in cpu by default (#966)

This commit is contained in:
Mashiro 2023-04-26 16:47:47 +08:00 committed by GitHub
parent 9868131c98
commit 6ba667c8cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 111 additions and 17 deletions

View File

@ -116,3 +116,4 @@ Miscellaneous
requires_executable
requires_package
check_time
apply_to

View File

@ -116,3 +116,4 @@ Miscellaneous
requires_executable
requires_package
check_time
apply_to

View File

@ -18,7 +18,8 @@ from mmengine.fileio import FileClient, get_file_backend
from mmengine.fileio import load as load_file
from mmengine.logging import print_log
from mmengine.model import BaseTTAModel, is_model_wrapper
from mmengine.utils import deprecated_function, digit_version, mkdir_or_exist
from mmengine.utils import (apply_to, deprecated_function, digit_version,
mkdir_or_exist)
from mmengine.utils.dl_utils import load_url
# `MMENGINE_HOME` is the highest priority directory to save checkpoints
@ -622,12 +623,11 @@ def weights_to_cpu(state_dict):
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'),
lambda x: x.cpu())
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
return state_dict_cpu
state_dict._metadata = getattr(state_dict, '_metadata', OrderedDict())
return state_dict
@deprecated_function(

View File

@ -35,14 +35,14 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
RUNNERS, VISUALIZERS, DefaultScope)
from mmengine.utils import digit_version, get_git_hash, is_seq_of
from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
find_latest_checkpoint, get_state_dict,
save_checkpoint, weights_to_cpu)
find_latest_checkpoint, save_checkpoint,
weights_to_cpu)
from .log_processor import LogProcessor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
@ -2164,14 +2164,20 @@ class Runner:
model = self.model
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model)),
'message_hub': self.message_hub.state_dict()
'meta':
meta,
'state_dict':
weights_to_cpu(model.state_dict()),
'message_hub':
apply_to(self.message_hub.state_dict(),
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()),
}
# save optimizer state dict to checkpoint
if save_optimizer:
if isinstance(self.optim_wrapper, OptimWrapper):
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
checkpoint['optimizer'] = apply_to(
self.optim_wrapper.state_dict(),
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu())
else:
raise TypeError(
'self.optim_wrapper should be an `OptimWrapper` '

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
deprecated_function, has_method,
from .misc import (apply_to, check_prerequisites, concat_list,
deprecated_api_warning, deprecated_function, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
@ -27,5 +27,6 @@ __all__ = [
'is_abs', 'is_method_overridden', 'has_method', 'digit_version',
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress', 'deprecated_function'
'track_parallel_progress', 'track_progress', 'deprecated_function',
'apply_to'
]

View File

@ -217,6 +217,47 @@ def concat_list(in_list):
return list(itertools.chain(*in_list))
def apply_to(data: Any, expr: Callable, apply_func: Callable):
"""Apply function to each element in dict, list or tuple that matches with
the expression.
For examples, if you want to convert each element in a list of dict from
`np.ndarray` to `Tensor`. You can use the following code:
Examples:
>>> from mmengine.utils import apply_to
>>> import numpy as np
>>> import torch
>>> data = dict(array=[np.array(1)]) # {'array': [array(1)]}
>>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x))
>>> print(result) # {'array': [tensor(1)]}
Args:
data (Any): Data to be applied.
expr (Callable): Expression to tell which data should be applied with
the function. It should return a boolean.
apply_func (Callable): Function applied to data.
Returns:
Any: The data after applying.
""" # noqa: E501
if isinstance(data, dict):
# Keep the original dict type
res = type(data)()
for key, value in data.items():
res[key] = apply_to(value, expr, apply_func)
return res
elif isinstance(data, tuple) and hasattr(data, '_fields'):
# namedtuple
return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
elif isinstance(data, (tuple, list)):
return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable
elif expr(data):
return apply_func(data)
else:
return data
def check_prerequisites(
prerequisites,
checker,

View File

@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import namedtuple
import numpy as np
import pytest
import torch
from mmengine import MMLogger
# yapf: disable
from mmengine.utils.misc import (concat_list, deprecated_api_warning,
from mmengine.utils.misc import (apply_to, concat_list, deprecated_api_warning,
deprecated_function, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_tuple_of,
@ -283,3 +287,43 @@ def test_deprecated_function():
Short summary.""" # noqa: E122
assert expected_docstring.strip(' ') == deprecated_demo1.__doc__
def test_apply_to():
# Test only apply `+1` to int object.
data = dict(a=1, b=2.0)
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
assert result == dict(a=2, b=2.0)
# Test with nested data
data = dict(a=[dict(c=1)], b=2.0)
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
assert result == dict(a=[dict(c=2)], b=2.0)
# Tensor to numpy
data = dict(a=[dict(c=torch.tensor(1))], b=torch.tensor(2))
result = apply_to(data, lambda x: isinstance(x, torch.Tensor),
lambda x: x.numpy())
assert isinstance(result['b'], np.ndarray)
assert isinstance(result['a'][0]['c'], np.ndarray)
# Tuple and convert string
data = (1, dict(a=[dict(b=2.0)]), 'test')
result = apply_to(
data, lambda x: isinstance(x, int) or x == 'test',
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
assert isinstance(result, tuple)
assert isinstance(result[0], torch.Tensor)
assert isinstance(result[1]['a'][0]['b'], float)
assert result[2] == 'train'
# Named Tuple
dataclass = namedtuple('Data', ['a', 'b'])
data = dataclass('test', dict(a=[dict(c=1)], b=2.0))
result = apply_to(
data, lambda x: isinstance(x, int) or x == 'test',
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
assert isinstance(result, dataclass)
assert result[0] == 'train'
assert isinstance(result.b['a'][0]['c'], torch.Tensor)
assert isinstance(result.b['b'], float)