mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add brute-force checkpoint remapping option
This commit is contained in:
parent
b293dfa595
commit
e858912e0c
@ -63,7 +63,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, 'load_pretrained'):
|
||||
@ -72,10 +72,28 @@ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
if remap:
|
||||
state_dict = remap_checkpoint(model, state_dict)
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||
return incompatible_keys
|
||||
|
||||
|
||||
def remap_checkpoint(model, state_dict, allow_reshape=True):
|
||||
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
|
||||
This assumes models (and originating state dict) were created with params registered in same order.
|
||||
"""
|
||||
out_dict = {}
|
||||
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
|
||||
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
if va.shape != vb.shape:
|
||||
if allow_reshape:
|
||||
vb = vb.reshape(va.shape)
|
||||
else:
|
||||
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
out_dict[ka] = vb
|
||||
return out_dict
|
||||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
|
Loading…
x
Reference in New Issue
Block a user