diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 7741cf94..8d124afe 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -177,12 +177,21 @@ def load_pretrained( model.load_pretrained(pretrained_loc) return else: - state_dict = load_state_dict_from_url( - pretrained_loc, - map_location='cpu', - progress=_DOWNLOAD_PROGRESS, - check_hash=_CHECK_HASH, - ) + try: + state_dict = load_state_dict_from_url( + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + weights_only=True, + ) + except TypeError: + state_dict = load_state_dict_from_url( + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + ) elif load_from == 'hf-hub': _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') if isinstance(pretrained_loc, (list, tuple)): @@ -193,7 +202,7 @@ def load_pretrained( else: state_dict = load_state_dict_from_hf(*pretrained_loc) else: - state_dict = load_state_dict_from_hf(pretrained_loc) + state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True) else: model_name = pretrained_cfg.get('architecture', 'this model') raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.") diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index ea7ea290..3622b104 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -44,6 +44,7 @@ def load_state_dict( checkpoint_path: str, use_ema: bool = True, device: Union[str, torch.device] = 'cpu', + weights_only: bool = False, ) -> Dict[str, Any]: if checkpoint_path and os.path.isfile(checkpoint_path): # Check if safetensors or not and load weights accordingly @@ -51,7 +52,10 @@ def load_state_dict( assert _has_safetensors, "`pip install safetensors` to use .safetensors" checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) else: - checkpoint = torch.load(checkpoint_path, map_location=device) + try: + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only) + except TypeError: + checkpoint = torch.load(checkpoint_path, map_location=device) state_dict_key = '' if isinstance(checkpoint, dict): @@ -79,6 +83,7 @@ def load_checkpoint( strict: bool = True, remap: bool = False, filter_fn: Optional[Callable] = None, + weights_only: bool = False, ): if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): # numpy checkpoint, try to load via model specific load_pretrained fn @@ -88,7 +93,7 @@ def load_checkpoint( raise NotImplementedError('Model cannot load numpy checkpoint') return - state_dict = load_state_dict(checkpoint_path, use_ema, device=device) + state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only) if remap: state_dict = remap_state_dict(state_dict, model) elif filter_fn: @@ -126,7 +131,7 @@ def resume_checkpoint( ): resume_epoch = None if os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if log_info: _logger.info('Restoring model state from checkpoint...') diff --git a/timm/models/_hub.py b/timm/models/_hub.py index e3eafc48..19b3dcf0 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -168,7 +168,11 @@ def load_model_config_from_hf(model_id: str): return pretrained_cfg, model_name, model_args -def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): +def load_state_dict_from_hf( + model_id: str, + filename: str = HF_WEIGHTS_NAME, + weights_only: bool = False, +): assert has_hf_hub(True) hf_model_id, hf_revision = hf_split(model_id) @@ -187,7 +191,11 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): # Otherwise, load using pytorch.load cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.") - return torch.load(cached_file, map_location='cpu') + try: + state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only) + except TypeError: + state_dict = torch.load(cached_file, map_location='cpu') + return state_dict def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):