Add weights only flag to avoid warning, try to keep bwd compat. Default to True for remote load of pretrained weights, keep False for local checkpoing load to avoid training checkpoint breaks.. fix #2249

This commit is contained in:
Ross Wightman 2024-08-12 11:41:41 -07:00
parent 531215eded
commit bd0f79153b
3 changed files with 34 additions and 12 deletions

View File

@ -177,12 +177,21 @@ def load_pretrained(
model.load_pretrained(pretrained_loc)
return
else:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
try:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
weights_only=True,
)
except ValueError:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
@ -193,7 +202,7 @@ def load_pretrained(
else:
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
else:
model_name = pretrained_cfg.get('architecture', 'this model')
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")

View File

@ -44,6 +44,7 @@ def load_state_dict(
checkpoint_path: str,
use_ema: bool = True,
device: Union[str, torch.device] = 'cpu',
weights_only: bool = False,
) -> Dict[str, Any]:
if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
@ -51,7 +52,10 @@ def load_state_dict(
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)
try:
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
except ValueError:
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict_key = ''
if isinstance(checkpoint, dict):
@ -79,6 +83,7 @@ def load_checkpoint(
strict: bool = True,
remap: bool = False,
filter_fn: Optional[Callable] = None,
weights_only: bool = False,
):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
@ -88,7 +93,7 @@ def load_checkpoint(
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only)
if remap:
state_dict = remap_state_dict(state_dict, model)
elif filter_fn:
@ -126,7 +131,7 @@ def resume_checkpoint(
):
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
if log_info:
_logger.info('Restoring model state from checkpoint...')

View File

@ -168,7 +168,11 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name, model_args
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
def load_state_dict_from_hf(
model_id: str,
filename: str = HF_WEIGHTS_NAME,
weights_only: bool = False,
):
assert has_hf_hub(True)
hf_model_id, hf_revision = hf_split(model_id)
@ -187,7 +191,11 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu')
try:
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
except ValueError:
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):