mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Save optimizer.state_dict() in cpu by default (#966)
This commit is contained in:
parent
9868131c98
commit
6ba667c8cf
@ -116,3 +116,4 @@ Miscellaneous
|
||||
requires_executable
|
||||
requires_package
|
||||
check_time
|
||||
apply_to
|
||||
|
@ -116,3 +116,4 @@ Miscellaneous
|
||||
requires_executable
|
||||
requires_package
|
||||
check_time
|
||||
apply_to
|
||||
|
@ -18,7 +18,8 @@ from mmengine.fileio import FileClient, get_file_backend
|
||||
from mmengine.fileio import load as load_file
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseTTAModel, is_model_wrapper
|
||||
from mmengine.utils import deprecated_function, digit_version, mkdir_or_exist
|
||||
from mmengine.utils import (apply_to, deprecated_function, digit_version,
|
||||
mkdir_or_exist)
|
||||
from mmengine.utils.dl_utils import load_url
|
||||
|
||||
# `MMENGINE_HOME` is the highest priority directory to save checkpoints
|
||||
@ -622,12 +623,11 @@ def weights_to_cpu(state_dict):
|
||||
Returns:
|
||||
OrderedDict: Model weights on GPU.
|
||||
"""
|
||||
state_dict_cpu = OrderedDict()
|
||||
for key, val in state_dict.items():
|
||||
state_dict_cpu[key] = val.cpu()
|
||||
state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'),
|
||||
lambda x: x.cpu())
|
||||
# Keep metadata in state_dict
|
||||
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
|
||||
return state_dict_cpu
|
||||
state_dict._metadata = getattr(state_dict, '_metadata', OrderedDict())
|
||||
return state_dict
|
||||
|
||||
|
||||
@deprecated_function(
|
||||
|
@ -35,14 +35,14 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
|
||||
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
|
||||
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
|
||||
RUNNERS, VISUALIZERS, DefaultScope)
|
||||
from mmengine.utils import digit_version, get_git_hash, is_seq_of
|
||||
from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of
|
||||
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
|
||||
set_multi_processing)
|
||||
from mmengine.visualization import Visualizer
|
||||
from .base_loop import BaseLoop
|
||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||
find_latest_checkpoint, get_state_dict,
|
||||
save_checkpoint, weights_to_cpu)
|
||||
find_latest_checkpoint, save_checkpoint,
|
||||
weights_to_cpu)
|
||||
from .log_processor import LogProcessor
|
||||
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
||||
from .priority import Priority, get_priority
|
||||
@ -2164,14 +2164,20 @@ class Runner:
|
||||
model = self.model
|
||||
|
||||
checkpoint = {
|
||||
'meta': meta,
|
||||
'state_dict': weights_to_cpu(get_state_dict(model)),
|
||||
'message_hub': self.message_hub.state_dict()
|
||||
'meta':
|
||||
meta,
|
||||
'state_dict':
|
||||
weights_to_cpu(model.state_dict()),
|
||||
'message_hub':
|
||||
apply_to(self.message_hub.state_dict(),
|
||||
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()),
|
||||
}
|
||||
# save optimizer state dict to checkpoint
|
||||
if save_optimizer:
|
||||
if isinstance(self.optim_wrapper, OptimWrapper):
|
||||
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
|
||||
checkpoint['optimizer'] = apply_to(
|
||||
self.optim_wrapper.state_dict(),
|
||||
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu())
|
||||
else:
|
||||
raise TypeError(
|
||||
'self.optim_wrapper should be an `OptimWrapper` '
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .manager import ManagerMeta, ManagerMixin
|
||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
deprecated_function, has_method,
|
||||
from .misc import (apply_to, check_prerequisites, concat_list,
|
||||
deprecated_api_warning, deprecated_function, has_method,
|
||||
import_modules_from_strings, is_list_of,
|
||||
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
||||
iter_cast, list_cast, requires_executable, requires_package,
|
||||
@ -27,5 +27,6 @@ __all__ = [
|
||||
'is_abs', 'is_method_overridden', 'has_method', 'digit_version',
|
||||
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
|
||||
'TimerError', 'ProgressBar', 'track_iter_progress',
|
||||
'track_parallel_progress', 'track_progress', 'deprecated_function'
|
||||
'track_parallel_progress', 'track_progress', 'deprecated_function',
|
||||
'apply_to'
|
||||
]
|
||||
|
@ -217,6 +217,47 @@ def concat_list(in_list):
|
||||
return list(itertools.chain(*in_list))
|
||||
|
||||
|
||||
def apply_to(data: Any, expr: Callable, apply_func: Callable):
|
||||
"""Apply function to each element in dict, list or tuple that matches with
|
||||
the expression.
|
||||
|
||||
For examples, if you want to convert each element in a list of dict from
|
||||
`np.ndarray` to `Tensor`. You can use the following code:
|
||||
|
||||
Examples:
|
||||
>>> from mmengine.utils import apply_to
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
>>> data = dict(array=[np.array(1)]) # {'array': [array(1)]}
|
||||
>>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x))
|
||||
>>> print(result) # {'array': [tensor(1)]}
|
||||
|
||||
Args:
|
||||
data (Any): Data to be applied.
|
||||
expr (Callable): Expression to tell which data should be applied with
|
||||
the function. It should return a boolean.
|
||||
apply_func (Callable): Function applied to data.
|
||||
|
||||
Returns:
|
||||
Any: The data after applying.
|
||||
""" # noqa: E501
|
||||
if isinstance(data, dict):
|
||||
# Keep the original dict type
|
||||
res = type(data)()
|
||||
for key, value in data.items():
|
||||
res[key] = apply_to(value, expr, apply_func)
|
||||
return res
|
||||
elif isinstance(data, tuple) and hasattr(data, '_fields'):
|
||||
# namedtuple
|
||||
return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
|
||||
elif isinstance(data, (tuple, list)):
|
||||
return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable
|
||||
elif expr(data):
|
||||
return apply_func(data)
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def check_prerequisites(
|
||||
prerequisites,
|
||||
checker,
|
||||
|
@ -1,9 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine import MMLogger
|
||||
# yapf: disable
|
||||
from mmengine.utils.misc import (concat_list, deprecated_api_warning,
|
||||
from mmengine.utils.misc import (apply_to, concat_list, deprecated_api_warning,
|
||||
deprecated_function, has_method,
|
||||
import_modules_from_strings, is_list_of,
|
||||
is_method_overridden, is_seq_of, is_tuple_of,
|
||||
@ -283,3 +287,43 @@ def test_deprecated_function():
|
||||
|
||||
Short summary.""" # noqa: E122
|
||||
assert expected_docstring.strip(' ') == deprecated_demo1.__doc__
|
||||
|
||||
|
||||
def test_apply_to():
|
||||
# Test only apply `+1` to int object.
|
||||
data = dict(a=1, b=2.0)
|
||||
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
|
||||
assert result == dict(a=2, b=2.0)
|
||||
|
||||
# Test with nested data
|
||||
data = dict(a=[dict(c=1)], b=2.0)
|
||||
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
|
||||
assert result == dict(a=[dict(c=2)], b=2.0)
|
||||
|
||||
# Tensor to numpy
|
||||
data = dict(a=[dict(c=torch.tensor(1))], b=torch.tensor(2))
|
||||
result = apply_to(data, lambda x: isinstance(x, torch.Tensor),
|
||||
lambda x: x.numpy())
|
||||
assert isinstance(result['b'], np.ndarray)
|
||||
assert isinstance(result['a'][0]['c'], np.ndarray)
|
||||
|
||||
# Tuple and convert string
|
||||
data = (1, dict(a=[dict(b=2.0)]), 'test')
|
||||
result = apply_to(
|
||||
data, lambda x: isinstance(x, int) or x == 'test',
|
||||
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
|
||||
assert isinstance(result, tuple)
|
||||
assert isinstance(result[0], torch.Tensor)
|
||||
assert isinstance(result[1]['a'][0]['b'], float)
|
||||
assert result[2] == 'train'
|
||||
|
||||
# Named Tuple
|
||||
dataclass = namedtuple('Data', ['a', 'b'])
|
||||
data = dataclass('test', dict(a=[dict(c=1)], b=2.0))
|
||||
result = apply_to(
|
||||
data, lambda x: isinstance(x, int) or x == 'test',
|
||||
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
|
||||
assert isinstance(result, dataclass)
|
||||
assert result[0] == 'train'
|
||||
assert isinstance(result.b['a'][0]['c'], torch.Tensor)
|
||||
assert isinstance(result.b['b'], float)
|
||||
|
Loading…
x
Reference in New Issue
Block a user