mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Imporve] Support _load_state_dict_post_hooks
in load_state_dict
. (#1103)
* [Imporve] Support `_load_state_dict_post_hooks` in `load_state_dict`. * Update * Add unit test
This commit is contained in:
parent
6ba667c8cf
commit
49b27dd83f
@ -5,7 +5,7 @@ import os
|
||||
import os.path as osp
|
||||
import pkgutil
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, namedtuple
|
||||
from importlib import import_module
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable, Dict, Optional
|
||||
@ -33,6 +33,17 @@ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
|
||||
|
||||
class _IncompatibleKeys(
|
||||
namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
|
||||
|
||||
def __repr__(self):
|
||||
if not self.missing_keys and not self.unexpected_keys:
|
||||
return '<All keys matched successfully>'
|
||||
return super().__repr__()
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
def _get_mmengine_home():
|
||||
mmengine_home = os.path.expanduser(
|
||||
os.getenv(
|
||||
@ -61,35 +72,53 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
|
||||
message. If not specified, print function will be used.
|
||||
"""
|
||||
unexpected_keys = []
|
||||
all_missing_keys = []
|
||||
missing_keys = []
|
||||
err_msg = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
# use _load_from_state_dict to enable checkpoint version control
|
||||
def load(module, prefix=''):
|
||||
def load(module, local_state_dict, prefix=''):
|
||||
# recursively check parallel module in case that the model has a
|
||||
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
||||
if is_model_wrapper(module) or isinstance(module, BaseTTAModel):
|
||||
module = module.module
|
||||
local_metadata = {} if metadata is None else metadata.get(
|
||||
prefix[:-1], {})
|
||||
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
||||
all_missing_keys, unexpected_keys,
|
||||
module._load_from_state_dict(local_state_dict, prefix, local_metadata,
|
||||
True, missing_keys, unexpected_keys,
|
||||
err_msg)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
child_prefix = prefix + name + '.'
|
||||
child_state_dict = {
|
||||
k: v
|
||||
for k, v in local_state_dict.items()
|
||||
if k.startswith(child_prefix)
|
||||
}
|
||||
load(child, child_state_dict, child_prefix)
|
||||
|
||||
load(module)
|
||||
# Note that the hook can modify missing_keys and unexpected_keys.
|
||||
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
if hasattr(module, '_load_state_dict_post_hooks'):
|
||||
for hook in module._load_state_dict_post_hooks.values():
|
||||
out = hook(module, incompatible_keys)
|
||||
assert out is None, (
|
||||
'Hooks registered with '
|
||||
'``register_load_state_dict_post_hook`` are not expected '
|
||||
'to return new values, if incompatible_keys need to be '
|
||||
'modified, it should be done inplace.')
|
||||
|
||||
load(module, state_dict)
|
||||
load = None # break load->load reference cycle
|
||||
|
||||
# ignore "num_batches_tracked" of BN layers
|
||||
missing_keys = [
|
||||
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
||||
key for key in missing_keys if 'num_batches_tracked' not in key
|
||||
]
|
||||
|
||||
if unexpected_keys:
|
||||
|
@ -18,7 +18,7 @@ from mmengine.runner.checkpoint import (CheckpointLoader,
|
||||
_load_checkpoint_with_prefix,
|
||||
get_state_dict, load_checkpoint,
|
||||
load_from_local, load_from_pavi,
|
||||
save_checkpoint)
|
||||
load_state_dict, save_checkpoint)
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
@ -406,3 +406,30 @@ def test_load_from_local():
|
||||
assert_tensor_equal(checkpoint['block.conv.weight'],
|
||||
model.block.conv.weight)
|
||||
os.remove(checkpoint_path)
|
||||
|
||||
|
||||
def test_load_state_dict_post_hooks():
|
||||
module = Block()
|
||||
|
||||
state_dict = {
|
||||
'conv.weight': torch.empty((3, 3, 1, 1), dtype=torch.float32),
|
||||
'conv.bias': torch.empty((3, ), dtype=torch.float32),
|
||||
'norm.weight': torch.empty([3], dtype=torch.float32),
|
||||
'norm.bias': torch.empty([3], dtype=torch.float32),
|
||||
'norm.running_mean': torch.empty([3], dtype=torch.float32),
|
||||
'norm.running_var': torch.empty([3], dtype=torch.float32),
|
||||
}
|
||||
state_dict.pop('norm.running_var')
|
||||
|
||||
with patch('mmengine.runner.checkpoint.print_log') as mock:
|
||||
load_state_dict(module, state_dict, strict=False)
|
||||
mock.assert_called_once()
|
||||
|
||||
def post_hook(_, incompatible_keys):
|
||||
incompatible_keys.missing_keys.remove('norm.running_var')
|
||||
|
||||
module._load_state_dict_post_hooks = {0: post_hook}
|
||||
|
||||
with patch('mmengine.runner.checkpoint.print_log') as mock:
|
||||
load_state_dict(module, state_dict, strict=False)
|
||||
mock.assert_not_called()
|
||||
|
Loading…
x
Reference in New Issue
Block a user