[Imporve] Support _load_state_dict_post_hooks in load_state_dict. (#1103)

* [Imporve] Support `_load_state_dict_post_hooks` in `load_state_dict`.

* Update

* Add unit test
This commit is contained in:
Ma Zerun 2023-04-26 16:48:57 +08:00 committed by GitHub
parent 6ba667c8cf
commit 49b27dd83f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 9 deletions

View File

@ -5,7 +5,7 @@ import os
import os.path as osp
import pkgutil
import re
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from importlib import import_module
from tempfile import TemporaryDirectory
from typing import Callable, Dict, Optional
@ -33,6 +33,17 @@ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
class _IncompatibleKeys(
namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
def __repr__(self):
if not self.missing_keys and not self.unexpected_keys:
return '<All keys matched successfully>'
return super().__repr__()
__str__ = __repr__
def _get_mmengine_home():
mmengine_home = os.path.expanduser(
os.getenv(
@ -61,35 +72,53 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
missing_keys = []
err_msg = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
def load(module, local_state_dict, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_model_wrapper(module) or isinstance(module, BaseTTAModel):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
module._load_from_state_dict(local_state_dict, prefix, local_metadata,
True, missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
child_prefix = prefix + name + '.'
child_state_dict = {
k: v
for k, v in local_state_dict.items()
if k.startswith(child_prefix)
}
load(child, child_state_dict, child_prefix)
load(module)
# Note that the hook can modify missing_keys and unexpected_keys.
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
if hasattr(module, '_load_state_dict_post_hooks'):
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible_keys)
assert out is None, (
'Hooks registered with '
'``register_load_state_dict_post_hook`` are not expected '
'to return new values, if incompatible_keys need to be '
'modified, it should be done inplace.')
load(module, state_dict)
load = None # break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
key for key in missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:

View File

@ -18,7 +18,7 @@ from mmengine.runner.checkpoint import (CheckpointLoader,
_load_checkpoint_with_prefix,
get_state_dict, load_checkpoint,
load_from_local, load_from_pavi,
save_checkpoint)
load_state_dict, save_checkpoint)
@MODEL_WRAPPERS.register_module()
@ -406,3 +406,30 @@ def test_load_from_local():
assert_tensor_equal(checkpoint['block.conv.weight'],
model.block.conv.weight)
os.remove(checkpoint_path)
def test_load_state_dict_post_hooks():
module = Block()
state_dict = {
'conv.weight': torch.empty((3, 3, 1, 1), dtype=torch.float32),
'conv.bias': torch.empty((3, ), dtype=torch.float32),
'norm.weight': torch.empty([3], dtype=torch.float32),
'norm.bias': torch.empty([3], dtype=torch.float32),
'norm.running_mean': torch.empty([3], dtype=torch.float32),
'norm.running_var': torch.empty([3], dtype=torch.float32),
}
state_dict.pop('norm.running_var')
with patch('mmengine.runner.checkpoint.print_log') as mock:
load_state_dict(module, state_dict, strict=False)
mock.assert_called_once()
def post_hook(_, incompatible_keys):
incompatible_keys.missing_keys.remove('norm.running_var')
module._load_state_dict_post_hooks = {0: post_hook}
with patch('mmengine.runner.checkpoint.print_log') as mock:
load_state_dict(module, state_dict, strict=False)
mock.assert_not_called()