mirror of https://github.com/alibaba/EasyCV.git
remove modify of model load
parent
c0333d6f07
commit
a7d359f855
|
@ -1,16 +1,11 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import is_module_wrapper
|
||||
from mmcv.runner import _load_checkpoint as _load_checkpoint
|
||||
from mmcv.runner import load_checkpoint as mmcv_load_checkpoint
|
||||
from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
|
||||
from torch import distributed as dist
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from easycv.file import io
|
||||
|
@ -51,148 +46,6 @@ def get_checkpoint(filename):
|
|||
return filename
|
||||
|
||||
|
||||
def load_and_check_state_dict(module: nn.Module,
|
||||
state_dict: Union[dict, OrderedDict],
|
||||
strict: bool = False,
|
||||
logger: Optional[logging.Logger] = None) -> None:
|
||||
"""Load state_dict to a module.
|
||||
|
||||
This method is modified from :meth:`mmcv.runner.checkpoint.load_state_dict`.
|
||||
Default value for ``strict`` is set to ``False`` and the message for
|
||||
param mismatch will be shown even if strict is False.
|
||||
Raise error when state_dict is highly mismatched.
|
||||
|
||||
Args:
|
||||
module (Module): Module that receives the state_dict.
|
||||
state_dict (dict or OrderedDict): Weights.
|
||||
strict (bool): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
||||
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
||||
message. If not specified, print function will be used.
|
||||
"""
|
||||
unexpected_keys: List[str] = []
|
||||
all_missing_keys: List[str] = []
|
||||
err_msg: List[str] = []
|
||||
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy() # type: ignore
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata # type: ignore
|
||||
|
||||
# use _load_from_state_dict to enable checkpoint version control
|
||||
def load(module, prefix=''):
|
||||
# recursively check parallel module in case that the model has a
|
||||
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
||||
if is_module_wrapper(module):
|
||||
module = module.module
|
||||
local_metadata = {} if metadata is None else metadata.get(
|
||||
prefix[:-1], {})
|
||||
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
||||
all_missing_keys, unexpected_keys,
|
||||
err_msg)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
def get_dist_info() -> Tuple[int, int]:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
load(module)
|
||||
# break load->load reference cycle
|
||||
load = None # type: ignore
|
||||
|
||||
# ignore "num_batches_tracked" of BN layers
|
||||
missing_keys = [
|
||||
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
||||
]
|
||||
|
||||
if unexpected_keys:
|
||||
err_msg.append('unexpected key in source '
|
||||
f'state_dict: {", ".join(unexpected_keys)}\n')
|
||||
if missing_keys:
|
||||
err_msg.append(
|
||||
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if len(err_msg) > 0 and rank == 0:
|
||||
err_msg.insert(
|
||||
0, 'The model and loaded state dict do not match exactly\n')
|
||||
err_msg = '\n'.join(err_msg) # type: ignore
|
||||
if strict:
|
||||
raise RuntimeError(err_msg)
|
||||
else:
|
||||
if logger is not None:
|
||||
logger.warning(err_msg)
|
||||
else:
|
||||
print(err_msg)
|
||||
|
||||
err_msg_list = err_msg.split('\n')
|
||||
|
||||
for error_msg_info in err_msg_list:
|
||||
if 'size mismatch' in error_msg_info and 'cls' not in error_msg_info:
|
||||
raise RuntimeError(
|
||||
'Please check your pretrained model. The parameters do not match outside of the cls layer.'
|
||||
)
|
||||
|
||||
|
||||
def load_and_check_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
filename: str,
|
||||
map_location: Union[str, Callable, None] = None,
|
||||
strict: bool = False,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
|
||||
"""Load checkpoint from a file or URI.
|
||||
|
||||
Args:
|
||||
model (Module): Module to load checkpoint.
|
||||
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||
details.
|
||||
map_location (str): Same as :func:`torch.load`.
|
||||
strict (bool): Whether to allow different params for the model and
|
||||
checkpoint.
|
||||
logger (:mod:`logging.Logger` or None): The logger for error message.
|
||||
revise_keys (list): A list of customized keywords to modify the
|
||||
state_dict in checkpoint. Each item is a (pattern, replacement)
|
||||
pair of the regular expression operations. Default: strip
|
||||
the prefix 'module.' by [(r'^module\\.', '')].
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
checkpoint = _load_checkpoint(filename, map_location, logger)
|
||||
# OrderedDict is a subclass of dict
|
||||
if not isinstance(checkpoint, dict):
|
||||
raise RuntimeError(
|
||||
f'No state_dict found in checkpoint file {filename}')
|
||||
# get state_dict from checkpoint
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
# strip prefix of state_dict
|
||||
metadata = getattr(state_dict, '_metadata', OrderedDict())
|
||||
for p, r in revise_keys:
|
||||
state_dict = OrderedDict(
|
||||
{re.sub(p, r, k): v
|
||||
for k, v in state_dict.items()})
|
||||
# Keep metadata in state_dict
|
||||
state_dict._metadata = metadata
|
||||
|
||||
# load state_dict
|
||||
load_and_check_state_dict(model, state_dict, strict, logger)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def load_checkpoint(model,
|
||||
filename,
|
||||
map_location='cpu',
|
||||
|
@ -219,7 +72,7 @@ def load_checkpoint(model,
|
|||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
filename = get_checkpoint(filename)
|
||||
return load_and_check_checkpoint(
|
||||
return mmcv_load_checkpoint(
|
||||
model,
|
||||
filename,
|
||||
map_location=map_location,
|
||||
|
|
Loading…
Reference in New Issue