Merge pull request #2257 from huggingface/weights_only

Add weights only flag to avoid warning, try to keep bwd compat.
This commit is contained in:
Ross Wightman 2024-08-12 18:09:36 -07:00 committed by GitHub
commit 2df9f2869d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 12 deletions

View File

@ -177,12 +177,21 @@ def load_pretrained(
model.load_pretrained(pretrained_loc)
return
else:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
try:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
weights_only=True,
)
except TypeError:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
@ -193,7 +202,7 @@ def load_pretrained(
else:
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
else:
model_name = pretrained_cfg.get('architecture', 'this model')
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")

View File

@ -44,6 +44,7 @@ def load_state_dict(
checkpoint_path: str,
use_ema: bool = True,
device: Union[str, torch.device] = 'cpu',
weights_only: bool = False,
) -> Dict[str, Any]:
if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
@ -51,7 +52,10 @@ def load_state_dict(
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)
try:
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
except TypeError:
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict_key = ''
if isinstance(checkpoint, dict):
@ -79,6 +83,7 @@ def load_checkpoint(
strict: bool = True,
remap: bool = False,
filter_fn: Optional[Callable] = None,
weights_only: bool = False,
):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
@ -88,7 +93,7 @@ def load_checkpoint(
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only)
if remap:
state_dict = remap_state_dict(state_dict, model)
elif filter_fn:
@ -126,7 +131,7 @@ def resume_checkpoint(
):
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
if log_info:
_logger.info('Restoring model state from checkpoint...')

View File

@ -168,7 +168,11 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name, model_args
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
def load_state_dict_from_hf(
model_id: str,
filename: str = HF_WEIGHTS_NAME,
weights_only: bool = False,
):
assert has_hf_hub(True)
hf_model_id, hf_revision = hf_split(model_id)
@ -187,7 +191,11 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu')
try:
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
except TypeError:
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):