mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2257 from huggingface/weights_only
Add weights only flag to avoid warning, try to keep bwd compat.
This commit is contained in:
commit
2df9f2869d
@ -177,12 +177,21 @@ def load_pretrained(
|
||||
model.load_pretrained(pretrained_loc)
|
||||
return
|
||||
else:
|
||||
state_dict = load_state_dict_from_url(
|
||||
pretrained_loc,
|
||||
map_location='cpu',
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
)
|
||||
try:
|
||||
state_dict = load_state_dict_from_url(
|
||||
pretrained_loc,
|
||||
map_location='cpu',
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
weights_only=True,
|
||||
)
|
||||
except TypeError:
|
||||
state_dict = load_state_dict_from_url(
|
||||
pretrained_loc,
|
||||
map_location='cpu',
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
)
|
||||
elif load_from == 'hf-hub':
|
||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
||||
if isinstance(pretrained_loc, (list, tuple)):
|
||||
@ -193,7 +202,7 @@ def load_pretrained(
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
|
||||
else:
|
||||
model_name = pretrained_cfg.get('architecture', 'this model')
|
||||
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
|
||||
|
@ -44,6 +44,7 @@ def load_state_dict(
|
||||
checkpoint_path: str,
|
||||
use_ema: bool = True,
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
weights_only: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
# Check if safetensors or not and load weights accordingly
|
||||
@ -51,7 +52,10 @@ def load_state_dict(
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
state_dict_key = ''
|
||||
if isinstance(checkpoint, dict):
|
||||
@ -79,6 +83,7 @@ def load_checkpoint(
|
||||
strict: bool = True,
|
||||
remap: bool = False,
|
||||
filter_fn: Optional[Callable] = None,
|
||||
weights_only: bool = False,
|
||||
):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
@ -88,7 +93,7 @@ def load_checkpoint(
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only)
|
||||
if remap:
|
||||
state_dict = remap_state_dict(state_dict, model)
|
||||
elif filter_fn:
|
||||
@ -126,7 +131,7 @@ def resume_checkpoint(
|
||||
):
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring model state from checkpoint...')
|
||||
|
@ -168,7 +168,11 @@ def load_model_config_from_hf(model_id: str):
|
||||
return pretrained_cfg, model_name, model_args
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
def load_state_dict_from_hf(
|
||||
model_id: str,
|
||||
filename: str = HF_WEIGHTS_NAME,
|
||||
weights_only: bool = False,
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
|
||||
@ -187,7 +191,11 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
# Otherwise, load using pytorch.load
|
||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||
return torch.load(cached_file, map_location='cpu')
|
||||
try:
|
||||
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
|
||||
except TypeError:
|
||||
state_dict = torch.load(cached_file, map_location='cpu')
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user