mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Save optimizer.state_dict() in cpu by default (#966)
This commit is contained in:
parent
9868131c98
commit
6ba667c8cf
@ -116,3 +116,4 @@ Miscellaneous
|
|||||||
requires_executable
|
requires_executable
|
||||||
requires_package
|
requires_package
|
||||||
check_time
|
check_time
|
||||||
|
apply_to
|
||||||
|
@ -116,3 +116,4 @@ Miscellaneous
|
|||||||
requires_executable
|
requires_executable
|
||||||
requires_package
|
requires_package
|
||||||
check_time
|
check_time
|
||||||
|
apply_to
|
||||||
|
@ -18,7 +18,8 @@ from mmengine.fileio import FileClient, get_file_backend
|
|||||||
from mmengine.fileio import load as load_file
|
from mmengine.fileio import load as load_file
|
||||||
from mmengine.logging import print_log
|
from mmengine.logging import print_log
|
||||||
from mmengine.model import BaseTTAModel, is_model_wrapper
|
from mmengine.model import BaseTTAModel, is_model_wrapper
|
||||||
from mmengine.utils import deprecated_function, digit_version, mkdir_or_exist
|
from mmengine.utils import (apply_to, deprecated_function, digit_version,
|
||||||
|
mkdir_or_exist)
|
||||||
from mmengine.utils.dl_utils import load_url
|
from mmengine.utils.dl_utils import load_url
|
||||||
|
|
||||||
# `MMENGINE_HOME` is the highest priority directory to save checkpoints
|
# `MMENGINE_HOME` is the highest priority directory to save checkpoints
|
||||||
@ -622,12 +623,11 @@ def weights_to_cpu(state_dict):
|
|||||||
Returns:
|
Returns:
|
||||||
OrderedDict: Model weights on GPU.
|
OrderedDict: Model weights on GPU.
|
||||||
"""
|
"""
|
||||||
state_dict_cpu = OrderedDict()
|
state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'),
|
||||||
for key, val in state_dict.items():
|
lambda x: x.cpu())
|
||||||
state_dict_cpu[key] = val.cpu()
|
|
||||||
# Keep metadata in state_dict
|
# Keep metadata in state_dict
|
||||||
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
|
state_dict._metadata = getattr(state_dict, '_metadata', OrderedDict())
|
||||||
return state_dict_cpu
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
@deprecated_function(
|
@deprecated_function(
|
||||||
|
@ -35,14 +35,14 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
|
|||||||
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
|
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
|
||||||
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
|
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
|
||||||
RUNNERS, VISUALIZERS, DefaultScope)
|
RUNNERS, VISUALIZERS, DefaultScope)
|
||||||
from mmengine.utils import digit_version, get_git_hash, is_seq_of
|
from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of
|
||||||
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
|
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
|
||||||
set_multi_processing)
|
set_multi_processing)
|
||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||||
find_latest_checkpoint, get_state_dict,
|
find_latest_checkpoint, save_checkpoint,
|
||||||
save_checkpoint, weights_to_cpu)
|
weights_to_cpu)
|
||||||
from .log_processor import LogProcessor
|
from .log_processor import LogProcessor
|
||||||
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
||||||
from .priority import Priority, get_priority
|
from .priority import Priority, get_priority
|
||||||
@ -2164,14 +2164,20 @@ class Runner:
|
|||||||
model = self.model
|
model = self.model
|
||||||
|
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
'meta': meta,
|
'meta':
|
||||||
'state_dict': weights_to_cpu(get_state_dict(model)),
|
meta,
|
||||||
'message_hub': self.message_hub.state_dict()
|
'state_dict':
|
||||||
|
weights_to_cpu(model.state_dict()),
|
||||||
|
'message_hub':
|
||||||
|
apply_to(self.message_hub.state_dict(),
|
||||||
|
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()),
|
||||||
}
|
}
|
||||||
# save optimizer state dict to checkpoint
|
# save optimizer state dict to checkpoint
|
||||||
if save_optimizer:
|
if save_optimizer:
|
||||||
if isinstance(self.optim_wrapper, OptimWrapper):
|
if isinstance(self.optim_wrapper, OptimWrapper):
|
||||||
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
|
checkpoint['optimizer'] = apply_to(
|
||||||
|
self.optim_wrapper.state_dict(),
|
||||||
|
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu())
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'self.optim_wrapper should be an `OptimWrapper` '
|
'self.optim_wrapper should be an `OptimWrapper` '
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .manager import ManagerMeta, ManagerMixin
|
from .manager import ManagerMeta, ManagerMixin
|
||||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
from .misc import (apply_to, check_prerequisites, concat_list,
|
||||||
deprecated_function, has_method,
|
deprecated_api_warning, deprecated_function, has_method,
|
||||||
import_modules_from_strings, is_list_of,
|
import_modules_from_strings, is_list_of,
|
||||||
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
||||||
iter_cast, list_cast, requires_executable, requires_package,
|
iter_cast, list_cast, requires_executable, requires_package,
|
||||||
@ -27,5 +27,6 @@ __all__ = [
|
|||||||
'is_abs', 'is_method_overridden', 'has_method', 'digit_version',
|
'is_abs', 'is_method_overridden', 'has_method', 'digit_version',
|
||||||
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
|
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
|
||||||
'TimerError', 'ProgressBar', 'track_iter_progress',
|
'TimerError', 'ProgressBar', 'track_iter_progress',
|
||||||
'track_parallel_progress', 'track_progress', 'deprecated_function'
|
'track_parallel_progress', 'track_progress', 'deprecated_function',
|
||||||
|
'apply_to'
|
||||||
]
|
]
|
||||||
|
@ -217,6 +217,47 @@ def concat_list(in_list):
|
|||||||
return list(itertools.chain(*in_list))
|
return list(itertools.chain(*in_list))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_to(data: Any, expr: Callable, apply_func: Callable):
|
||||||
|
"""Apply function to each element in dict, list or tuple that matches with
|
||||||
|
the expression.
|
||||||
|
|
||||||
|
For examples, if you want to convert each element in a list of dict from
|
||||||
|
`np.ndarray` to `Tensor`. You can use the following code:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mmengine.utils import apply_to
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> import torch
|
||||||
|
>>> data = dict(array=[np.array(1)]) # {'array': [array(1)]}
|
||||||
|
>>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x))
|
||||||
|
>>> print(result) # {'array': [tensor(1)]}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Any): Data to be applied.
|
||||||
|
expr (Callable): Expression to tell which data should be applied with
|
||||||
|
the function. It should return a boolean.
|
||||||
|
apply_func (Callable): Function applied to data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The data after applying.
|
||||||
|
""" # noqa: E501
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# Keep the original dict type
|
||||||
|
res = type(data)()
|
||||||
|
for key, value in data.items():
|
||||||
|
res[key] = apply_to(value, expr, apply_func)
|
||||||
|
return res
|
||||||
|
elif isinstance(data, tuple) and hasattr(data, '_fields'):
|
||||||
|
# namedtuple
|
||||||
|
return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
|
||||||
|
elif isinstance(data, (tuple, list)):
|
||||||
|
return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable
|
||||||
|
elif expr(data):
|
||||||
|
return apply_func(data)
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def check_prerequisites(
|
def check_prerequisites(
|
||||||
prerequisites,
|
prerequisites,
|
||||||
checker,
|
checker,
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from mmengine import MMLogger
|
from mmengine import MMLogger
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from mmengine.utils.misc import (concat_list, deprecated_api_warning,
|
from mmengine.utils.misc import (apply_to, concat_list, deprecated_api_warning,
|
||||||
deprecated_function, has_method,
|
deprecated_function, has_method,
|
||||||
import_modules_from_strings, is_list_of,
|
import_modules_from_strings, is_list_of,
|
||||||
is_method_overridden, is_seq_of, is_tuple_of,
|
is_method_overridden, is_seq_of, is_tuple_of,
|
||||||
@ -283,3 +287,43 @@ def test_deprecated_function():
|
|||||||
|
|
||||||
Short summary.""" # noqa: E122
|
Short summary.""" # noqa: E122
|
||||||
assert expected_docstring.strip(' ') == deprecated_demo1.__doc__
|
assert expected_docstring.strip(' ') == deprecated_demo1.__doc__
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_to():
|
||||||
|
# Test only apply `+1` to int object.
|
||||||
|
data = dict(a=1, b=2.0)
|
||||||
|
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
|
||||||
|
assert result == dict(a=2, b=2.0)
|
||||||
|
|
||||||
|
# Test with nested data
|
||||||
|
data = dict(a=[dict(c=1)], b=2.0)
|
||||||
|
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
|
||||||
|
assert result == dict(a=[dict(c=2)], b=2.0)
|
||||||
|
|
||||||
|
# Tensor to numpy
|
||||||
|
data = dict(a=[dict(c=torch.tensor(1))], b=torch.tensor(2))
|
||||||
|
result = apply_to(data, lambda x: isinstance(x, torch.Tensor),
|
||||||
|
lambda x: x.numpy())
|
||||||
|
assert isinstance(result['b'], np.ndarray)
|
||||||
|
assert isinstance(result['a'][0]['c'], np.ndarray)
|
||||||
|
|
||||||
|
# Tuple and convert string
|
||||||
|
data = (1, dict(a=[dict(b=2.0)]), 'test')
|
||||||
|
result = apply_to(
|
||||||
|
data, lambda x: isinstance(x, int) or x == 'test',
|
||||||
|
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert isinstance(result[0], torch.Tensor)
|
||||||
|
assert isinstance(result[1]['a'][0]['b'], float)
|
||||||
|
assert result[2] == 'train'
|
||||||
|
|
||||||
|
# Named Tuple
|
||||||
|
dataclass = namedtuple('Data', ['a', 'b'])
|
||||||
|
data = dataclass('test', dict(a=[dict(c=1)], b=2.0))
|
||||||
|
result = apply_to(
|
||||||
|
data, lambda x: isinstance(x, int) or x == 'test',
|
||||||
|
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
|
||||||
|
assert isinstance(result, dataclass)
|
||||||
|
assert result[0] == 'train'
|
||||||
|
assert isinstance(result.b['a'][0]['c'], torch.Tensor)
|
||||||
|
assert isinstance(result.b['b'], float)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user