mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2257 from huggingface/weights_only
Add weights only flag to avoid warning, try to keep bwd compat.
This commit is contained in:
commit
2df9f2869d
@ -177,12 +177,21 @@ def load_pretrained(
|
|||||||
model.load_pretrained(pretrained_loc)
|
model.load_pretrained(pretrained_loc)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
state_dict = load_state_dict_from_url(
|
try:
|
||||||
pretrained_loc,
|
state_dict = load_state_dict_from_url(
|
||||||
map_location='cpu',
|
pretrained_loc,
|
||||||
progress=_DOWNLOAD_PROGRESS,
|
map_location='cpu',
|
||||||
check_hash=_CHECK_HASH,
|
progress=_DOWNLOAD_PROGRESS,
|
||||||
)
|
check_hash=_CHECK_HASH,
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
state_dict = load_state_dict_from_url(
|
||||||
|
pretrained_loc,
|
||||||
|
map_location='cpu',
|
||||||
|
progress=_DOWNLOAD_PROGRESS,
|
||||||
|
check_hash=_CHECK_HASH,
|
||||||
|
)
|
||||||
elif load_from == 'hf-hub':
|
elif load_from == 'hf-hub':
|
||||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
||||||
if isinstance(pretrained_loc, (list, tuple)):
|
if isinstance(pretrained_loc, (list, tuple)):
|
||||||
@ -193,7 +202,7 @@ def load_pretrained(
|
|||||||
else:
|
else:
|
||||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||||
else:
|
else:
|
||||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
|
||||||
else:
|
else:
|
||||||
model_name = pretrained_cfg.get('architecture', 'this model')
|
model_name = pretrained_cfg.get('architecture', 'this model')
|
||||||
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
|
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
|
||||||
|
@ -44,6 +44,7 @@ def load_state_dict(
|
|||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
use_ema: bool = True,
|
use_ema: bool = True,
|
||||||
device: Union[str, torch.device] = 'cpu',
|
device: Union[str, torch.device] = 'cpu',
|
||||||
|
weights_only: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||||
# Check if safetensors or not and load weights accordingly
|
# Check if safetensors or not and load weights accordingly
|
||||||
@ -51,7 +52,10 @@ def load_state_dict(
|
|||||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
try:
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
|
||||||
|
except TypeError:
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
|
||||||
state_dict_key = ''
|
state_dict_key = ''
|
||||||
if isinstance(checkpoint, dict):
|
if isinstance(checkpoint, dict):
|
||||||
@ -79,6 +83,7 @@ def load_checkpoint(
|
|||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
remap: bool = False,
|
remap: bool = False,
|
||||||
filter_fn: Optional[Callable] = None,
|
filter_fn: Optional[Callable] = None,
|
||||||
|
weights_only: bool = False,
|
||||||
):
|
):
|
||||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||||
@ -88,7 +93,7 @@ def load_checkpoint(
|
|||||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||||
return
|
return
|
||||||
|
|
||||||
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
|
state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only)
|
||||||
if remap:
|
if remap:
|
||||||
state_dict = remap_state_dict(state_dict, model)
|
state_dict = remap_state_dict(state_dict, model)
|
||||||
elif filter_fn:
|
elif filter_fn:
|
||||||
@ -126,7 +131,7 @@ def resume_checkpoint(
|
|||||||
):
|
):
|
||||||
resume_epoch = None
|
resume_epoch = None
|
||||||
if os.path.isfile(checkpoint_path):
|
if os.path.isfile(checkpoint_path):
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
||||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||||
if log_info:
|
if log_info:
|
||||||
_logger.info('Restoring model state from checkpoint...')
|
_logger.info('Restoring model state from checkpoint...')
|
||||||
|
@ -168,7 +168,11 @@ def load_model_config_from_hf(model_id: str):
|
|||||||
return pretrained_cfg, model_name, model_args
|
return pretrained_cfg, model_name, model_args
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
def load_state_dict_from_hf(
|
||||||
|
model_id: str,
|
||||||
|
filename: str = HF_WEIGHTS_NAME,
|
||||||
|
weights_only: bool = False,
|
||||||
|
):
|
||||||
assert has_hf_hub(True)
|
assert has_hf_hub(True)
|
||||||
hf_model_id, hf_revision = hf_split(model_id)
|
hf_model_id, hf_revision = hf_split(model_id)
|
||||||
|
|
||||||
@ -187,7 +191,11 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
|||||||
# Otherwise, load using pytorch.load
|
# Otherwise, load using pytorch.load
|
||||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||||
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||||
return torch.load(cached_file, map_location='cpu')
|
try:
|
||||||
|
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
|
||||||
|
except TypeError:
|
||||||
|
state_dict = torch.load(cached_file, map_location='cpu')
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
|
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user