mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
_orig_mod still causing issues even though I thought it was fixed in pytorch, add unwrap / clean helpers
This commit is contained in:
parent
3a8a965891
commit
0cbf4fa586
@ -19,12 +19,24 @@ _logger = logging.getLogger(__name__)
|
||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
|
||||
|
||||
|
||||
def _remove_prefix(text, prefix):
|
||||
# FIXME replace with 3.9 stdlib fn when min at 3.9
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix):]
|
||||
return text
|
||||
|
||||
|
||||
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
||||
cleaned_state_dict = {}
|
||||
to_remove = (
|
||||
'module.', # DDP wrapper
|
||||
'_orig_mod.', # torchcompile dynamo wrapper
|
||||
)
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] if k.startswith('module.') else k
|
||||
cleaned_state_dict[name] = v
|
||||
for r in to_remove:
|
||||
k = _remove_prefix(k, r)
|
||||
cleaned_state_dict[k] = v
|
||||
return cleaned_state_dict
|
||||
|
||||
|
||||
|
@ -17,7 +17,12 @@ def unwrap_model(model):
|
||||
if isinstance(model, ModelEma):
|
||||
return unwrap_model(model.ema)
|
||||
else:
|
||||
return model.module if hasattr(model, 'module') else model
|
||||
if hasattr(model, 'module'):
|
||||
return unwrap_model(model.module)
|
||||
elif hasattr(model, '_orig_mod'):
|
||||
return unwrap_model(model._orig_mod)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def get_state_dict(model, unwrap_fn=unwrap_model):
|
||||
|
Loading…
x
Reference in New Issue
Block a user