mirror of https://github.com/open-mmlab/mmcv.git
load_checkpoint support normal dict checkpoints (#351)
* load_checkpoint support normal dict checkpoints * commentspull/350/head^2
parent
97f9efd825
commit
f48241a65a
|
@ -222,14 +222,15 @@ def load_checkpoint(model,
|
|||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
checkpoint = _load_checkpoint(filename, map_location)
|
||||
# get state_dict from checkpoint
|
||||
if isinstance(checkpoint, OrderedDict):
|
||||
state_dict = checkpoint
|
||||
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
# OrderedDict is a subclass of dict
|
||||
if not isinstance(checkpoint, dict):
|
||||
raise RuntimeError(
|
||||
f'No state_dict found in checkpoint file {filename}')
|
||||
# get state_dict from checkpoint
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
# strip prefix of state_dict
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
|
||||
|
|
Loading…
Reference in New Issue