mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
_orig_mod still causing issues even though I thought it was fixed in pytorch, add unwrap / clean helpers
This commit is contained in:
parent
3a8a965891
commit
0cbf4fa586
@ -19,12 +19,24 @@ _logger = logging.getLogger(__name__)
|
|||||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
|
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_prefix(text, prefix):
|
||||||
|
# FIXME replace with 3.9 stdlib fn when min at 3.9
|
||||||
|
if text.startswith(prefix):
|
||||||
|
return text[len(prefix):]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
||||||
cleaned_state_dict = {}
|
cleaned_state_dict = {}
|
||||||
|
to_remove = (
|
||||||
|
'module.', # DDP wrapper
|
||||||
|
'_orig_mod.', # torchcompile dynamo wrapper
|
||||||
|
)
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
name = k[7:] if k.startswith('module.') else k
|
for r in to_remove:
|
||||||
cleaned_state_dict[name] = v
|
k = _remove_prefix(k, r)
|
||||||
|
cleaned_state_dict[k] = v
|
||||||
return cleaned_state_dict
|
return cleaned_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,12 @@ def unwrap_model(model):
|
|||||||
if isinstance(model, ModelEma):
|
if isinstance(model, ModelEma):
|
||||||
return unwrap_model(model.ema)
|
return unwrap_model(model.ema)
|
||||||
else:
|
else:
|
||||||
return model.module if hasattr(model, 'module') else model
|
if hasattr(model, 'module'):
|
||||||
|
return unwrap_model(model.module)
|
||||||
|
elif hasattr(model, '_orig_mod'):
|
||||||
|
return unwrap_model(model._orig_mod)
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(model, unwrap_fn=unwrap_model):
|
def get_state_dict(model, unwrap_fn=unwrap_model):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user