diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index 079a4dda..ea7ea290 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -19,12 +19,24 @@ _logger = logging.getLogger(__name__) __all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint'] +def _remove_prefix(text, prefix): + # FIXME replace with 3.9 stdlib fn when min at 3.9 + if text.startswith(prefix): + return text[len(prefix):] + return text + + def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training cleaned_state_dict = {} + to_remove = ( + 'module.', # DDP wrapper + '_orig_mod.', # torchcompile dynamo wrapper + ) for k, v in state_dict.items(): - name = k[7:] if k.startswith('module.') else k - cleaned_state_dict[name] = v + for r in to_remove: + k = _remove_prefix(k, r) + cleaned_state_dict[k] = v return cleaned_state_dict diff --git a/timm/utils/model.py b/timm/utils/model.py index 894453a8..492313cb 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -17,7 +17,12 @@ def unwrap_model(model): if isinstance(model, ModelEma): return unwrap_model(model.ema) else: - return model.module if hasattr(model, 'module') else model + if hasattr(model, 'module'): + return unwrap_model(model.module) + elif hasattr(model, '_orig_mod'): + return unwrap_model(model._orig_mod) + else: + return model def get_state_dict(model, unwrap_fn=unwrap_model):