remove modify of model load

pull/339/head
zouxinyi0625 2024-02-29 17:35:58 +08:00
parent c0333d6f07
commit a7d359f855
1 changed files with 2 additions and 149 deletions

View File

@ -1,16 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import re
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.parallel import is_module_wrapper
from mmcv.runner import _load_checkpoint as _load_checkpoint
from mmcv.runner import load_checkpoint as mmcv_load_checkpoint
from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
from torch import distributed as dist
from torch.optim import Optimizer
from easycv.file import io
@ -51,148 +46,6 @@ def get_checkpoint(filename):
return filename
def load_and_check_state_dict(module: nn.Module,
state_dict: Union[dict, OrderedDict],
strict: bool = False,
logger: Optional[logging.Logger] = None) -> None:
"""Load state_dict to a module.
This method is modified from :meth:`mmcv.runner.checkpoint.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Raise error when state_dict is highly mismatched.
Args:
module (Module): Module that receives the state_dict.
state_dict (dict or OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys: List[str] = []
all_missing_keys: List[str] = []
err_msg: List[str] = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy() # type: ignore
if metadata is not None:
state_dict._metadata = metadata # type: ignore
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
load(module)
# break load->load reference cycle
load = None # type: ignore
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg) # type: ignore
if strict:
raise RuntimeError(err_msg)
else:
if logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
err_msg_list = err_msg.split('\n')
for error_msg_info in err_msg_list:
if 'size mismatch' in error_msg_info and 'cls' not in error_msg_info:
raise RuntimeError(
'Please check your pretrained model. The parameters do not match outside of the cls layer.'
)
def load_and_check_checkpoint(
model: torch.nn.Module,
filename: str,
map_location: Union[str, Callable, None] = None,
strict: bool = False,
logger: Optional[logging.Logger] = None,
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
revise_keys (list): A list of customized keywords to modify the
state_dict in checkpoint. Each item is a (pattern, replacement)
pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')].
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location, logger)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# strip prefix of state_dict
metadata = getattr(state_dict, '_metadata', OrderedDict())
for p, r in revise_keys:
state_dict = OrderedDict(
{re.sub(p, r, k): v
for k, v in state_dict.items()})
# Keep metadata in state_dict
state_dict._metadata = metadata
# load state_dict
load_and_check_state_dict(model, state_dict, strict, logger)
return checkpoint
def load_checkpoint(model,
filename,
map_location='cpu',
@ -219,7 +72,7 @@ def load_checkpoint(model,
dict or OrderedDict: The loaded checkpoint.
"""
filename = get_checkpoint(filename)
return load_and_check_checkpoint(
return mmcv_load_checkpoint(
model,
filename,
map_location=map_location,