From bd0f79153b863e911618bbb05e2e8ae8e10e8277 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 12 Aug 2024 11:41:41 -0700 Subject: [PATCH] Add weights only flag to avoid warning, try to keep bwd compat. Default to True for remote load of pretrained weights, keep False for local checkpoing load to avoid training checkpoint breaks.. fix #2249 --- timm/models/_builder.py | 23 ++++++++++++++++------- timm/models/_helpers.py | 11 ++++++++--- timm/models/_hub.py | 12 ++++++++++-- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 7741cf94..7bfdff41 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -177,12 +177,21 @@ def load_pretrained( model.load_pretrained(pretrained_loc) return else: - state_dict = load_state_dict_from_url( - pretrained_loc, - map_location='cpu', - progress=_DOWNLOAD_PROGRESS, - check_hash=_CHECK_HASH, - ) + try: + state_dict = load_state_dict_from_url( + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + weights_only=True, + ) + except ValueError: + state_dict = load_state_dict_from_url( + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + ) elif load_from == 'hf-hub': _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') if isinstance(pretrained_loc, (list, tuple)): @@ -193,7 +202,7 @@ def load_pretrained( else: state_dict = load_state_dict_from_hf(*pretrained_loc) else: - state_dict = load_state_dict_from_hf(pretrained_loc) + state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True) else: model_name = pretrained_cfg.get('architecture', 'this model') raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.") diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index ea7ea290..7f384182 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -44,6 +44,7 @@ def load_state_dict( checkpoint_path: str, use_ema: bool = True, device: Union[str, torch.device] = 'cpu', + weights_only: bool = False, ) -> Dict[str, Any]: if checkpoint_path and os.path.isfile(checkpoint_path): # Check if safetensors or not and load weights accordingly @@ -51,7 +52,10 @@ def load_state_dict( assert _has_safetensors, "`pip install safetensors` to use .safetensors" checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) else: - checkpoint = torch.load(checkpoint_path, map_location=device) + try: + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only) + except ValueError: + checkpoint = torch.load(checkpoint_path, map_location=device) state_dict_key = '' if isinstance(checkpoint, dict): @@ -79,6 +83,7 @@ def load_checkpoint( strict: bool = True, remap: bool = False, filter_fn: Optional[Callable] = None, + weights_only: bool = False, ): if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): # numpy checkpoint, try to load via model specific load_pretrained fn @@ -88,7 +93,7 @@ def load_checkpoint( raise NotImplementedError('Model cannot load numpy checkpoint') return - state_dict = load_state_dict(checkpoint_path, use_ema, device=device) + state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only) if remap: state_dict = remap_state_dict(state_dict, model) elif filter_fn: @@ -126,7 +131,7 @@ def resume_checkpoint( ): resume_epoch = None if os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if log_info: _logger.info('Restoring model state from checkpoint...') diff --git a/timm/models/_hub.py b/timm/models/_hub.py index e3eafc48..bcbf93e4 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -168,7 +168,11 @@ def load_model_config_from_hf(model_id: str): return pretrained_cfg, model_name, model_args -def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): +def load_state_dict_from_hf( + model_id: str, + filename: str = HF_WEIGHTS_NAME, + weights_only: bool = False, +): assert has_hf_hub(True) hf_model_id, hf_revision = hf_split(model_id) @@ -187,7 +191,11 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): # Otherwise, load using pytorch.load cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.") - return torch.load(cached_file, map_location='cpu') + try: + state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only) + except ValueError: + state_dict = torch.load(cached_file, map_location='cpu') + return state_dict def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):