mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add weights only flag to avoid warning, try to keep bwd compat. Default to True for remote load of pretrained weights, keep False for local checkpoing load to avoid training checkpoint breaks.. fix #2249
This commit is contained in:
parent
531215eded
commit
bd0f79153b
@ -177,6 +177,15 @@ def load_pretrained(
|
||||
model.load_pretrained(pretrained_loc)
|
||||
return
|
||||
else:
|
||||
try:
|
||||
state_dict = load_state_dict_from_url(
|
||||
pretrained_loc,
|
||||
map_location='cpu',
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
weights_only=True,
|
||||
)
|
||||
except ValueError:
|
||||
state_dict = load_state_dict_from_url(
|
||||
pretrained_loc,
|
||||
map_location='cpu',
|
||||
@ -193,7 +202,7 @@ def load_pretrained(
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
|
||||
else:
|
||||
model_name = pretrained_cfg.get('architecture', 'this model')
|
||||
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
|
||||
|
@ -44,6 +44,7 @@ def load_state_dict(
|
||||
checkpoint_path: str,
|
||||
use_ema: bool = True,
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
weights_only: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
# Check if safetensors or not and load weights accordingly
|
||||
@ -51,6 +52,9 @@ def load_state_dict(
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
||||
else:
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
|
||||
except ValueError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
state_dict_key = ''
|
||||
@ -79,6 +83,7 @@ def load_checkpoint(
|
||||
strict: bool = True,
|
||||
remap: bool = False,
|
||||
filter_fn: Optional[Callable] = None,
|
||||
weights_only: bool = False,
|
||||
):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
@ -88,7 +93,7 @@ def load_checkpoint(
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only)
|
||||
if remap:
|
||||
state_dict = remap_state_dict(state_dict, model)
|
||||
elif filter_fn:
|
||||
@ -126,7 +131,7 @@ def resume_checkpoint(
|
||||
):
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring model state from checkpoint...')
|
||||
|
@ -168,7 +168,11 @@ def load_model_config_from_hf(model_id: str):
|
||||
return pretrained_cfg, model_name, model_args
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
def load_state_dict_from_hf(
|
||||
model_id: str,
|
||||
filename: str = HF_WEIGHTS_NAME,
|
||||
weights_only: bool = False,
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
|
||||
@ -187,7 +191,11 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
# Otherwise, load using pytorch.load
|
||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||
return torch.load(cached_file, map_location='cpu')
|
||||
try:
|
||||
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
|
||||
except ValueError:
|
||||
state_dict = torch.load(cached_file, map_location='cpu')
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user