_orig_mod still causing issues even though I thought it was fixed in pytorch, add unwrap / clean helpers

This commit is contained in:
Ross Wightman 2024-07-19 11:03:45 -07:00
parent 3a8a965891
commit 0cbf4fa586
2 changed files with 20 additions and 3 deletions

View File

@ -19,12 +19,24 @@ _logger = logging.getLogger(__name__)
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
def _remove_prefix(text, prefix):
# FIXME replace with 3.9 stdlib fn when min at 3.9
if text.startswith(prefix):
return text[len(prefix):]
return text
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
cleaned_state_dict = {}
to_remove = (
'module.', # DDP wrapper
'_orig_mod.', # torchcompile dynamo wrapper
)
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
cleaned_state_dict[name] = v
for r in to_remove:
k = _remove_prefix(k, r)
cleaned_state_dict[k] = v
return cleaned_state_dict

View File

@ -17,7 +17,12 @@ def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
else:
return model.module if hasattr(model, 'module') else model
if hasattr(model, 'module'):
return unwrap_model(model.module)
elif hasattr(model, '_orig_mod'):
return unwrap_model(model._orig_mod)
else:
return model
def get_state_dict(model, unwrap_fn=unwrap_model):