From 0cbf4fa5867b59ea82d0b4e863b86cbd35f5de06 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 19 Jul 2024 11:03:45 -0700 Subject: [PATCH] _orig_mod still causing issues even though I thought it was fixed in pytorch, add unwrap / clean helpers --- timm/models/_helpers.py | 16 ++++++++++++++-- timm/utils/model.py | 7 ++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index 079a4dda..ea7ea290 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -19,12 +19,24 @@ _logger = logging.getLogger(__name__) __all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint'] +def _remove_prefix(text, prefix): + # FIXME replace with 3.9 stdlib fn when min at 3.9 + if text.startswith(prefix): + return text[len(prefix):] + return text + + def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training cleaned_state_dict = {} + to_remove = ( + 'module.', # DDP wrapper + '_orig_mod.', # torchcompile dynamo wrapper + ) for k, v in state_dict.items(): - name = k[7:] if k.startswith('module.') else k - cleaned_state_dict[name] = v + for r in to_remove: + k = _remove_prefix(k, r) + cleaned_state_dict[k] = v return cleaned_state_dict diff --git a/timm/utils/model.py b/timm/utils/model.py index 894453a8..492313cb 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -17,7 +17,12 @@ def unwrap_model(model): if isinstance(model, ModelEma): return unwrap_model(model.ema) else: - return model.module if hasattr(model, 'module') else model + if hasattr(model, 'module'): + return unwrap_model(model.module) + elif hasattr(model, '_orig_mod'): + return unwrap_model(model._orig_mod) + else: + return model def get_state_dict(model, unwrap_fn=unwrap_model):